blob: b3d8a143e16fd22e615ebdef592c8308d8b93705 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from common import xfail_when_nonstandard_decimal_separator
import mxnet as mx
from mxnet import nd, autograd, gluon
from mxnet.test_utils import default_device
@mx.util.use_np
class RoundSTENET(gluon.HybridBlock):
def __init__(self, w_init, **kwargs):
super(RoundSTENET, self).__init__(**kwargs)
self.w = gluon.Parameter('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write')
@staticmethod
def expected_grads(self, in_data, w_init):
return mx.np.round(in_data * w_init) + (in_data * w_init)
@staticmethod
def expected_output(self, in_data, w_init):
return mx.np.round(in_data * w_init) * w_init
def forward(self, x):
# Simple forward function: round_ste(w*x)*w
out = self.w.data(x.device) * x
out = mx.npx.round_ste(out)
# Uncomment to see how test fails with round
# out = F.round(out)
out = out * self.w.data(x.device)
return out
@mx.util.use_np
class SignSTENET(gluon.HybridBlock):
def __init__(self, w_init, **kwargs):
super(SignSTENET, self).__init__(**kwargs)
self.w = gluon.Parameter('w', shape=30, init=mx.initializer.Constant(w_init), grad_req='write')
@staticmethod
def expected_grads(self, in_data, w_init):
return mx.np.sign(in_data * w_init) + (in_data * w_init)
@staticmethod
def expected_output(self, in_data, w_init):
return mx.np.sign(in_data * w_init) * w_init
def forward(self, x):
# Simple forward function: sign_ste(w*x)*w
out = self.w.data(x.device) * x
out = mx.npx.sign_ste(out)
# Uncomment to see how test fails with sign
# out = F.sign(out)
out = out * self.w.data(x.device)
return out
def check_ste(net_type_str, w_init, hybridize, in_data, device=None):
device = device or default_device()
net = eval(net_type_str)(w_init=w_init)
if hybridize:
net.hybridize()
# Init
net.initialize(mx.init.Constant([w_init]), device=device)
# Test:
in_data = in_data.to_device(device)
with mx.autograd.record():
out = net(in_data)
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \
" expected " + str(net.expected_output(in_data, w_init))
out.backward()
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \
str(net.w.grad()) + " but expected " + \
str(net.expected_grads(in_data, w_init))
with mx.autograd.record():
out = net(in_data)
assert all(out == net.expected_output(in_data, w_init)), net_type_str + " output is " + str(out) + ", but" + \
" expected " + str(net.expected_output(in_data, w_init))
out.backward()
assert all(net.w.grad() == net.expected_grads(in_data, w_init)), net_type_str + " w grads are " + \
str(net.w.grad()) + " but expected " + \
str(net.expected_grads(in_data, w_init))
@xfail_when_nonstandard_decimal_separator
def test_contrib_round_ste():
# Test with random data
in_data = mx.np.random.uniform(-10, 10, size=30) # 10 and 30 are arbitrary numbers
w_init = float(mx.np.random.uniform(-10, 10, size=1).item())
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)
# Test 1.5 (verifies that .5 rounds the same as in round)
in_data = mx.np.array([1.5]*30) # 10 and 30 are arbitrary numbers
w_init = 1.
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)
# Test 0
in_data = mx.np.array([0]*30) # 10 and 30 are arbitrary numbers
w_init = 0.
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="RoundSTENET", w_init=w_init, hybridize=False, in_data=in_data)
@xfail_when_nonstandard_decimal_separator
def test_contrib_sign_ste():
in_data = mx.np.random.uniform(-10, 10, size=30) # 10 and 30 are arbitrary numbers
w_init = float(mx.np.random.uniform(-10, 10, size=1).item())
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data)
# Test 0
in_data = mx.np.array([0]*30) # 10 and 30 are arbitrary numbers
w_init = 0.
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=True, in_data=in_data)
check_ste(net_type_str="SignSTENET", w_init=w_init, hybridize=False, in_data=in_data)