blob: d8a45149d2e83149bb5a0ba02c410d234e01872f [file] [log] [blame]
# pylint: skip-file
from __future__ import print_function
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path = [os.path.join(curr_path, "../autoencoder")] + sys.path
import mxnet as mx
import numpy as np
import data
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
import model
from autoencoder import AutoEncoderModel
from solver import Solver, Monitor
import logging
def cluster_acc(Y_pred, Y):
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w
class DECModel(model.MXModel):
class DECLoss(mx.operator.NumpyOp):
def __init__(self, num_centers, alpha):
super(DECModel.DECLoss, self).__init__(need_top_grad=False)
self.num_centers = num_centers
self.alpha = alpha
def forward(self, in_data, out_data):
z = in_data[0]
mu = in_data[1]
q = out_data[0]
self.mask = 1.0/(1.0+cdist(z, mu)**2/self.alpha)
q[:] = self.mask**((self.alpha+1.0)/2.0)
q[:] = (q.T/q.sum(axis=1)).T
def backward(self, out_grad, in_data, out_data, in_grad):
q = out_data[0]
z = in_data[0]
mu = in_data[1]
p = in_data[2]
dz = in_grad[0]
dmu = in_grad[1]
self.mask *= (self.alpha+1.0)/self.alpha*(p-q)
dz[:] = (z.T*self.mask.sum(axis=1)).T - self.mask.dot(mu)
dmu[:] = (mu.T*self.mask.sum(axis=0)).T - self.mask.T.dot(z)
def infer_shape(self, in_shape):
assert len(in_shape) == 3
assert len(in_shape[0]) == 2
input_shape = in_shape[0]
label_shape = (input_shape[0], self.num_centers)
mu_shape = (self.num_centers, input_shape[1])
out_shape = (input_shape[0], self.num_centers)
return [input_shape, mu_shape, label_shape], [out_shape]
def list_arguments(self):
return ['data', 'mu', 'label']
def setup(self, X, num_centers, alpha, save_to='dec_model'):
sep = X.shape[0]*9/10
X_train = X[:sep]
X_val = X[sep:]
ae_model = AutoEncoderModel(self.xpu, [X.shape[1],500,500,2000,10], pt_dropout=0.2)
if not os.path.exists(save_to+'_pt.arg'):
ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.save(save_to+'_pt.arg')
logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train))
logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val))
else:
ae_model.load(save_to+'_pt.arg')
self.ae_model = ae_model
self.dec_op = DECModel.DECLoss(num_centers, alpha)
label = mx.sym.Variable('label')
self.feature = self.ae_model.encoder
self.loss = self.dec_op(data=self.ae_model.encoder, label=label, name='dec')
self.args.update({k:v for k,v in self.ae_model.args.items() if k in self.ae_model.encoder.list_arguments()})
self.args['dec_mu'] = mx.nd.empty((num_centers, self.ae_model.dims[-1]), ctx=self.xpu)
self.args_grad.update({k: mx.nd.empty(v.shape, ctx=self.xpu) for k,v in self.args.items()})
self.args_mult.update({k: k.endswith('bias') and 2.0 or 1.0 for k in self.args})
self.num_centers = num_centers
def cluster(self, X, y=None, update_interval=None):
N = X.shape[0]
if not update_interval:
update_interval = N
batch_size = 256
test_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False,
last_batch_handle='pad')
args = {k: mx.nd.array(v.asnumpy(), ctx=self.xpu) for k, v in self.args.items()}
z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
kmeans = KMeans(self.num_centers, n_init=20)
kmeans.fit(z)
args['dec_mu'][:] = kmeans.cluster_centers_
solver = Solver('sgd', momentum=0.9, wd=0.0, learning_rate=0.01)
def ce(label, pred):
return np.sum(label*np.log(label/(pred+0.000001)))/label.shape[0]
solver.set_metric(mx.metric.CustomMetric(ce))
label_buff = np.zeros((X.shape[0], self.num_centers))
train_iter = mx.io.NDArrayIter({'data': X}, {'label': label_buff}, batch_size=batch_size,
shuffle=False, last_batch_handle='roll_over')
self.y_pred = np.zeros((X.shape[0]))
def refresh(i):
if i%update_interval == 0:
z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
p = np.zeros((z.shape[0], self.num_centers))
self.dec_op.forward([z, args['dec_mu'].asnumpy()], [p])
y_pred = p.argmax(axis=1)
print(np.std(np.bincount(y_pred)), np.bincount(y_pred))
print(np.std(np.bincount(y.astype(np.int))), np.bincount(y.astype(np.int)))
if y is not None:
print(cluster_acc(y_pred, y)[0])
weight = 1.0/p.sum(axis=0)
weight *= self.num_centers/weight.sum()
p = (p**2)*weight
train_iter.data_list[1][:] = (p.T/p.sum(axis=1)).T
print(np.sum(y_pred != self.y_pred), 0.001*y_pred.shape[0])
if np.sum(y_pred != self.y_pred) < 0.001*y_pred.shape[0]:
self.y_pred = y_pred
return True
self.y_pred = y_pred
solver.set_iter_start_callback(refresh)
solver.set_monitor(Monitor(50))
solver.solve(self.xpu, self.loss, args, self.args_grad, None,
train_iter, 0, 1000000000, {}, False)
self.end_args = args
if y is not None:
return cluster_acc(self.y_pred, y)[0]
else:
return -1
def mnist_exp(xpu):
X, Y = data.get_mnist()
dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist')
acc = []
for i in [10*(2**j) for j in range(9)]:
acc.append(dec_model.cluster(X, Y, i))
logging.log(logging.INFO, 'Clustering Acc: %f at update interval: %d'%(acc[-1], i))
logging.info(str(acc))
logging.info('Best Clustering ACC: %f at update_interval: %d'%(np.max(acc), 10*(2**np.argmax(acc))))
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
mnist_exp(mx.gpu(0))