blob: b442aaf8138c3a0f99e708eb446190007b856a71 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test TVM bridge, only enable this when TVM is available"""
import logging
import mxnet as mx
import numpy as np
import unittest
def test_tvm_bridge():
# only enable test if TVM is available
try:
import tvm
import tvm.contrib.mxnet
import topi
except ImportError:
logging.warn("TVM bridge test skipped because TVM is missing...")
return
def check(target, dtype):
shape = (20,)
scale = tvm.te.var("scale", dtype="float32")
x = tvm.te.placeholder(shape, dtype=dtype)
y = tvm.te.placeholder(shape, dtype=dtype)
z = tvm.te.compute(shape, lambda i: x[i] + y[i])
zz = tvm.te.compute(shape, lambda *i: z(*i) * scale.astype(dtype))
ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0)
target = tvm.target.create(target)
# build the function
with target:
s = topi.generic.schedule_injective(zz)
f = tvm.build(s, [x, y, zz, scale])
# get a mxnet version
mxf = tvm.contrib.mxnet.to_mxnet_func(f, const_loc=[0, 1])
xx = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
yy = mx.nd.uniform(shape=shape, ctx=ctx).astype(dtype)
zz = mx.nd.empty(shape=shape, ctx=ctx).astype(dtype)
# invoke myf: this runs in mxnet engine
mxf(xx, yy, zz, 10.0)
np.testing.assert_allclose(
zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10)
for tgt in ["llvm", "cuda"]:
for dtype in ["int8", "uint8", "int64",
"float32", "float64"]:
check(tgt, dtype)