blob: 07c0fb843be8ef3881e3ef221383d1e9981b2d9e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
from __future__ import print_function
import numpy as np
import mxnet as mx
from numpy.testing import assert_allclose
def assert_mx_allclose(A, B, **kwds):
return assert_allclose(A.asnumpy(), B.asnumpy(), **kwds)
def test_krprod_one_input():
A = mx.nd.arange(1,9).reshape((2,4))
out = mx.nd.khatri_rao(A)
assert_mx_allclose(out, A, rtol=1e-12)
def test_krprod_two_inputs():
A = mx.nd.arange(1,7).reshape((3,2))
B = mx.nd.arange(1,3).reshape((1,2))
out = mx.nd.khatri_rao(A, B)
expected = mx.nd.array([[1,4],[3,8],[5,12]])
assert_mx_allclose(out, expected, rtol=1e-12)
A = mx.nd.arange(1,7).reshape((3,2))
B = mx.nd.arange(1,9).reshape((4,2))
out = mx.nd.khatri_rao(A, B)
expected = mx.nd.array([[1,4],[3,8],[5,12],[7,16],[3,8],[9,16],[15,24],
[21,32],[5,12],[15,24],[25,36],[35,48]])
assert_mx_allclose(out, expected, rtol=1e-12)
def test_krprod_three_inputs():
A = mx.nd.arange(1,7).reshape((3,2))
B = mx.nd.arange(1,3).reshape((1,2))
C = mx.nd.arange(1,5).reshape((2,2))
out = mx.nd.khatri_rao(A, B, C)
expected = mx.nd.array([[1,8],[3,16],[3,16],[9,32],[5,24],[15,48]])
assert_mx_allclose(out, expected, rtol=1e-12)
out_AB = mx.nd.khatri_rao(A, B)
out = mx.nd.khatri_rao(out_AB, C)
assert_mx_allclose(out, expected, rtol=1e-12)
out_BC = mx.nd.khatri_rao(B, C)
out = mx.nd.khatri_rao(A, out_BC)
assert_mx_allclose(out, expected, rtol=1e-12)