blob: 538d8b976d0c88d1c5007e72c88bf990a9879f65 [file] [log] [blame]
# pylint: skip-file
from __future__ import print_function
import mxnet as mx
import numpy as np
import logging
import data
from autoencoder import AutoEncoderModel
if __name__ == '__main__':
# set to INFO to see less information during training
logging.basicConfig(level=logging.DEBUG)
ae_model = AutoEncoderModel(mx.gpu(0), [784,500,500,2000,10], pt_dropout=0.2,
internal_act='relu', output_act='relu')
X, _ = data.get_mnist()
train_X = X[:60000]
val_X = X[60000:]
ae_model.layerwise_pretrain(train_X, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.finetune(train_X, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.save('mnist_pt.arg')
ae_model.load('mnist_pt.arg')
print("Training error:", ae_model.eval(train_X))
print("Validation error:", ae_model.eval(val_X))