blob: 62ae5b2cd09742025dac626aa2b02babfcd0d22c [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import numpy as np
import argparse
import os
from singa import device
from singa import tensor
from singa import autograd
from singa import opt
def load_data(path):
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()
return (x_train, y_train), (x_test, y_test)
def to_categorical(y, num_classes):
'''
Converts a class vector (integers) to binary class matrix.
Args
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
Return
A binary matrix representation of the input.
'''
y = np.array(y, dtype='int')
n = y.shape[0]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), y] = 1
categorical = categorical.astype(np.float32)
return categorical
def preprocess(data):
data = data.astype(np.float32)
data /= 255
data = np.expand_dims(data, axis=1)
return data
def accuracy(pred, target):
y = np.argmax(pred, axis=1)
t = np.argmax(target, axis=1)
a = y == t
return np.array(a, 'int').sum() / float(len(t))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train CNN over MNIST')
parser.add_argument('file_path', type=str, help='the dataset path')
parser.add_argument('--use_cpu', action='store_true')
args = parser.parse_args()
assert os.path.exists(args.file_path), \
'Pls download the MNIST dataset from https://s3.amazonaws.com/img-datasets/mnist.npz'
if args.use_cpu:
print('Using CPU')
dev = device.get_default_device()
else:
print('Using GPU')
dev = device.create_cuda_gpu()
train, test = load_data(args.file_path)
batch_number = 600
num_classes = 10
epochs = 1
sgd = opt.SGD(lr=0.01)
x_train = preprocess(train[0])
y_train = to_categorical(train[1], num_classes)
x_test = preprocess(test[0])
y_test = to_categorical(test[1], num_classes)
print('the shape of training data is', x_train.shape)
print('the shape of training label is', y_train.shape)
print('the shape of testing data is', x_test.shape)
print('the shape of testing label is', y_test.shape)
# operations initialization
conv1 = autograd.Conv2d(1, 32, 3, padding=1, bias=False)
bn1 = autograd.BatchNorm2d(32)
conv21 = autograd.Conv2d(32, 16, 3, padding=1)
conv22 = autograd.Conv2d(32, 16, 3, padding=1)
bn2 = autograd.BatchNorm2d(32)
linear = autograd.Linear(32 * 28 * 28, 10)
pooling1 = autograd.MaxPool2d(3, 1, padding=1)
pooling2 = autograd.AvgPool2d(3, 1, padding=1)
def forward(x, t):
y = conv1(x)
y = autograd.relu(y)
y = bn1(y)
y = pooling1(y)
y1 = conv21(y)
y2 = conv22(y)
y = autograd.cat((y1, y2), 1)
y = bn2(y)
y = autograd.relu(y)
y = bn2(y)
y = pooling2(y)
y = autograd.flatten(y)
y = linear(y)
loss = autograd.softmax_cross_entropy(y, t)
return loss, y
autograd.training = True
for epoch in range(epochs):
for i in range(batch_number):
inputs = tensor.Tensor(device=dev, data=x_train[
i * 100:(1 + i) * 100], stores_grad=False)
targets = tensor.Tensor(device=dev, data=y_train[
i * 100:(1 + i) * 100], requires_grad=False, stores_grad=False)
loss, y = forward(inputs, targets)
accuracy_rate = accuracy(tensor.to_numpy(y),
tensor.to_numpy(targets))
if (i % 5 == 0):
print('accuracy is:', accuracy_rate, 'loss is:',
tensor.to_numpy(loss)[0])
for p, gp in autograd.backward(loss):
sgd.update(p, gp)
sgd.step()