blob: 6a2ab2ff3c3c09487bfcae50d8e2cf936c970063 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
''' Unit tests for Gluon Batch Processor '''
import sys
import unittest
import warnings
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.contrib.estimator import *
from mxnet.gluon.contrib.estimator.event_handler import *
from mxnet.gluon.contrib.estimator.batch_processor import BatchProcessor
import pytest
mx.npx.reset_np()
def _get_test_network():
net = nn.Sequential()
net.add(nn.Dense(4, activation='relu', flatten=False))
return net
def _get_test_data():
batch_size = 4
in_data = mx.np.random.uniform(size=(10, 3))
out_data = mx.np.random.uniform(size=(10, 4))
# Input dataloader
dataset = gluon.data.dataset.ArrayDataset(in_data, out_data)
dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size)
dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size)
return dataloader, dataiter
@mx.util.use_np
def test_batch_processor_fit():
''' test estimator with different train data types '''
net = _get_test_network()
dataloader, dataiter = _get_test_data()
num_epochs = 1
device = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.gluon.metric.Accuracy()
net.initialize(device=device)
processor = BatchProcessor()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
train_metrics=acc,
trainer=trainer,
device=device,
batch_processor=processor)
est.fit(train_data=dataloader,
epochs=num_epochs)
with pytest.raises(ValueError):
est.fit(train_data=dataiter,
epochs=num_epochs)
# Input NDArray
with pytest.raises(ValueError):
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
epochs=num_epochs)
@mx.util.use_np
def test_batch_processor_validation():
''' test different validation data types'''
net = _get_test_network()
dataloader, dataiter = _get_test_data()
num_epochs = 1
device = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.gluon.metric.Accuracy()
val_loss = gluon.loss.L1Loss()
net.initialize(device=device)
processor = BatchProcessor()
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
train_metrics=acc,
trainer=trainer,
device=device,
val_loss=val_loss,
batch_processor=processor)
# Input dataloader
est.fit(train_data=dataloader,
val_data=dataloader,
epochs=num_epochs)
# using validation handler
train_metrics = est.train_metrics
val_metrics = est.val_metrics
validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate)
with pytest.raises(ValueError):
est.fit(train_data=dataiter,
val_data=dataiter,
epochs=num_epochs)
# Input NDArray
with pytest.raises(ValueError):
est.fit(train_data=[mx.nd.ones(shape=(10, 3))],
val_data=[mx.nd.ones(shape=(10, 3))],
epochs=num_epochs)