blob: d07041b002d30bc1e751052f7818b1e3cd43f4f6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import logging
import numpy as np
import mxnet as mx
class Softmax(mx.operator.CustomOp):
def __init__(self):
super(Softmax,self).__init__()
# Each thread processes a row (a sample in the batch).
fwd_src = r"""
template<class DType>
__global__ void fwd(const DType* x, DType* y, const int row_size, const int req) {
const int offset = row_size * threadIdx.x;
DType max = x[offset];
for(int i = 1; i < row_size; ++i) {
if(max < x[offset + i]) {
max = x[offset + i];
}
}
DType sum = 0;
for(int i = 0; i < row_size; ++i) {
sum += exp(x[offset + i] - max);
}
switch(req) {
case 1:
for(int i = 0; i < row_size; ++i) {
y[offset + i] = exp(x[offset + i] - max) / sum;
}
break;
case 2:
for(int i = 0; i < row_size; ++i) {
y[offset + i] += exp(x[offset + i] - max) / sum;
}
break;
}
}
"""
# Each block processes a row and each thread in a block calculate an element of `dx`.
bwd_src = r"""
template<class DType>
__global__ void bwd(const DType* l, const DType* y, DType* dx, const int req) {
const int z = static_cast<int>(l[blockIdx.x]);
const int i = threadIdx.x + blockDim.x * blockIdx.x;
if(req == 1) {
dx[i] = threadIdx.x == z ? y[i] - 1 : y[i];
} else {
dx[i] += threadIdx.x == z ? y[i] - 1 : y[i];
}
}
"""
fwd_kernel_mod = mx.rtc.CudaModule(fwd_src, exports=["fwd<float>", "fwd<double>"])
bwd_kernel_mod = mx.rtc.CudaModule(bwd_src, exports=["bwd<float>", "bwd<double>"])
fwd_kernel_float_signature = "const float*, const float*, const int, const int"
self.fwd_float_kernel = fwd_kernel_mod.get_kernel("fwd<float>", fwd_kernel_float_signature)
bwd_kernel_float_signature = "const float*, const float*, float*, const int"
self.bwd_float_kernel = bwd_kernel_mod.get_kernel("bwd<float>", bwd_kernel_float_signature)
fwd_kernel_double_signature = "const double*, const double*, const int, const int"
self.fwd_double_kernel = fwd_kernel_mod.get_kernel("fwd<double>", fwd_kernel_double_signature)
bwd_kernel_double_signature = "const double*, const double*, double*, const int"
self.bwd_double_kernel = bwd_kernel_mod.get_kernel("bwd<double>", bwd_kernel_double_signature)
def forward(self, is_train, req, in_data, out_data, aux):
if req[0] == "null":
return
x = in_data[0] # input
y = out_data[0] # output
if y.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_double_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.fwd_float_kernel.launch((x, y, x.shape[1], self._reqCode(req[0])), mx.gpu(0), (1, 1, 1), (x.shape[0], 1, 1))
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
if req[0] == "null":
return
l = in_data[1] # label
y = out_data[0] # output from the forward pass
dx = in_grad[0] # the storage for the gradient
if dx.dtype == np.float64:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_double_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
else:
# args, ctx, grid_shape, block_shape, shared_mem = 0
self.bwd_float_kernel.launch((l, y, dx, self._reqCode(req[0])), mx.gpu(0), (y.shape[0], 1, 1), (y.shape[1], 1, 1))
def _reqCode(self, req):
if(req == "write"):
return 1
elif(req == "add"):
return 2
elif(req == "null"):
return 0
else:
raise ValueError("Invalid value of `req`: {}".format(req))
@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):
def __init__(self):
super(SoftmaxProp, self).__init__(need_top_grad=False)
def list_arguments(self):
return ['data', 'label']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)
output_shape = in_shape[0]
return [data_shape, label_shape], [output_shape], []
def infer_type(self, in_type):
return in_type, [in_type[0]], []
def create_operator(self, ctx, in_shapes, in_dtypes):
return Softmax()
# define mlp
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data=act1, name='fc2', num_hidden=64)
act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data=act2, name='fc3', num_hidden=10)
#mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
mlp = mx.symbol.Custom(data=fc3, name='softmax', op_type='softmax')
# data
train, val = mx.test_utils.get_mnist_iterator(batch_size=100, input_shape=(784,))
# train
logging.basicConfig(level=logging.DEBUG)
context = mx.gpu(0)
mod = mx.mod.Module(mlp, context=context)
mod.fit(
train_data=train,
eval_data=val,
optimizer='sgd',
optimizer_params={'learning_rate':0.1, 'momentum': 0.9, 'wd': 0.00001},
num_epoch=10,
batch_end_callback=mx.callback.Speedometer(100, 100)
)