Fix parameter initialization (#6728) * fix * fix parameters initialization * refactor tutorial * fix * fix
diff --git a/docs/tutorials/basic/foo.md b/docs/tutorials/basic/foo.md deleted file mode 100644 index 84b1427..0000000 --- a/docs/tutorials/basic/foo.md +++ /dev/null
@@ -1,291 +0,0 @@ -# Foo - High-level Interface - -Foo package is a high-level interface for MXNet designed to be easy to use while -keeping most of the flexibility of low level API. Foo supports both imperative -and symbolic programming, making it easy to train complex models imperatively -in Python and then deploy with symbolic graph in C++ and Scala. - -This tutorial covers four topics: -- MXNet NDArray as a replacement of numpy for asynchronous scientific computing -across CPU and GPU. -- Automatic differentiation with NDArray. -- Define and train neural network models with Foo's imperative API. -- [TODO] Save trained models as symbolic graph for easy production deployment. - -## Setup -First, let's import MXNet and Foo: - -```python -from __future__ import print_function -import numpy as np -import mxnet as mx -``` - -## NDArray - -### Creating NDArray - -NDArray is similar to numpy's ndarray, but supports asynchronous operations -and GPU. There are many ways to create NDArray. - -Construct from (nested) list: -```python -x = mx.nd.array([[1, 2, 3], [4, 5, 6]]) -print(x) -``` - -Construct from numpy array: -```python -x_numpy = np.ones((2, 3)) -x = mx.nd.array(x_numpy) -print(x) -``` - -Array construction routines: -```python -# create an 2x3 array of ones -x = mx.nd.ones((2, 3)) -print(x) -# create an 2x3 array of zeros -x = mx.nd.zeros((2, 3)) -print(x) -# create an 1d-array of 0 to 5 and reshape to 2x3 -x = mx.nd.arange(6).reshape((2, 3)) -print(x) -``` - -You can convert any NDArray to numpy array with `.asnumpy()`: -```python -z = x.asnumpy() -print(z) -``` - -### NDArray Operations - -NDArray supports a wide range of operations. Simple operations can be called -with python syntax: - -```python -x = mx.nd.array([[1, 2], [3, 4]]) -y = mx.nd.array([[4, 3], [2, 1]]) -print(x + y) -``` - -You can also call operators from the `mxnet.ndarray` (or `mx.nd` for short) name space: - -```python -z = mx.nd.add(x, y) -print(z) -``` - -You can also pass additional flags to operators: - -```python -z = mx.nd.sum(x, axis=0) -print('axis=0:', z) -z = mx.nd.sum(x, axis=1) -print('axis=1:', z) -``` - -By default operators create new NDArrays for return value. You can specify `out` -to use a pre-allocated buffer: - -```python -z = mx.nd.empty((2, 2)) -mx.nd.add(x, y, out=z) -print(x) -``` - -### Using GPU - -Each NDArray lives on a `Context`. MXNet supports `mx.cpu()` for CPU and `mx.gpu(0)`, -`mx.gpu(1)`, etc for GPU. You can specify context when creating NDArray: - -```python -# creates on CPU (the default). -# Replace mx.cpu() with mx.gpu(0) if you have a GPU. -x = mx.nd.zeros((2, 2), ctx=mx.cpu()) -print(x) -x = mx.nd.array([[1, 2], [3, 4]], ctx=mx.cpu()) -print(x) -``` - -You can copy arrays between devices with `.copyto()`: - -```python -# Copy x to cpu. Replace with mx.gpu(0) if you have GPU. -y = x.copyto(mx.cpu()) -# Copy x to another NDArray, possibly on another Context. -x.copyto(y) -print(y) -``` - -See the [NDArray tutorial](ndarray.md) for a more detailed introduction to -NDArray API. - -## Automatic Differentiation - -MXNet supports automatic differentiation with the `autograd` package. -`autograd` allows you to differentiate a network of NDArray operations. -This is call define-by-run, i.e., the network is defined on-the-fly by -running forward computation. You can define exotic network structures -and differentiate them, and each iteration can have a totally different -network structure. - -```python -form mxnet import autograd -from mxnet.autograd import train_section -``` - -To use `autograd`, we must first mark variables that require gradient and -attach gradient buffers to them: - -```python -x = mx.nd.array([[1, 2], [3, 4]]) -dx = mx.nd.zeros_like(x) -x.attach_grad(dx) -``` - -Now we can define the network while running forward computation by wrapping -it inside a `train_section` (operations out of `train_section` does not define -a graph and cannot be differentiated): - -```python -with train_section(): - y = x * 2 - z = y * x -``` - -Let's backprop with `z.backward()`, which is equivalent to -`z.backward(mx.nd.ones_like(z))`. When z has more than one entry, `z.backward()` -is equivalent to `mx.nd.sum(z).backward()`: - -```python -z.backward() -print(x.grad) -``` - -## Neural Network and Layers - -Neural networks (and other machine learning models) can be defined and trained -with `foo.nn` and `foo.rnn` package. A typical training script has the following -steps: - -- Define network -- Initialize parameters -- Loop over inputs -- Forward input through network to get output -- Compute loss with output and label -- Backprop gradient -- Update parameters with gradient descent. - - -### Define Network - -`foo.nn.Layer` is the basic building block of models. You can define networks by -composing and inheriting `Layer`: - -```python -import mxnet.foo as foo -from mxnet.foo import nn - -class Net(nn.Layer): - def __init__(self, **kwargs): - super(Net, self).__init__(**kwargs) - with self.name_scope: - # layers created in name_scope will inherit name space - # from parent layer. - self.conv1 = nn.Conv2D(6, kernel_size=5) - self.pool1 = nn.Pool2D(kernel_size=2) - self.conv2 = nn.Conv2D(16, kernel_size=5) - self.pool2 = nn.Pool2D(kernel_size=2) - self.fc1 = nn.Dense(120) - self.fc2 = nn.Dense(84) - self.fc3 = nn.Dense(10) - - def forward(self, F, x): - x = self.pool1(F.relu(self.conv1(x))) - x = self.pool2(F.relu(self.conv2(x))) - # 0 means copy over size from corresponding dimension. - # -1 means infer size from the rest of dimensions. - x = x.reshape((0, -1)) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x -``` - -### Initialize Parameters - -A network must be created and initialized before it can be used: - -```python -net = Net() -# Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`, -# etc to use one or more GPUs. -net.all_params().initialize(mx.init.Xavier(), ctx=mx.cpu()) -``` - -Note that because we didn't specify input size to layers in Net's constructor, -the shape of parameters cannot be determined at this point. Actual initialization -is deferred to the first forward pass, i.e. if you access `net.fc1.weight.data()` -now an exception will be raised. - -You can actually initialize the weights by running a forward pass: - -```python -data = mx.nd.random_normal(shape=(10, 1, 32, 32)) # dummy data -output = net(data) -``` - -Or you can specify input size when creating layers, i.e. `nn.Dense(84, in_units=120)` -instead of `nn.Dense(84)`. - -### Loss Functions - -Loss functions take (output, label) pairs and compute a scalar loss for each sample -in the mini-batch. The scalars measure how far each output is from the label. - -There are many predefined loss functions in `foo.loss`. Here we use -`softmax_cross_entropy_loss` for digit classification. - -To compute loss and backprop for one iteration, we do: - -```python -label = mx.nd.arange(10) # dummy label -with train_section(): - output = net(data) - loss = foo.loss.softmax_cross_entropy_loss(output, label) - loss.backward() -print('loss:', loss) -print('grad:', net.fc1.weight.grad()) -``` - -### Updating the weights - -Now that gradient is computed, we just need to update the weights. This is usually -done with formulas like `weight = weight - learning_rate * grad / batch_size`. -Note we divide gradient by batch_size because gradient is aggregated over the -entire batch. For example, - -```python -lr = 0.01 -for p in net.all_params().values(): - p.data()[:] -= lr / data.shape[0] * p.grad() -``` - -But sometimes you want more fancy updating rules like momentum and Adam, and since -this is a commonly used functionality, foo provide a `Trainer` class for it: - -```python -trainer = foo.Trainer(net.all_params(), 'sgd', {'learning_rate': 0.01}) - -with train_section(): - output = net(data) - loss = foo.loss.softmax_cross_entropy_loss(output, label) - loss.backward() - -# do the update. Trainer needs to know the batch size of data to normalize -# the gradient by 1/batch_size. -trainer.step(data.shape[0]) -```
diff --git a/docs/tutorials/foo/autograd.md b/docs/tutorials/foo/autograd.md new file mode 100644 index 0000000..5d1d615 --- /dev/null +++ b/docs/tutorials/foo/autograd.md
@@ -0,0 +1,42 @@ +# Automatic differentiation + +MXNet supports automatic differentiation with the `autograd` package. +`autograd` allows you to differentiate a graph of NDArray operations +with the chain rule. +This is called define-by-run, i.e., the network is defined on-the-fly by +running forward computation. You can define exotic network structures +and differentiate them, and each iteration can have a totally different +network structure. + +```python +import mxnet as mx +from mxnet import autograd +``` + +To use `autograd`, we must first mark variables that require gradient and +attach gradient buffers to them: + +```python +x = mx.nd.array([[1, 2], [3, 4]]) +dx = mx.nd.zeros_like(x) +x.attach_grad(dx) +``` + +Now we can define the network while running forward computation by wrapping +it inside a `train_section` (operations out of `train_section` does not define +a graph and cannot be differentiated): + +```python +with autograd.train_section(): + y = x * 2 + z = y * x +``` + +Let's backprop with `z.backward()`, which is equivalent to +`z.backward(mx.nd.ones_like(z))`. When z has more than one entry, `z.backward()` +is equivalent to `mx.nd.sum(z).backward()`: + +```python +z.backward() +print(x.grad) +```
diff --git a/docs/tutorials/foo/foo.md b/docs/tutorials/foo/foo.md new file mode 100644 index 0000000..c454e34 --- /dev/null +++ b/docs/tutorials/foo/foo.md
@@ -0,0 +1,136 @@ +# Foo - Neural network building blocks + +Foo package is a high-level interface for MXNet designed to be easy to use while +keeping most of the flexibility of low level API. Foo supports both imperative +and symbolic programming, making it easy to train complex models imperatively +in Python and then deploy with symbolic graph in C++ and Scala. + + +```python +# import dependencies +from __future__ import print_function +import numpy as np +import mxnet as mx +import mxnet.foo as foo +from mxnet.foo import nn +``` + +Neural networks (and other machine learning models) can be defined and trained +with `foo.nn` and `foo.rnn` package. A typical training script has the following +steps: + +- Define network +- Initialize parameters +- Loop over inputs +- Forward input through network to get output +- Compute loss with output and label +- Backprop gradient +- Update parameters with gradient descent. + + +## Define Network + +`foo.nn.Layer` is the basic building block of models. You can define networks by +composing and inheriting `Layer`: + +```python +class Net(nn.Layer): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + with self.name_scope: + # layers created in name_scope will inherit name space + # from parent layer. + self.conv1 = nn.Conv2D(6, kernel_size=5) + self.pool1 = nn.Pool2D(kernel_size=2) + self.conv2 = nn.Conv2D(16, kernel_size=5) + self.pool2 = nn.Pool2D(kernel_size=2) + self.fc1 = nn.Dense(120) + self.fc2 = nn.Dense(84) + self.fc3 = nn.Dense(10) + + def forward(self, F, x): + x = self.pool1(F.relu(self.conv1(x))) + x = self.pool2(F.relu(self.conv2(x))) + # 0 means copy over size from corresponding dimension. + # -1 means infer size from the rest of dimensions. + x = x.reshape((0, -1)) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x +``` + +## Initialize Parameters + +A network must be created and initialized before it can be used: + +```python +net = Net() +# Initialize on CPU. Replace with `mx.gpu(0)`, or `[mx.gpu(0), mx.gpu(1)]`, +# etc to use one or more GPUs. +net.all_params().initialize(mx.init.Xavier(), ctx=mx.cpu()) +``` + +Note that because we didn't specify input size to layers in Net's constructor, +the shape of parameters cannot be determined at this point. Actual initialization +is deferred to the first forward pass, i.e. if you access `net.fc1.weight.data()` +now an exception will be raised. + +You can actually initialize the weights by running a forward pass: + +```python +data = mx.nd.random_normal(shape=(10, 1, 32, 32)) # dummy data +output = net(data) +``` + +Or you can specify input size when creating layers, i.e. `nn.Dense(84, in_units=120)` +instead of `nn.Dense(84)`. + +## Loss Functions + +Loss functions take (output, label) pairs and compute a scalar loss for each sample +in the mini-batch. The scalars measure how far each output is from the label. + +There are many predefined loss functions in `foo.loss`. Here we use +`softmax_cross_entropy_loss` for digit classification. + +To compute loss and backprop for one iteration, we do: + +```python +label = mx.nd.arange(10) # dummy label +with train_section(): + output = net(data) + loss = foo.loss.softmax_cross_entropy_loss(output, label) + loss.backward() +print('loss:', loss) +print('grad:', net.fc1.weight.grad()) +``` + +## Updating the weights + +Now that gradient is computed, we just need to update the weights. This is usually +done with formulas like `weight = weight - learning_rate * grad / batch_size`. +Note we divide gradient by batch_size because gradient is aggregated over the +entire batch. For example, + +```python +lr = 0.01 +for p in net.all_params().values(): + p.data()[:] -= lr / data.shape[0] * p.grad() +``` + +But sometimes you want more fancy updating rules like momentum and Adam, and since +this is a commonly used functionality, foo provide a `Trainer` class for it: + +```python +trainer = foo.Trainer(net.all_params(), 'sgd', {'learning_rate': 0.01}) + +with train_section(): + output = net(data) + loss = foo.loss.softmax_cross_entropy_loss(output, label) + loss.backward() + +# do the update. Trainer needs to know the batch size of data to normalize +# the gradient by 1/batch_size. +trainer.step(data.shape[0]) +```
diff --git a/docs/tutorials/foo/ndarray.md b/docs/tutorials/foo/ndarray.md new file mode 100644 index 0000000..bc5d9c4 --- /dev/null +++ b/docs/tutorials/foo/ndarray.md
@@ -0,0 +1,152 @@ +# NDArray - Scientific computing on CPU and GPU + +NDArray is a tensor data structure similar to numpy's multi-dimensional array. +In addition, it supports asynchronous computation on CPU and GPU. + +First, let's import MXNet: + +```python +from __future__ import print_function +import numpy as np +import mxnet as mx +``` + +## Creating NDArray + +There are many ways to create NDArray. + +Construct from (nested) list: +```python +x = mx.nd.array([[1, 2, 3], [4, 5, 6]]) +print(x) +``` + +Construct from numpy array: +```python +x_numpy = np.ones((2, 3)) +x = mx.nd.array(x_numpy) +print(x) +``` + +Array construction routines: +```python +# create an 2x3 array of ones +x = mx.nd.ones((2, 3)) +print(x) +# create an 2x3 array of zeros +x = mx.nd.zeros((2, 3)) +print(x) +# create an 1d-array of 0 to 5 and reshape to 2x3 +x = mx.nd.arange(6).reshape((2, 3)) +print(x) +``` + +You can convert an NDArray to numpy array to retrieve its data with `.asnumpy()`: +```python +z = x.asnumpy() +print(z) +``` + +## Basic attributes + +NDArray has some basic attributes that you often want to query: + +**NDArray.shape**: The dimensions of the array. It is a tuple of integers +indicating the length of the array along each axis. For a matrix with `n` rows +and `m` columns, its `shape` will be `(n, m)`. + +```python +print('x.shape:', x.shape) +``` + +**NDArray.dtype**: A `numpy` _type_ object describing the type of array +elements. + +```python +print('x.dtype:', x.dtype) +``` + +**NDArray.size**: the total number of components in the array - equals to the +product of the components of its `shape` + +```python +print('x.size:', x.size) +``` + +**NDArray.context**: The device on which this array is stored, e.g. `mx.cpu()` +or `mx.gpu(1)`. + +```python +print('x.context:', x.context) +``` + +## NDArray Operations + +NDArray supports a wide range of operations. Simple operations can be called +with python syntax: + +```python +x = mx.nd.array([[1, 2], [3, 4]]) +y = mx.nd.array([[4, 3], [2, 1]]) +print(x + y) +``` + +You can also call operators from the `mxnet.ndarray` (or `mx.nd` for short) name space: + +```python +z = mx.nd.add(x, y) +print(z) +``` + +You can also pass additional flags to operators: + +```python +z = mx.nd.sum(x, axis=0) +print('axis=0:', z) +z = mx.nd.sum(x, axis=1) +print('axis=1:', z) +``` + +By default operators create new NDArrays for return value. You can specify `out` +to use a pre-allocated buffer: + +```python +z = mx.nd.empty((2, 2)) +mx.nd.add(x, y, out=z) +print(x) +``` + +## Using GPU + +Each NDArray lives on a `Context`. MXNet supports `mx.cpu()` for CPU and `mx.gpu(0)`, +`mx.gpu(1)`, etc for GPU. You can specify context when creating NDArray: + +```python +# creates on CPU (the default). +# Replace mx.cpu() with mx.gpu(0) if you have a GPU. +x = mx.nd.zeros((2, 2), ctx=mx.cpu()) +print(x) +``` + +```python +x = mx.nd.array([[1, 2], [3, 4]], ctx=mx.cpu()) +print(x) +``` + +You can copy arrays between devices with `.copyto()`: + +```python +# Copy x to cpu. Replace with mx.gpu(0) if you have GPU. +y = x.copyto(mx.cpu()) +print(y) +``` + +```python +# Copy x to another NDArray, possibly on another Context. +y = mx.nd.zeros_like(x) +x.copyto(y) +print(y) +``` + +See the [Advanced NDArray tutorial](../basic/ndarray.md) for a more detailed +introduction to NDArray API.
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index dc56cb1..cb8a2ec 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md
@@ -4,13 +4,23 @@ ## Python -### Basics +### Basics - High-level interface ```eval_rst .. toctree:: :maxdepth: 1 - basic/foo + foo/ndarray + foo/autograd + foo/foo +``` + +### Advanced -- Low-level interface + +```eval_rst +.. toctree:: + :maxdepth: 1 + basic/ndarray basic/symbol basic/module
diff --git a/python/mxnet/foo/parameter.py b/python/mxnet/foo/parameter.py index 50c9c61..28d28f4 100644 --- a/python/mxnet/foo/parameter.py +++ b/python/mxnet/foo/parameter.py
@@ -116,6 +116,36 @@ self._defered_init = (init, ctx, default_init) self._finish_deferred_init() + def _load_init(self, data, ctx): + """(Re)init by loading from data.""" + if self.shape: + for i, j in zip(self.shape, data.shape): + assert i == 0 or i == j, \ + "Failed loading Parameter %s from saved params: " \ + "shape incompatible expacted %s vs saved %s"%( + self.name, str(self.shape), str(data.shape)) + if self.dtype: + assert np.dtype(self.dtype).type == data.dtype, \ + "Failed loading Parameter %s from saved params: " \ + "dtype incompatible expacted %s vs saved %s"%( + self.name, str(self.dtype), str(data.dtype)) + if isinstance(ctx, Context): + ctx = [ctx] + if self._data is None: + if self._defered_init: + assert set(ctx) == set(self._defered_init[1]), \ + "Failed to load Parameter %s on %s because it was " \ + "previous initialized on %s."%( + self.name, str(ctx), str(self.list_ctx())) + self._init_impl(data, ctx) + else: + assert set(ctx) == set(self.list_ctx()), \ + "Failed to load Parameter %s on %s because it was " \ + "previous initialized on %s."%( + self.name, str(ctx), str(self.list_ctx())) + self.set_data(data) + self._defered_init = () + def _finish_deferred_init(self): """Finish deferred initialization.""" if not self._defered_init: @@ -129,27 +159,30 @@ self.name, str(self.shape)) with autograd.test_section(): - data = ndarray.zeros(shape=self.shape, dtype=self.dtype, ctx=ctx[0]) + data = ndarray.zeros(shape=self.shape, dtype=self.dtype, + ctx=context.cpu()) if init is None: init = self.init initializer.create(default_init)( - initializer.InitDesc(self.name, {'__init__': init}), - data) + initializer.InitDesc(self.name, {'__init__': init}), data) - self._data = OrderedDict() - self._data[ctx[0]] = data - for i in ctx[1:]: - self._data[i] = data.copyto(i) + self._init_impl(data, ctx) - if self.grad_req == 'null': - self._grad = None - return + def _init_impl(self, data, ctx): + """Set data and grad.""" + self._data = OrderedDict() + for i in ctx: + self._data[i] = data.copyto(i) - self._grad = OrderedDict() - for i in ctx: - self._grad[i] = ndarray.zeros_like(self._data[i]) + if self.grad_req == 'null': + self._grad = None + return - autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req) + self._grad = OrderedDict() + for i in ctx: + self._grad[i] = ndarray.zeros_like(self._data[i]) + + autograd.mark_variables(self.list_data(), self.list_grad(), self.grad_req) def set_data(self, data): """Set this parameter's value on all contexts to data.""" @@ -365,7 +398,7 @@ arg_dict[param.name] = weight ndarray.save(filename, arg_dict) - def load(self, filename, allow_missing=False, ignore_extra=False): + def load(self, filename, ctx, allow_missing=False, ignore_extra=False): arg_dict = ndarray.load(filename) if not allow_missing: for name in self.keys(): @@ -377,4 +410,4 @@ "Parameter %s loaded from file %s is not present in ParameterDict"%( name, filename) continue - self[name].set_data(arg_dict[name]) + self[name]._load_init(arg_dict[name], ctx)
diff --git a/python/mxnet/foo/trainer.py b/python/mxnet/foo/trainer.py index 514dfbd..8d22983 100644 --- a/python/mxnet/foo/trainer.py +++ b/python/mxnet/foo/trainer.py
@@ -68,6 +68,7 @@ kvstore.pull(i, param_arrays, priority=-i) if update_on_kvstore: kvstore.set_optimizer(self._optimizer) + self._kv_initialized = True def step(self, batch_size, ignore_stale_grad=False): """Make one step of parameter update. Should be called after
diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index a70b81b..5e08fcf 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py
@@ -933,7 +933,7 @@ return self return self.copyto(context) - def set_grad(self, grad_req='write'): + def attach_grad(self, grad_req='write'): """Attach a gradient buffer to this NDArray, so that `backward` can compute gradient with respect to it.
diff --git a/tests/python/train/test_autograd.py b/tests/python/train/test_autograd.py new file mode 100644 index 0000000..25cd505 --- /dev/null +++ b/tests/python/train/test_autograd.py
@@ -0,0 +1,90 @@ +# pylint: skip-file +from __future__ import print_function + +import mxnet as mx +from mxnet import foo +from mxnet.foo import nn +import numpy as np +import logging +from common import get_data +from mxnet.contrib import autograd as ag +logging.basicConfig(level=logging.DEBUG) + +# define network + +def get_net(): + net = nn.Sequential() + net.add(nn.Dense(128, activation='relu', prefix='fc1_')) + net.add(nn.Dense(64, activation='relu', prefix='fc2_')) + net.add(nn.Dense(10, prefix='fc3_')) + return net + +get_data.GetMNIST_ubyte() + +batch_size = 100 +train_data = mx.io.MNISTIter( + image="data/train-images-idx3-ubyte", + label="data/train-labels-idx1-ubyte", + data_shape=(784,), + label_name='sm_label', + batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) +val_data = mx.io.MNISTIter( + image="data/t10k-images-idx3-ubyte", + label="data/t10k-labels-idx1-ubyte", + data_shape=(784,), + label_name='sm_label', + batch_size=batch_size, shuffle=True, flat=True, silent=False) + +def score(net, ctx): + metric = mx.metric.Accuracy() + val_data.reset() + for batch in val_data: + data = foo.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) + label = foo.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) + outputs = [] + for x in data: + outputs.append(net(x)) + metric.update(label, outputs) + return metric.get()[1] + +def train(net, epoch, ctx): + net.all_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx) + trainer = foo.Trainer(net.all_params(), 'sgd', {'learning_rate': 0.5}) + metric = mx.metric.Accuracy() + + for i in range(epoch): + train_data.reset() + for batch in train_data: + data = foo.utils.load_data(batch.data[0], ctx_list=ctx, batch_axis=0) + label = foo.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0) + outputs = [] + with ag.train_section(): + for x, y in zip(data, label): + z = net(x) + loss = foo.loss.softmax_cross_entropy_loss(z, y) + ag.compute_gradient([loss]) + outputs.append(z) + metric.update(label, outputs) + trainer.step(batch.data[0].shape[0]) + name, acc = metric.get() + metric.reset() + print('training acc at epoch %d: %s=%f'%(i, name, acc)) + + +def test_autograd(): + net1 = get_net() + train(net1, 5, [mx.cpu(0), mx.cpu(1)]) + acc1 = score(net1, [mx.cpu(0)]) + acc2 = score(net1, [mx.cpu(0), mx.cpu(1)]) + assert acc1 > 0.95 + assert abs(acc1 - acc2) < 0.01 + net1.all_params().save('mnist.params') + + net2 = get_net() + net2.all_params().load('mnist.params', ctx=[mx.cpu(0)]) + acc3 = score(net2, [mx.cpu(0)]) + assert abs(acc3 - acc1) < 0.0001 + + +if __name__ == '__main__': + test_autograd()
diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index eb73a12..abcaef4 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py
@@ -234,10 +234,10 @@ "differentiating the same graph twice without retain_graph should fail") -def test_set_grad(): +def test_attach_grad(): x = mx.nd.zeros((10,)) assert x.grad is None - x.set_grad() + x.attach_grad() with train_section(): y = x * 2 assert y.grad is None
diff --git a/tests/python/unittest/test_nn.py b/tests/python/unittest/test_nn.py index bd9eca6..8bb490c 100644 --- a/tests/python/unittest/test_nn.py +++ b/tests/python/unittest/test_nn.py
@@ -18,9 +18,9 @@ params = foo.ParameterDict('net_') params.get('weight', shape=(10, 10)) assert list(params.keys()) == ['net_weight'] - params.initialize() + params.initialize(ctx=mx.cpu()) params.save('test.params') - params.load('test.params') + params.load('test.params', mx.cpu()) def test_parameter_sharing():