blob: 7b3eae10fc5f0e9f7a068dd61e5af294c76eed63 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import print_function
import numpy as np
import mxnet as mx
import copy
from mxnet.test_utils import *
import pytest
from mxnet.gluon.model_zoo.vision import get_model
def make_subgraph(subg, *args):
js = subg.tojson()
return subg
@pytest.mark.serial
def test_make_subgraph():
def make_subgraph1(stype):
a = mx.symbol.Variable(name='a', stype=stype)
b = mx.symbol.Variable(name='b', stype=stype)
c = a * b
d = c * 2
a1 = mx.symbol.Variable(name='a', stype=stype)
b1 = mx.symbol.Variable(name='b', stype=stype)
y = make_subgraph(c, a1, b1)
y = y * 2
s = (10, 10)
a_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
ctx=default_device()).tostype(stype)
b_arr = mx.nd.array(np.random.normal(-0.1, 0.1, size=s),
ctx=default_device()).tostype(stype)
return (d, y, {'a': a_arr, 'b': b_arr}, {})
def create_weights(shapes, names):
nd_dict = {}
sym_dict = {}
assert len(shapes) == len(names)
for i in range(len(shapes)):
sym_dict[names[i]] = mx.symbol.Variable(names[i])
nd_dict[names[i]] = mx.nd.array(np.ones(shapes[i]), ctx=default_device())
return (nd_dict, sym_dict)
def make_subgraph_weight(orig, shape, stype):
arg_shapes, out_shapes, aux_shapes = orig.infer_shape(data=shape)
weight_shapes = arg_shapes[1:]
weight_names = orig.list_arguments()[1:]
weight_dict, weight_sym_dict = create_weights(weight_shapes, weight_names)
aux_dict, aux_sym_dict = create_weights(aux_shapes, orig.list_auxiliary_states())
input_dict = copy.deepcopy(weight_sym_dict)
input_dict.update(aux_sym_dict)
input_dict['data'] = mx.symbol.Variable('data', stype=stype)
input_list = []
for name in orig.list_inputs():
assert name in input_dict.keys()
input_list.append(input_dict[name])
subg = make_subgraph(orig, *input_list)
arr = mx.nd.random.uniform(-1, 1, shape=shape, ctx=default_device()).tostype(stype)
arg_dict = weight_dict
arg_dict['data'] = arr
return (orig, subg, arg_dict, aux_dict)
def make_subgraph2(stype, out_mean_var):
data = mx.symbol.Variable('data', stype=stype)
orig = mx.symbol.BatchNorm(data, fix_gamma=False,
output_mean_var=out_mean_var, name="batchnorm")
s = (10, 10)
return make_subgraph_weight(orig, s, stype)
def make_subgraph3(stype):
data = mx.symbol.Variable('data', stype=stype)
conv1 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
bn1 = mx.symbol.BatchNorm(conv1, fix_gamma=False, output_mean_var=False)
conv2 = mx.symbol.Convolution(data=data, kernel=(3, 3), num_filter=16, no_bias=True)
bn2 = mx.symbol.BatchNorm(conv2, fix_gamma=False, output_mean_var=False)
orig = bn1 + bn2
s = (1, 3, 32, 32)
return make_subgraph_weight(orig, s, stype)
def make_subgraph4(stype):
model = get_model('resnet18_v1')
model.hybridize()
model.initialize()
s = (1, 3, 32, 32)
data = mx.np.random.normal(size=s)
out = model(data)
model.export('resnet18')
orig = mx.sym.load('resnet18-symbol.json')
return make_subgraph_weight(orig, s, stype)
make_subgraphs = [make_subgraph1,
lambda stype: make_subgraph2(stype, False),
lambda stype: make_subgraph2(stype, True),
make_subgraph3, make_subgraph4]
stypes = ['default', 'row_sparse']
for make_subg in make_subgraphs:
for stype in stypes:
orig, subg, inputs, aux_states = make_subg(stype)
all_inputs = copy.deepcopy(inputs)
all_inputs.update(aux_states)
args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
e1 = orig._bind(ctx=default_device(), args=all_inputs, args_grad=args_grad,
aux_states=all_inputs)
args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
e2 = subg._bind(ctx=default_device(), args=all_inputs, args_grad=args_grad,
aux_states=all_inputs)
e1.forward(is_train=True)
e2.forward(is_train=True)
for i in range(len(e1.outputs)):
assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(),
rtol=0.001, atol=0.0001)
out_grads = [mx.nd.random.uniform(-1, 1, shape=out.shape, ctx=default_device())
for out in e1.outputs]
e1.backward(out_grads)
e2.backward(out_grads)
for i in range(len(e1.grad_arrays)):
assert_almost_equal(e1.grad_arrays[i].asnumpy(), e2.grad_arrays[i].asnumpy(),
rtol=0.001, atol=0.0001)
@pytest.mark.serial
def test_subgraph_with_customOp():
class MyAdd(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], in_data[0] + 1)
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], out_grad[0])
@mx.operator.register('MyAdd1')
class MyAdd1Prop(mx.operator.CustomOpProp):
def __init__(self):
super(MyAdd1Prop, self).__init__(need_top_grad=True)
def list_arguments(self):
return ['data']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
# inputs, outputs, aux
return [in_shape[0]], [in_shape[0]], []
def create_operator(self, ctx, shapes, dtypes):
return MyAdd()
@mx.operator.register('MyAdd2')
class MyAdd2Prop(mx.operator.CustomOpProp):
def __init__(self):
super(MyAdd2Prop, self).__init__(need_top_grad=True)
def list_arguments(self):
return ['data']
def list_outputs(self):
return ['output']
def infer_shape(self, in_shape):
# inputs, outputs, aux
return [in_shape[0]], [in_shape[0]], []
def create_operator(self, ctx, shapes, dtypes):
return MyAdd()
inp = mx.nd.zeros(shape=(100, 100))
a = mx.symbol.Variable('a')
b = a + 1
b = mx.symbol.Custom(data=a, op_type='MyAdd1')
c = mx.symbol.Custom(data=a, op_type='MyAdd2')
b._bind(mx.cpu(), {'a': inp}).forward()
c._bind(mx.cpu(), {'a': inp}).forward()
mx.nd.waitall()