blob: 47796874e7e1f4b4e45b8ef69031076f491ff6e2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
@mx.util.use_np
def test_save():
class MyBlock(mx.gluon.Block):
def __init__(self, **kwargs):
super(MyBlock, self).__init__(**kwargs)
self.layers = []
def add(self, block):
self.layers.append(block)
self.register_child(block)
def forward(self, x, *args):
out = (x,) + args
for block in self._children.values():
out = block()(*out)
return out
def createNet():
inside = MyBlock()
dense = mx.gluon.nn.Dense(10)
inside.add(dense)
net = MyBlock()
net.add(inside)
net.add(mx.gluon.nn.Dense(10))
return net
# create and initialize model
net1 = createNet()
net1.initialize()
# hybridize (the hybridizeable blocks, ie. the Dense layers)
net1.hybridize()
x = mx.np.zeros((1,10))
out1 = net1(x)
# save hybridized model
net1.save('MyModel')
# create a new model, uninitialized
net2 = createNet()
# reload hybridized model
net2.load('MyModel')
net2.hybridize()
# run inference again
out2 = net2(x)
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy())