blob: e67e1fc1a3bc94f931339f3b9c6663246031f6ae [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import mxnet as mx
from common import models
import pytest
mx.npx.reset_np()
def test_mlp2_infer_shape():
# Build MLP
out = models.mlp2()
# infer shape
data_shape = (100, 100)
arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape)
arg_shape_dict = dict(zip(out.list_arguments(), arg_shapes))
assert len(out_shapes) == 1
assert out_shapes[0] == (100, 10)
true_shapes = {'fc2_bias': (10,),
'fc2_weight' : (10, 1000),
'fc1_bias' : (1000,),
'fc1_weight' : (1000,100) }
for k, v in true_shapes.items():
assert arg_shape_dict[k] == v
def test_mlp2_infer_error():
# Test shape inconsistent case
out = models.mlp2()
weight_shape= (1, 100)
data_shape = (100, 100)
with pytest.raises(mx.MXNetError):
arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=data_shape, fc1_weight=weight_shape)
def test_backward_infer():
w = mx.sym.Variable("weight")
wshift = mx.sym.Variable("wshift", shape=(1,))
data = mx.sym.Variable("data")
# broadcast add here, not being able to deduce shape correctly
wt = mx.sym.broadcast_add(w, wshift)
# shape constraint, this is what enables backward shape inference
wt = mx.symbol._internal._identity_with_attr_like_rhs(wt, w)
net = mx.sym.FullyConnected(data=data, weight=wt, num_hidden=11, no_bias=True)
data_shape = (7, 100)
arg_shapes, out_shapes, aux_shapes = net.infer_shape(data=data_shape)
arg_shape_dict = dict(zip(net.list_arguments(), arg_shapes))
true_shapes = {'weight': (11, 100)}
for k, v in true_shapes.items():
assert arg_shape_dict[k] == v
def test_incomplete_infer_elewise():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.Variable('b', shape=(12, 0))
c = a + b
arg_shapes, _, _ = c.infer_shape()
arg_names = c.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (12, 10)
assert arg_shapes['b'] == (12, 10)
def test_incomplete_infer_mlp():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.FullyConnected(data=a, num_hidden=21)
c = mx.sym.Variable('c', shape=(5, 0))
d = b + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10)
assert arg_shapes['c'] == (5, 21)
def test_incomplete_infer_slicechannel():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.SliceChannel(data=a, num_outputs=10, axis=1, squeeze_axis=True)
c = mx.sym.Variable('c', shape=(5,))
d = b[1] + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10)
a = mx.sym.Variable('a', shape=(0, 15, 0))
b = mx.sym.SliceChannel(data=a, num_outputs=3, squeeze_axis=False)
c = mx.sym.Variable('c', shape=(3, 5, 2))
d = b[1] + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (3, 15, 2)
def test_incomplete_infer_convolution():
a = mx.sym.Variable('a', shape=(0, 10, 0, 0))
b = mx.sym.Convolution(data=a, num_filter=21, kernel=(3, 3), dilate=(1, 1), pad=(1, 1))
c = mx.sym.Variable('c', shape=(5, 21, 32, 32))
d = b + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (5, 10, 32, 32)
def test_incomplete_infer_concat():
a = mx.sym.Variable('a', shape=(0, 10))
b = mx.sym.Variable('b', shape=(0, 5))
c = mx.sym.Concat(a, b, num_args=2, dim=1)
d = mx.sym.Variable('d', shape=(2, 0))
d = d + c
arg_shapes, _, _ = d.infer_shape()
arg_names = d.list_arguments()
arg_shapes = {k: v for k, v in zip(arg_names, arg_shapes)}
assert arg_shapes['a'] == (2, 10)
assert arg_shapes['b'] == (2, 5)
assert arg_shapes['d'] == (2, 15)
def test_fc_infer_type():
mx_real_t = mx.base.mx_real_t
data = mx.symbol.Variable('data')
out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000)
# infer type
data_type = mx_real_t
arg_types, out_types, aux_types = out.infer_type(data=data_type)
arg_type_dict = dict(zip(out.list_arguments(), arg_types))
assert len(out_types) == 1
assert out_types[0] == mx_real_t
true_types = {
'fc1_bias' : mx_real_t,
'fc1_weight' : mx_real_t }
for k, v in true_types.items():
assert arg_type_dict[k] == v
def test_shape_completely_unknown():
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] == ()
assert out_shapes[0] == ()
with mx.np_shape():
data = mx.sym.var("data")
ret = mx.sym.sin(data)
arg_shapes, out_shapes, _ = ret.infer_shape_partial()
assert arg_shapes[0] is None
assert out_shapes[0] is None
def test_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.dot(x, y)
# batch size(first dim) of lhs unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
assert result_shape == [(0, 3, 5)]
with mx.np_shape(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
assert result_shape == [(-1, 3, 5)]
def test_batch_dot_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
z = mx.sym.batch_dot(x, y)
# lhs and rhs batch size unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
assert result_shape == [(0, 3, 5)]
# rhs second dim unknown
_, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
assert result_shape == [()]
with mx.np_shape(True):
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
assert result_shape == [(-1, 3, 5)]
_, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5))
assert result_shape == [None]
def test_embedding_partial_shape():
# testing embedding with batch size unknown
x = mx.sym.Variable("x")
w = mx.sym.Variable("w")
y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
_, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
assert result_shape == [(0, 5, 10)]
with mx.np_shape(True):
_, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
assert result_shape == [(-1, 5, 10)]
def test_transpose_partial_shape():
# test converting tensor shape
# from channels first to channels last
# with batch size unknown
axes = [0, 3, 2, 1]
x = mx.sym.Variable("x")
y = mx.sym.transpose(x, axes=axes)
_, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
assert result == [(0, 224, 224, 3)]
with mx.np_shape(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
assert result == [(-1, 224, 224, 3)]
def test_pick_partial_shape():
x = mx.sym.Variable("x")
index = mx.sym.Variable("index")
y = mx.sym.pick(x, index, axis=1)
# batch size unknown
_, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,))
assert result == [(0, 3)]
with mx.np_shape(True):
_, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,))
assert result == [(-1, 3)]
def test_where_partial_shape():
x = mx.sym.Variable("x")
y = mx.sym.Variable("y")
cond = mx.sym.Variable("cond")
where_op = mx.sym.where(cond, x, y)
# condition must be fully known to infer shape
_, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
assert result == [()]
_, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2))
assert result == [()]
with mx.np_shape(True):
_, result, _ = where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))
assert result == [None]
_, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2))
assert result == [None]