blob: ac6545abb1de4f0760578aa4dc0de5d56d86efb0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import print_function
import sys
import os
# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path = [os.path.join(curr_path, "../autoencoder")] + sys.path
import mxnet as mx
import numpy as np
import data
from scipy.spatial.distance import cdist
from sklearn.cluster import KMeans
import model
from autoencoder import AutoEncoderModel
from solver import Solver, Monitor
import logging
def cluster_acc(Y_pred, Y):
from sklearn.utils.linear_assignment_ import linear_assignment
assert Y_pred.size == Y.size
D = max(Y_pred.max(), Y.max())+1
w = np.zeros((D,D), dtype=np.int64)
for i in range(Y_pred.size):
w[Y_pred[i], int(Y[i])] += 1
ind = linear_assignment(w.max() - w)
return sum([w[i,j] for i,j in ind])*1.0/Y_pred.size, w
class DECModel(model.MXModel):
class DECLoss(mx.operator.NumpyOp):
def __init__(self, num_centers, alpha):
super(DECModel.DECLoss, self).__init__(need_top_grad=False)
self.num_centers = num_centers
self.alpha = alpha
def forward(self, in_data, out_data):
z = in_data[0]
mu = in_data[1]
q = out_data[0]
self.mask = 1.0/(1.0+cdist(z, mu)**2/self.alpha)
q[:] = self.mask**((self.alpha+1.0)/2.0)
q[:] = (q.T/q.sum(axis=1)).T
def backward(self, out_grad, in_data, out_data, in_grad):
q = out_data[0]
z = in_data[0]
mu = in_data[1]
p = in_data[2]
dz = in_grad[0]
dmu = in_grad[1]
self.mask *= (self.alpha+1.0)/self.alpha*(p-q)
dz[:] = (z.T*self.mask.sum(axis=1)).T - self.mask.dot(mu)
dmu[:] = (mu.T*self.mask.sum(axis=0)).T - self.mask.T.dot(z)
def infer_shape(self, in_shape):
assert len(in_shape) == 3
assert len(in_shape[0]) == 2
input_shape = in_shape[0]
label_shape = (input_shape[0], self.num_centers)
mu_shape = (self.num_centers, input_shape[1])
out_shape = (input_shape[0], self.num_centers)
return [input_shape, mu_shape, label_shape], [out_shape]
def list_arguments(self):
return ['data', 'mu', 'label']
def setup(self, X, num_centers, alpha, save_to='dec_model'):
sep = X.shape[0]*9/10
X_train = X[:sep]
X_val = X[sep:]
ae_model = AutoEncoderModel(self.xpu, [X.shape[1],500,500,2000,10], pt_dropout=0.2)
if not os.path.exists(save_to+'_pt.arg'):
ae_model.layerwise_pretrain(X_train, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.finetune(X_train, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.save(save_to+'_pt.arg')
logging.log(logging.INFO, "Autoencoder Training error: %f"%ae_model.eval(X_train))
logging.log(logging.INFO, "Autoencoder Validation error: %f"%ae_model.eval(X_val))
else:
ae_model.load(save_to+'_pt.arg')
self.ae_model = ae_model
self.dec_op = DECModel.DECLoss(num_centers, alpha)
label = mx.sym.Variable('label')
self.feature = self.ae_model.encoder
self.loss = self.dec_op(data=self.ae_model.encoder, label=label, name='dec')
self.args.update({k:v for k,v in self.ae_model.args.items() if k in self.ae_model.encoder.list_arguments()})
self.args['dec_mu'] = mx.nd.empty((num_centers, self.ae_model.dims[-1]), ctx=self.xpu)
self.args_grad.update({k: mx.nd.empty(v.shape, ctx=self.xpu) for k,v in self.args.items()})
self.args_mult.update({k: k.endswith('bias') and 2.0 or 1.0 for k in self.args})
self.num_centers = num_centers
def cluster(self, X, y=None, update_interval=None):
N = X.shape[0]
if not update_interval:
update_interval = N
batch_size = 256
test_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=False,
last_batch_handle='pad')
args = {k: mx.nd.array(v.asnumpy(), ctx=self.xpu) for k, v in self.args.items()}
z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
kmeans = KMeans(self.num_centers, n_init=20)
kmeans.fit(z)
args['dec_mu'][:] = kmeans.cluster_centers_
solver = Solver('sgd', momentum=0.9, wd=0.0, learning_rate=0.01)
def ce(label, pred):
return np.sum(label*np.log(label/(pred+0.000001)))/label.shape[0]
solver.set_metric(mx.metric.CustomMetric(ce))
label_buff = np.zeros((X.shape[0], self.num_centers))
train_iter = mx.io.NDArrayIter({'data': X}, {'label': label_buff}, batch_size=batch_size,
shuffle=False, last_batch_handle='roll_over')
self.y_pred = np.zeros((X.shape[0]))
def refresh(i):
if i%update_interval == 0:
z = list(model.extract_feature(self.feature, args, None, test_iter, N, self.xpu).values())[0]
p = np.zeros((z.shape[0], self.num_centers))
self.dec_op.forward([z, args['dec_mu'].asnumpy()], [p])
y_pred = p.argmax(axis=1)
print(np.std(np.bincount(y_pred)), np.bincount(y_pred))
print(np.std(np.bincount(y.astype(np.int))), np.bincount(y.astype(np.int)))
if y is not None:
print(cluster_acc(y_pred, y)[0])
weight = 1.0/p.sum(axis=0)
weight *= self.num_centers/weight.sum()
p = (p**2)*weight
train_iter.data_list[1][:] = (p.T/p.sum(axis=1)).T
print(np.sum(y_pred != self.y_pred), 0.001*y_pred.shape[0])
if np.sum(y_pred != self.y_pred) < 0.001*y_pred.shape[0]:
self.y_pred = y_pred
return True
self.y_pred = y_pred
solver.set_iter_start_callback(refresh)
solver.set_monitor(Monitor(50))
solver.solve(self.xpu, self.loss, args, self.args_grad, None,
train_iter, 0, 1000000000, {}, False)
self.end_args = args
if y is not None:
return cluster_acc(self.y_pred, y)[0]
else:
return -1
def mnist_exp(xpu):
X, Y = data.get_mnist()
dec_model = DECModel(xpu, X, 10, 1.0, 'data/mnist')
acc = []
for i in [10*(2**j) for j in range(9)]:
acc.append(dec_model.cluster(X, Y, i))
logging.log(logging.INFO, 'Clustering Acc: %f at update interval: %d'%(acc[-1], i))
logging.info(str(acc))
logging.info('Best Clustering ACC: %f at update_interval: %d'%(np.max(acc), 10*(2**np.argmax(acc))))
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
mnist_exp(mx.gpu(0))