| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Generate MXNet implementation of CapsNet |
| Reference 1: https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf |
| Reference 2: https://arxiv.org/pdf/1710.09829.pdf |
| """ |
| import os |
| import re |
| import gzip |
| import struct |
| import numpy as np |
| import scipy.ndimage as ndi |
| import mxnet as mx |
| from capsulelayers import primary_caps, CapsuleLayer |
| |
| from mxboard import SummaryWriter |
| |
| |
| def margin_loss(y_true, y_pred): |
| loss = y_true * mx.sym.square(mx.sym.maximum(0., 0.9 - y_pred)) +\ |
| 0.5 * (1 - y_true) * mx.sym.square(mx.sym.maximum(0., y_pred - 0.1)) |
| return mx.sym.mean(data=mx.sym.sum(loss, 1)) |
| |
| |
| def capsnet(batch_size, n_class, num_routing, recon_loss_weight): |
| """Create CapsNet""" |
| # data.shape = [batch_size, 1, 28, 28] |
| data = mx.sym.Variable('data') |
| |
| input_shape = (1, 28, 28) |
| # Conv2D layer |
| # net.shape = [batch_size, 256, 20, 20] |
| conv1 = mx.sym.Convolution(data=data, |
| num_filter=256, |
| kernel=(9, 9), |
| layout='NCHW', |
| name='conv1') |
| conv1 = mx.sym.Activation(data=conv1, act_type='relu', name='conv1_act') |
| # net.shape = [batch_size, 256, 6, 6] |
| |
| primarycaps = primary_caps(data=conv1, |
| dim_vector=8, |
| n_channels=32, |
| kernel=(9, 9), |
| strides=[2, 2], |
| name='primarycaps') |
| primarycaps.infer_shape(data=(batch_size, 1, 28, 28)) |
| # CapsuleLayer |
| kernel_initializer = mx.init.Xavier(rnd_type='uniform', factor_type='avg', magnitude=3) |
| bias_initializer = mx.init.Zero() |
| digitcaps = CapsuleLayer(num_capsule=10, |
| dim_vector=16, |
| batch_size=batch_size, |
| kernel_initializer=kernel_initializer, |
| bias_initializer=bias_initializer, |
| num_routing=num_routing)(primarycaps) |
| |
| # out_caps : (batch_size, 10) |
| out_caps = mx.sym.sqrt(data=mx.sym.sum(mx.sym.square(digitcaps), 2)) |
| out_caps.infer_shape(data=(batch_size, 1, 28, 28)) |
| |
| y = mx.sym.Variable('softmax_label', shape=(batch_size,)) |
| y_onehot = mx.sym.one_hot(y, n_class) |
| y_reshaped = mx.sym.Reshape(data=y_onehot, shape=(batch_size, -4, n_class, -1)) |
| y_reshaped.infer_shape(softmax_label=(batch_size,)) |
| |
| # inputs_masked : (batch_size, 16) |
| inputs_masked = mx.sym.linalg_gemm2(y_reshaped, digitcaps, transpose_a=True) |
| inputs_masked = mx.sym.Reshape(data=inputs_masked, shape=(-3, 0)) |
| x_recon = mx.sym.FullyConnected(data=inputs_masked, num_hidden=512, name='x_recon') |
| x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act') |
| x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=1024, name='x_recon2') |
| x_recon = mx.sym.Activation(data=x_recon, act_type='relu', name='x_recon_act2') |
| x_recon = mx.sym.FullyConnected(data=x_recon, num_hidden=np.prod(input_shape), name='x_recon3') |
| x_recon = mx.sym.Activation(data=x_recon, act_type='sigmoid', name='x_recon_act3') |
| |
| data_flatten = mx.sym.flatten(data=data) |
| squared_error = mx.sym.square(x_recon-data_flatten) |
| recon_error = mx.sym.mean(squared_error) |
| recon_error_stopped = recon_error |
| recon_error_stopped = mx.sym.BlockGrad(recon_error_stopped) |
| loss = mx.symbol.MakeLoss((1-recon_loss_weight)*margin_loss(y_onehot, out_caps)+recon_loss_weight*recon_error) |
| |
| out_caps_blocked = out_caps |
| out_caps_blocked = mx.sym.BlockGrad(out_caps_blocked) |
| return mx.sym.Group([out_caps_blocked, loss, recon_error_stopped]) |
| |
| |
| def download_data(url, force_download=False): |
| fname = url.split("/")[-1] |
| if force_download or not os.path.exists(fname): |
| mx.test_utils.download(url, fname) |
| return fname |
| |
| |
| def read_data(label_url, image_url): |
| with gzip.open(download_data(label_url)) as flbl: |
| magic, num = struct.unpack(">II", flbl.read(8)) |
| label = np.fromstring(flbl.read(), dtype=np.int8) |
| with gzip.open(download_data(image_url), 'rb') as fimg: |
| magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16)) |
| image = np.fromstring(fimg.read(), dtype=np.uint8) |
| np.reshape(image, len(label), (rows, cols)) |
| return label, image |
| |
| |
| def to4d(img): |
| return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255 |
| |
| |
| class LossMetric(mx.metric.EvalMetric): |
| """Evaluate the loss function""" |
| def __init__(self, batch_size, num_gpus): |
| super(LossMetric, self).__init__('LossMetric') |
| self.batch_size = batch_size |
| self.num_gpu = num_gpus |
| self.sum_metric = 0 |
| self.num_inst = 0 |
| self.loss = 0.0 |
| self.batch_sum_metric = 0 |
| self.batch_num_inst = 0 |
| self.batch_loss = 0.0 |
| self.recon_loss = 0.0 |
| self.n_batch = 0 |
| |
| def update(self, labels, preds): |
| """Update the hyper-parameters and loss of CapsNet""" |
| batch_sum_metric = 0 |
| batch_num_inst = 0 |
| for label, pred_outcaps in zip(labels[0], preds[0]): |
| label_np = int(label.asnumpy()) |
| pred_label = int(np.argmax(pred_outcaps.asnumpy())) |
| batch_sum_metric += int(label_np == pred_label) |
| batch_num_inst += 1 |
| batch_loss = preds[1].asnumpy() |
| recon_loss = preds[2].asnumpy() |
| self.sum_metric += batch_sum_metric |
| self.num_inst += batch_num_inst |
| self.loss += batch_loss |
| self.recon_loss += recon_loss |
| self.batch_sum_metric = batch_sum_metric |
| self.batch_num_inst = batch_num_inst |
| self.batch_loss = batch_loss |
| self.n_batch += 1 |
| |
| def get_name_value(self): |
| acc = float(self.sum_metric)/float(self.num_inst) |
| mean_loss = self.loss / float(self.n_batch) |
| mean_recon_loss = self.recon_loss / float(self.n_batch) |
| return acc, mean_loss, mean_recon_loss |
| |
| def get_batch_log(self, n_batch): |
| print("n_batch :"+str(n_batch)+" batch_acc:" + |
| str(float(self.batch_sum_metric) / float(self.batch_num_inst)) + |
| ' batch_loss:' + str(float(self.batch_loss)/float(self.batch_num_inst))) |
| self.batch_sum_metric = 0 |
| self.batch_num_inst = 0 |
| self.batch_loss = 0.0 |
| |
| def reset(self): |
| self.sum_metric = 0 |
| self.num_inst = 0 |
| self.loss = 0.0 |
| self.recon_loss = 0.0 |
| self.n_batch = 0 |
| |
| |
| class SimpleLRScheduler(mx.lr_scheduler.LRScheduler): |
| """A simple lr schedule that simply return `dynamic_lr`. We will set `dynamic_lr` |
| dynamically based on performance on the validation set. |
| """ |
| |
| def __init__(self, learning_rate=0.001): |
| super(SimpleLRScheduler, self).__init__() |
| self.learning_rate = learning_rate |
| |
| def __call__(self, num_update): |
| return self.learning_rate |
| |
| |
| def do_training(num_epoch, optimizer, kvstore, learning_rate, model_prefix, decay): |
| """Perform CapsNet training""" |
| summary_writer = SummaryWriter(args.tblog_dir) |
| lr_scheduler = SimpleLRScheduler(learning_rate) |
| optimizer_params = {'lr_scheduler': lr_scheduler} |
| module.init_params() |
| module.init_optimizer(kvstore=kvstore, |
| optimizer=optimizer, |
| optimizer_params=optimizer_params) |
| n_epoch = 0 |
| while True: |
| if n_epoch >= num_epoch: |
| break |
| train_iter.reset() |
| val_iter.reset() |
| loss_metric.reset() |
| for n_batch, data_batch in enumerate(train_iter): |
| module.forward_backward(data_batch) |
| module.update() |
| module.update_metric(loss_metric, data_batch.label) |
| loss_metric.get_batch_log(n_batch) |
| train_acc, train_loss, train_recon_err = loss_metric.get_name_value() |
| loss_metric.reset() |
| for n_batch, data_batch in enumerate(val_iter): |
| module.forward(data_batch) |
| module.update_metric(loss_metric, data_batch.label) |
| loss_metric.get_batch_log(n_batch) |
| val_acc, val_loss, val_recon_err = loss_metric.get_name_value() |
| |
| summary_writer.add_scalar('train_acc', train_acc, n_epoch) |
| summary_writer.add_scalar('train_loss', train_loss, n_epoch) |
| summary_writer.add_scalar('train_recon_err', train_recon_err, n_epoch) |
| summary_writer.add_scalar('val_acc', val_acc, n_epoch) |
| summary_writer.add_scalar('val_loss', val_loss, n_epoch) |
| summary_writer.add_scalar('val_recon_err', val_recon_err, n_epoch) |
| |
| print('Epoch[%d] train acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, train_acc, train_loss, |
| train_recon_err)) |
| print('Epoch[%d] val acc: %.4f loss: %.6f recon_err: %.6f' % (n_epoch, val_acc, val_loss, val_recon_err)) |
| print('SAVE CHECKPOINT') |
| |
| module.save_checkpoint(prefix=model_prefix, epoch=n_epoch) |
| n_epoch += 1 |
| lr_scheduler.learning_rate = learning_rate * (decay ** n_epoch) |
| |
| |
| def apply_transform(x, transform_matrix, fill_mode='nearest', cval=0.): |
| """Apply transform on nd.array""" |
| x = np.rollaxis(x, 0, 0) |
| final_affine_matrix = transform_matrix[:2, :2] |
| final_offset = transform_matrix[:2, 2] |
| channel_images = [ndi.interpolation.affine_transform( |
| x_channel, |
| final_affine_matrix, |
| final_offset, |
| order=0, |
| mode=fill_mode, |
| cval=cval) for x_channel in x] |
| x = np.stack(channel_images, axis=0) |
| x = np.rollaxis(x, 0, 0 + 1) |
| return x |
| |
| |
| def random_shift(x, width_shift_fraction, height_shift_fraction): |
| tx = np.random.uniform(-height_shift_fraction, height_shift_fraction) * x.shape[2] |
| ty = np.random.uniform(-width_shift_fraction, width_shift_fraction) * x.shape[1] |
| shift_matrix = np.array([[1, 0, tx], |
| [0, 1, ty], |
| [0, 0, 1]]) |
| x = apply_transform(x, shift_matrix, 'nearest') |
| return x |
| |
| |
| def _shuffle(data, idx): |
| """Shuffle the data.""" |
| shuffle_data = [] |
| |
| for idx_k, idx_v in data: |
| shuffle_data.append((idx_k, mx.ndarray.array(idx_v.asnumpy()[idx], idx_v.context))) |
| |
| return shuffle_data |
| |
| |
| class MNISTCustomIter(mx.io.NDArrayIter): |
| """Create custom iterator of mnist dataset""" |
| def __init__(self, data, label, batch_size, shuffle): |
| self.data = data |
| self.label = label |
| self.batch_size = batch_size |
| self.shuffle = shuffle |
| self.cursor = None |
| |
| def reset(self): |
| """Reset class MNISTCustomIter(mx.io.NDArrayIter):""" |
| # shuffle data |
| if self.is_train: |
| np.random.shuffle(self.idx) |
| self.data = _shuffle(self.data, self.idx) |
| self.label = _shuffle(self.label, self.idx) |
| |
| if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: |
| self.cursor = -self.batch_size + (self.cursor % self.num_data) % self.batch_size |
| else: |
| self.cursor = -self.batch_size |
| |
| def set_is_train(self, is_train): |
| """Set training flag""" |
| self.is_train = is_train |
| |
| def next(self): |
| """Generate next of iterator""" |
| if self.iter_next(): |
| if self.is_train: |
| data_raw_list = self.getdata() |
| data_shifted = [] |
| for data_raw in data_raw_list[0]: |
| data_shifted.append(random_shift(data_raw.asnumpy(), 0.1, 0.1)) |
| return mx.io.DataBatch(data=[mx.nd.array(data_shifted)], label=self.getlabel(), |
| pad=self.getpad(), index=None) |
| else: |
| return mx.io.DataBatch(data=self.getdata(), label=self.getlabel(), pad=self.getpad(), index=None) |
| |
| else: |
| raise StopIteration |
| |
| |
| if __name__ == "__main__": |
| # Read mnist data set |
| path = 'http://yann.lecun.com/exdb/mnist/' |
| (train_lbl, train_img) = read_data(path + 'train-labels-idx1-ubyte.gz', path + 'train-images-idx3-ubyte.gz') |
| (val_lbl, val_img) = read_data(path + 't10k-labels-idx1-ubyte.gz', path + 't10k-images-idx3-ubyte.gz') |
| |
| # set batch size |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--batch_size', default=100, type=int) |
| parser.add_argument('--devices', default='gpu0', type=str) |
| parser.add_argument('--num_epoch', default=100, type=int) |
| parser.add_argument('--lr', default=0.001, type=float) |
| parser.add_argument('--num_routing', default=3, type=int) |
| parser.add_argument('--model_prefix', default='capsnet', type=str) |
| parser.add_argument('--decay', default=0.9, type=float) |
| parser.add_argument('--tblog_dir', default='tblog', type=str) |
| parser.add_argument('--recon_loss_weight', default=0.392, type=float) |
| args = parser.parse_args() |
| for k, v in sorted(vars(args).items()): |
| print("{0}: {1}".format(k, v)) |
| contexts = re.split(r'\W+', args.devices) |
| for i, ctx in enumerate(contexts): |
| if ctx[:3] == 'gpu': |
| contexts[i] = mx.context.gpu(int(ctx[3:])) |
| else: |
| contexts[i] = mx.context.cpu() |
| num_gpu = len(contexts) |
| |
| if args.batch_size % num_gpu != 0: |
| raise Exception('num_gpu should be positive divisor of batch_size') |
| |
| # generate train_iter, val_iter |
| train_iter = MNISTCustomIter(data=to4d(train_img), label=train_lbl, batch_size=int(args.batch_size), shuffle=True) |
| train_iter.set_is_train(True) |
| val_iter = MNISTCustomIter(data=to4d(val_img), label=val_lbl, batch_size=int(args.batch_size), shuffle=True) |
| val_iter.set_is_train(False) |
| # define capsnet |
| final_net = capsnet(batch_size=int(args.batch_size/num_gpu), |
| n_class=10, |
| num_routing=args.num_routing, |
| recon_loss_weight=args.recon_loss_weight) |
| # set metric |
| loss_metric = LossMetric(args.batch_size/num_gpu, 1) |
| |
| # run model |
| module = mx.mod.Module(symbol=final_net, context=contexts, data_names=('data',), label_names=('softmax_label',)) |
| module.bind(data_shapes=train_iter.provide_data, |
| label_shapes=val_iter.provide_label, |
| for_training=True) |
| |
| do_training(num_epoch=args.num_epoch, optimizer='adam', kvstore='device', learning_rate=args.lr, |
| model_prefix=args.model_prefix, decay=args.decay) |