fix
diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py
index fdf7dcb..a678e17 100644
--- a/python/mxnet/_ctypes/ndarray.py
+++ b/python/mxnet/_ctypes/ndarray.py
@@ -81,7 +81,7 @@
for i in range(num_output.value)]
-def call_cached(cached_op, args, out=None, name=None): # pylint: disable=unused-argument
+def invoke(cached_op, args, out=None, name=None): # pylint: disable=unused-argument
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
original_output = out
diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py
index d9cae1b..2ffa1a9 100644
--- a/python/mxnet/_ctypes/symbol.py
+++ b/python/mxnet/_ctypes/symbol.py
@@ -102,7 +102,7 @@
_symbol_cls = cls
-def call_cached(cached_op, args, name=None):
+def invoke(cached_op, args, name=None):
"""Call cached symbolic operator"""
ret = SymbolHandle()
hint = cached_op.op.lower()
diff --git a/python/mxnet/cython/ndarray.pyx b/python/mxnet/cython/ndarray.pyx
index e99549d..24e37b5 100644
--- a/python/mxnet/cython/ndarray.pyx
+++ b/python/mxnet/cython/ndarray.pyx
@@ -61,7 +61,7 @@
return nd
-def call_cached(cached_op, args, out=None, name=None):
+def invoke(cached_op, args, out=None, name=None):
"""ctypes implementation of imperative invoke wrapper"""
cdef vector[NDArrayHandle] ndvars
cdef vector[NDArrayHandle] output_vars
diff --git a/python/mxnet/cython/symbol.pyx b/python/mxnet/cython/symbol.pyx
index cf7bda4..e8787fb 100644
--- a/python/mxnet/cython/symbol.pyx
+++ b/python/mxnet/cython/symbol.pyx
@@ -79,7 +79,7 @@
return sym
-def call_cached(cached_op, args, name=None):
+def invoke(cached_op, args, name=None):
cdef SymbolHandle ret
cdef vector[SymbolHandle] sym_args
hint = cached_op.op.lower()
diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py
index 15b6f7c..c5d6754 100644
--- a/python/mxnet/ndarray.py
+++ b/python/mxnet/ndarray.py
@@ -32,18 +32,18 @@
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class
- from ._ctypes.ndarray import call_cached, CachedOp, _imperative_invoke
+ from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
elif _sys.version_info >= (3, 0):
from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._cy3.ndarray import call_cached, CachedOp, _imperative_invoke
+ from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke
else:
from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._cy2.ndarray import call_cached, CachedOp, _imperative_invoke
+ from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
- from ._ctypes.ndarray import call_cached, CachedOp, _imperative_invoke
+ from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
# pylint: enable=unused-import
# pylint: disable= no-member
diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py
index 6e2f2e6..16cbeae 100644
--- a/python/mxnet/symbol.py
+++ b/python/mxnet/symbol.py
@@ -29,18 +29,18 @@
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.symbol import SymbolBase, _set_symbol_class
- from ._ctypes.symbol import CachedOp, call_cached, _symbol_creator # pylint: disable=unused-import
+ from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import
elif _sys.version_info >= (3, 0):
from ._cy3.symbol import SymbolBase, _set_symbol_class
- from ._cy3.symbol import CachedOp, call_cached, _symbol_creator # pylint: disable=unused-import
+ from ._cy3.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import
else:
from ._cy2.symbol import SymbolBase, _set_symbol_class
- from ._cy2.symbol import CachedOp, call_cached, _symbol_creator # pylint: disable=unused-import
+ from ._cy2.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from ._ctypes.symbol import SymbolBase, _set_symbol_class
- from ._ctypes.symbol import CachedOp, call_cached, _symbol_creator # pylint: disable=unused-import
+ from ._ctypes.symbol import CachedOp, invoke, _symbol_creator # pylint: disable=unused-import
_GRAD_REQ_MAP = {'null': 0, 'write': 1, 'add': 3}
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 0a1e196..025624c 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -26,11 +26,12 @@
NDArray NDArray::Reshape(const TShape &shape) const {
using namespace autograd;
- CHECK_GE(shape_.Size(), shape.Size())
- << "NDArray.Reshape: target shape size is different from current shape";
- NDArray ret = *this;
- ret.shape_ = shape;
if (AutogradRuntime::Get()->IsTraining()) {
+ CHECK_GE(shape_.Size(), shape.Size())
+ << "NDArray.Reshape: target shape must have must have the same size as "
+ << "current shape when in train_section.";
+ NDArray ret = *this;
+ ret.shape_ = shape;
// fake a Reshape op
ret.entry_.clear();
const nnvm::Op* op = nnvm::Op::Get("Reshape");
@@ -47,6 +48,10 @@
op, attrs, &inputs, &outputs);
return outputs[0];
} else {
+ CHECK_GE(shape_.Size(), shape.Size())
+ << "NDArray.Reshape: target shape size is larger current shape";
+ NDArray ret = *this;
+ ret.shape_ = shape;
return ret;
}
}
diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py
index 5082f96..2be95a9 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -632,9 +632,9 @@
data = mx.nd.ones((3, 4, 10, 10))
weight = mx.nd.ones((10, 4, 3, 3))
bias = mx.nd.ones((10,))
- o1 = mx.nd.call_cached(op, [data, weight, bias])
+ o1 = mx.nd.invoke(op, [data, weight, bias])
bias[:] = 2
- o2 = mx.nd.call_cached(op, [data, weight, bias])
+ o2 = mx.nd.invoke(op, [data, weight, bias])
assert_almost_equal(o2.asnumpy(), o1.asnumpy()+1)
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 532657d..28fc8a4 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -230,12 +230,12 @@
data = mx.sym.var('data')
weight = mx.sym.var('weight')
bias = mx.sym.var('bias')
- out = mx.sym.call_cached(op, [data, weight, bias], 'conv')
+ out = mx.sym.invoke(op, [data, weight, bias], 'conv')
assert out.list_arguments() == ['data', 'weight', 'bias']
assert out.list_outputs() == ['conv_output']
with mx.name.Prefix('test_'):
- assert mx.sym.call_cached(op, [data, weight, bias]).name == 'test_convolution0'
- assert mx.sym.call_cached(op, [data, weight, bias]).name == 'test_convolution1'
+ assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution0'
+ assert mx.sym.invoke(op, [data, weight, bias]).name == 'test_convolution1'
if __name__ == '__main__':