| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| ''' Unit tests for Gluon Estimator ''' |
| |
| import sys |
| import unittest |
| |
| import mxnet as mx |
| from mxnet import gluon |
| from mxnet.gluon import nn |
| from mxnet.gluon.contrib.estimator import * |
| from nose.tools import assert_raises |
| |
| |
| def _get_test_network(): |
| net = nn.Sequential() |
| net.add(nn.Dense(4, activation='relu', flatten=False)) |
| return net |
| |
| |
| def _get_test_data(): |
| batch_size = 4 |
| in_data = mx.nd.random.uniform(shape=(10, 3)) |
| out_data = mx.nd.random.uniform(shape=(10, 4)) |
| # Input dataloader |
| dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) |
| dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) |
| dataiter = mx.io.NDArrayIter(data=in_data, label=out_data, batch_size=batch_size) |
| return dataloader, dataiter |
| |
| |
| def test_fit(): |
| ''' test estimator with different train data types ''' |
| net = _get_test_network() |
| dataloader, dataiter = _get_test_data() |
| num_epochs = 1 |
| ctx = mx.cpu() |
| loss = gluon.loss.L2Loss() |
| acc = mx.metric.Accuracy() |
| net.initialize(ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| trainer=trainer, |
| context=ctx) |
| |
| est.fit(train_data=dataloader, |
| epochs=num_epochs) |
| |
| with assert_raises(ValueError): |
| est.fit(train_data=dataiter, |
| epochs=num_epochs) |
| |
| # Input NDArray |
| with assert_raises(ValueError): |
| est.fit(train_data=[mx.nd.ones(shape=(10, 3))], |
| epochs=num_epochs) |
| |
| |
| def test_validation(): |
| ''' test different validation data types''' |
| net = _get_test_network() |
| dataloader, dataiter = _get_test_data() |
| num_epochs = 1 |
| ctx = mx.cpu() |
| loss = gluon.loss.L2Loss() |
| acc = mx.metric.Accuracy() |
| net.initialize(ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| trainer=trainer, |
| context=ctx) |
| # Input dataloader |
| est.fit(train_data=dataloader, |
| val_data=dataloader, |
| epochs=num_epochs) |
| |
| # using validation handler |
| train_metrics, val_metrics = est.prepare_loss_and_metrics() |
| validation_handler = ValidationHandler(val_data=dataloader, eval_fn=est.evaluate, |
| val_metrics=val_metrics) |
| |
| with assert_raises(ValueError): |
| est.fit(train_data=dataiter, |
| val_data=dataiter, |
| epochs=num_epochs) |
| # Input NDArray |
| with assert_raises(ValueError): |
| est.fit(train_data=[mx.nd.ones(shape=(10, 3))], |
| val_data=[mx.nd.ones(shape=(10, 3))], |
| epochs=num_epochs) |
| |
| |
| @unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') |
| def test_initializer(): |
| ''' test with no initializer, inconsistent initializer ''' |
| net = _get_test_network() |
| train_data, _ = _get_test_data() |
| num_epochs = 1 |
| ctx = mx.cpu() |
| |
| loss = gluon.loss.L2Loss() |
| acc = mx.metric.Accuracy() |
| # no initializer |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| context=ctx) |
| est.fit(train_data=train_data, |
| epochs=num_epochs) |
| |
| # different initializer for net and estimator |
| net = _get_test_network() |
| net.initialize(mx.init.Xavier(), ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| # catch reinit warning |
| with warnings.catch_warnings(record=True) as w: |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| initializer=mx.init.MSRAPrelu(), |
| trainer=trainer, |
| context=ctx) |
| assert 'Network already fully initialized' in str(w[-1].message) |
| # net partially initialized, fine tuning use case |
| net = gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=ctx) |
| net.output = gluon.nn.Dense(10) #last layer not initialized |
| est = Estimator(net, loss=loss, metrics=acc, context=ctx) |
| dataset = gluon.data.ArrayDataset(mx.nd.zeros((10, 3, 224, 224)), mx.nd.zeros((10, 10))) |
| train_data = gluon.data.DataLoader(dataset=dataset, batch_size=5) |
| est.fit(train_data=train_data, |
| epochs=num_epochs) |
| |
| |
| @unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') |
| def test_trainer(): |
| ''' test with no trainer and invalid trainer ''' |
| net = _get_test_network() |
| train_data, _ = _get_test_data() |
| num_epochs = 1 |
| ctx = mx.cpu() |
| |
| loss = gluon.loss.L2Loss() |
| acc = mx.metric.Accuracy() |
| net.initialize(ctx=ctx) |
| # input no trainer |
| with warnings.catch_warnings(record=True) as w: |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| context=ctx) |
| assert 'No trainer specified' in str(w[-1].message) |
| est.fit(train_data=train_data, |
| epochs=num_epochs) |
| |
| # input invalid trainer |
| trainer = 'sgd' |
| with assert_raises(ValueError): |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=acc, |
| trainer=trainer, |
| context=ctx) |
| |
| |
| def test_metric(): |
| ''' test with no metric, list of metrics, invalid metric ''' |
| net = _get_test_network() |
| train_data, _ = _get_test_data() |
| num_epochs = 1 |
| ctx = mx.cpu() |
| |
| loss = gluon.loss.L2Loss() |
| net.initialize(ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| # input no metric |
| est = Estimator(net=net, |
| loss=loss, |
| trainer=trainer, |
| context=ctx) |
| est.fit(train_data=train_data, |
| epochs=num_epochs) |
| # input list of metrics |
| metrics = [mx.metric.Accuracy(), mx.metric.Accuracy()] |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=metrics, |
| trainer=trainer, |
| context=ctx) |
| est.fit(train_data=train_data, |
| epochs=num_epochs) |
| # input invalid metric |
| with assert_raises(ValueError): |
| est = Estimator(net=net, |
| loss=loss, |
| metrics='acc', |
| trainer=trainer, |
| context=ctx) |
| # test default metric |
| loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| est = Estimator(net=net, |
| loss=loss, |
| trainer=trainer, |
| context=ctx) |
| est.prepare_loss_and_metrics() |
| assert isinstance(est.train_metrics[0], mx.metric.Accuracy) |
| |
| |
| def test_loss(): |
| ''' test with invalid loss ''' |
| net = _get_test_network() |
| ctx = mx.cpu() |
| acc = mx.metric.Accuracy() |
| net.initialize(ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| # input invalid loss |
| with assert_raises(ValueError): |
| est = Estimator(net=net, |
| loss='mse', |
| metrics=acc, |
| trainer=trainer, |
| context=ctx) |
| |
| |
| def test_context(): |
| ''' test with no context, list of context, invalid context ''' |
| net = _get_test_network() |
| loss = gluon.loss.L2Loss() |
| metrics = mx.metric.Accuracy() |
| # input no context |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=metrics) |
| # input list of context |
| gpus = mx.context.num_gpus() |
| ctx = [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()] |
| net = _get_test_network() |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=metrics, |
| context=ctx) |
| # input invalid context |
| with assert_raises(ValueError): |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=metrics, |
| context='cpu') |
| |
| with assert_raises(AssertionError): |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=metrics, |
| context=[mx.gpu(0), mx.gpu(100)]) |
| |
| |
| def test_categorize_handlers(): |
| class CustomHandler1(TrainBegin): |
| |
| def train_begin(self): |
| print("custom train begin") |
| |
| class CustomHandler2(EpochBegin, BatchBegin, TrainEnd): |
| |
| def epoch_begin(self): |
| print("custom epoch begin") |
| |
| def batch_begin(self): |
| print("custom batch begin") |
| |
| def train_end(self): |
| print("custom train end") |
| |
| class CustomHandler3(EpochBegin, BatchBegin, BatchEnd, TrainEnd): |
| |
| def epoch_begin(self): |
| print("custom epoch begin") |
| |
| def batch_begin(self): |
| print("custom batch begin") |
| |
| def batch_end(self): |
| print("custom batch end") |
| |
| def train_end(self): |
| print("custom train end") |
| |
| net = nn.Sequential() |
| net.add(nn.Dense(10)) |
| loss = gluon.loss.SoftmaxCrossEntropyLoss() |
| est = Estimator(net, loss=loss) |
| event_handlers = [CustomHandler1(), CustomHandler2(), CustomHandler3()] |
| train_begin, epoch_begin, batch_begin, \ |
| batch_end, epoch_end, train_end = est._categorize_handlers(event_handlers) |
| assert len(train_begin) == 1 |
| assert len(epoch_begin) == 2 |
| assert len(batch_begin) == 2 |
| assert len(batch_end) == 1 |
| assert len(train_end) == 2 |
| |
| |
| @unittest.skipIf(sys.version_info.major < 3, 'Test on python 3') |
| def test_default_handlers(): |
| net = _get_test_network() |
| train_data, _ = _get_test_data() |
| |
| num_epochs = 1 |
| ctx = mx.cpu() |
| |
| net.initialize(ctx=ctx) |
| trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) |
| |
| train_acc = mx.metric.RMSE() |
| loss = gluon.loss.L2Loss() |
| |
| est = Estimator(net=net, |
| loss=loss, |
| metrics=train_acc, |
| trainer=trainer, |
| context=ctx) |
| # no handler |
| with warnings.catch_warnings(record=True) as w: |
| est.fit(train_data=train_data, epochs=num_epochs) |
| assert 'You are training with the' in str(w[-1].message) |
| |
| # handler with prepared loss and metrics |
| # use mix of default and user defined handlers |
| train_metrics, val_metrics = est.prepare_loss_and_metrics() |
| logging = LoggingHandler(train_metrics=train_metrics, val_metrics=val_metrics) |
| with warnings.catch_warnings(record=True) as w: |
| est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) |
| assert 'You are training with the' in str(w[-1].message) |
| # provide metric handler by default |
| assert 'MetricHandler' in str(w[-1].message) |
| |
| # handler with all user defined metrics |
| # use mix of default and user defined handlers |
| metric = MetricHandler(train_metrics=[train_acc]) |
| logging = LoggingHandler(train_metrics=[train_acc], val_metrics=[mx.metric.RMSE("val acc")]) |
| est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[metric, logging]) |
| |
| # handler with mixed metrics, some handler use metrics prepared by estimator |
| # some handler use metrics user prepared |
| logging = LoggingHandler(train_metrics=train_metrics, val_metrics=[mx.metric.RMSE("val acc")]) |
| with assert_raises(ValueError): |
| est.fit(train_data=train_data, epochs=num_epochs, event_handlers=[logging]) |
| |
| # test handler order |
| train_metrics, val_metrics = est.prepare_loss_and_metrics() |
| early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) |
| handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) |
| assert len(handlers) == 4 |
| assert isinstance(handlers[0], MetricHandler) |
| assert isinstance(handlers[3], LoggingHandler) |