blob: 7ea5ff3f4b623644eb44195994cf72a717302109 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import mxnet as mx
from common import TemporaryDirectory
from mxnet import nd
from mxnet.gluon import nn, loss
from mxnet.gluon.contrib.estimator import estimator, event_handler
def _get_test_network(net=nn.Sequential()):
net.add(nn.Dense(128, activation='relu', flatten=False),
nn.Dense(64, activation='relu'),
nn.Dense(10, activation='relu'))
return net
def _get_test_data():
data = nd.ones((32, 100))
label = nd.zeros((32, 1))
data_arr = mx.gluon.data.dataset.ArrayDataset(data, label)
return mx.gluon.data.DataLoader(data_arr, batch_size=8)
def test_checkpoint_handler():
with TemporaryDirectory() as tmpdir:
model_prefix = 'test_epoch'
file_path = os.path.join(tmpdir, model_prefix)
test_data = _get_test_data()
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
monitor=acc,
save_best=True,
epoch_period=1)
est.fit(test_data, event_handlers=[checkpoint_handler], epochs=1)
assert checkpoint_handler.current_epoch == 1
assert checkpoint_handler.current_batch == 4
assert os.path.isfile(file_path + '-best.params')
assert os.path.isfile(file_path + '-best.states')
assert os.path.isfile(file_path + '-epoch0batch4.params')
assert os.path.isfile(file_path + '-epoch0batch4.states')
model_prefix = 'test_batch'
file_path = os.path.join(tmpdir, model_prefix)
net = _get_test_network(nn.HybridSequential())
net.hybridize()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
epoch_period=None,
batch_period=2,
max_checkpoints=2)
est.fit(test_data, event_handlers=[checkpoint_handler], batches=10)
assert checkpoint_handler.current_batch == 10
assert checkpoint_handler.current_epoch == 3
assert not os.path.isfile(file_path + 'best.params')
assert not os.path.isfile(file_path + 'best.states')
assert not os.path.isfile(file_path + '-epoch0batch0.params')
assert not os.path.isfile(file_path + '-epoch0batch0.states')
assert os.path.isfile(file_path + '-symbol.json')
assert os.path.isfile(file_path + '-epoch1batch7.params')
assert os.path.isfile(file_path + '-epoch1batch7.states')
assert os.path.isfile(file_path + '-epoch2batch9.params')
assert os.path.isfile(file_path + '-epoch2batch9.states')
def test_resume_checkpoint():
with TemporaryDirectory() as tmpdir:
model_prefix = 'test_net'
file_path = os.path.join(tmpdir, model_prefix)
test_data = _get_test_data()
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
monitor=acc,
max_checkpoints=1)
est.fit(test_data, event_handlers=[checkpoint_handler], epochs=2)
assert os.path.isfile(file_path + '-epoch1batch8.params')
assert os.path.isfile(file_path + '-epoch1batch8.states')
checkpoint_handler = event_handler.CheckpointHandler(model_dir=tmpdir,
model_prefix=model_prefix,
monitor=acc,
max_checkpoints=1,
resume_from_checkpoint=True)
est.fit(test_data, event_handlers=[checkpoint_handler], epochs=5)
# should only continue to train 3 epochs and last checkpoint file is epoch4
assert est.max_epoch == 3
assert os.path.isfile(file_path + '-epoch4batch20.states')
def test_early_stopping():
test_data = _get_test_data()
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
patience=0,
mode='min')
est.fit(test_data, event_handlers=[early_stopping], epochs=5)
assert early_stopping.current_epoch == 2
assert early_stopping.stopped_epoch == 1
early_stopping = event_handler.EarlyStoppingHandler(monitor=acc,
patience=2,
mode='auto')
est.fit(test_data, event_handlers=[early_stopping], epochs=1)
assert early_stopping.current_epoch == 1
def test_logging():
with TemporaryDirectory() as tmpdir:
test_data = _get_test_data()
file_name = 'test_log'
output_dir = os.path.join(tmpdir, file_name)
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
train_metrics, val_metrics = est.prepare_loss_and_metrics()
logging_handler = event_handler.LoggingHandler(file_name=file_name,
file_location=tmpdir,
train_metrics=train_metrics,
val_metrics=val_metrics)
est.fit(test_data, event_handlers=[logging_handler], epochs=3)
assert logging_handler.batch_index == 0
assert logging_handler.current_epoch == 3
assert os.path.isfile(output_dir)
def test_custom_handler():
class CustomStopHandler(event_handler.TrainBegin,
event_handler.BatchEnd,
event_handler.EpochEnd):
def __init__(self, batch_stop=None, epoch_stop=None):
self.batch_stop = batch_stop
self.epoch_stop = epoch_stop
self.num_batch = 0
self.num_epoch = 0
self.stop_training = False
def train_begin(self, estimator, *args, **kwargs):
self.num_batch = 0
self.num_epoch = 0
def batch_end(self, estimator, *args, **kwargs):
self.num_batch += 1
if self.num_batch == self.batch_stop:
self.stop_training = True
return self.stop_training
def epoch_end(self, estimator, *args, **kwargs):
self.num_epoch += 1
if self.num_epoch == self.epoch_stop:
self.stop_training = True
return self.stop_training
# total data size is 32, batch size is 8
# 4 batch per epoch
test_data = _get_test_data()
net = _get_test_network()
ce_loss = loss.SoftmaxCrossEntropyLoss()
acc = mx.metric.Accuracy()
est = estimator.Estimator(net, loss=ce_loss, metrics=acc)
custom_handler = CustomStopHandler(3, 2)
est.fit(test_data, event_handlers=[custom_handler], epochs=3)
assert custom_handler.num_batch == 3
assert custom_handler.num_epoch == 1
custom_handler = CustomStopHandler(100, 5)
est.fit(test_data, event_handlers=[custom_handler], epochs=10)
assert custom_handler.num_batch == 5 * 4
assert custom_handler.num_epoch == 5