blob: 78da144e54bf528eaf717f2f204debb16eeedd6d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Support level2 operator test cases.
"""
import sys
import numpy as np
import pytest
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import autotvm, relay, te
from tvm.contrib import utils, cudnn
from tvm.ir.module import IRModule
from tvm.relay import transform
from tvm.relay.testing import run_infer_type
from tvm.topi.cuda.conv3d_winograd import _infer_tile_size
executor_kind = tvm.testing.parameter("graph", "vm")
@tvm.testing.uses_gpu
def test_conv1d_infer_type():
# symbolic in batch dimension
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.ty.TensorType((n, c, w), "float32"))
w = relay.var("w")
y = relay.nn.conv1d(x, w, kernel_size=3, padding=(1, 1), channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType((2, 10, 3), "float32")
# infer by shape of w, mixed precision
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
y = relay.nn.conv1d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
y = relay.nn.conv1d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222), "int32")
# Infer with NWC
n, c, w = 4, 32, 224
x = relay.var("x", relay.TensorType((n, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv1d(
x, wt, kernel_size=3, padding=(1, 1), channels=16, data_layout="NWC", out_dtype="int32"
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, w, 16), "int32")
@tvm.testing.uses_gpu
def test_conv1d_run():
def run_test_conv1d(
dtype,
out_dtype,
scale,
dshape,
kshape,
padding=(1, 1),
fref=None,
dilation=1,
except_targets=None,
**attrs,
):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv1d(x, w, padding=padding, dilation=dilation, **attrs)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = tvm.topi.testing.conv1d_ncw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation
)
for target, dev in tvm.testing.enabled_targets():
if target in except_targets:
continue
dev = tvm.device(target, 0)
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv1d
dshape = (1, 3, 224)
kshape = (10, 3, 3)
run_test_conv1d(
"float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=3
)
# mixed precision
run_test_conv1d("int8", "int32", 1, dshape, kshape, padding=(1, 1), channels=10, kernel_size=3)
# dilated conv2d
dshape = (1, 3, 18)
kshape = (10, 3, 3)
run_test_conv1d(
"float32",
"float32",
1,
dshape,
kshape,
padding=(1, 1),
channels=10,
kernel_size=3,
dilation=3,
)
@tvm.testing.uses_gpu
def test_conv2d_infer_type():
# symbolic in batch dimension
n, c, h, w = te.size_var("n"), 10, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv2d(x, w, kernel_size=(3, 3), padding=(1, 1), channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 224, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType((2, 10, 3, 3), "float32")
# infer by shape of w, mixed precision
n, c, h, w = te.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
y = relay.nn.conv2d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222, 222), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, h, w = te.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3), "int8"))
y = relay.nn.conv2d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222, 222), "int32")
# Infer with a different layout
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n // 4, c // 4, h, w, 4, 4), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(
x,
wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NCHW4n4c",
kernel_layout="OIHW4o4i",
out_dtype="int32",
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 4, 4), "int32")
assert yy.args[1].checked_type == relay.TensorType((4, 8, 3, 3, 4, 4), "int8")
# Infer with NHWC
n, c, h, w = 4, 32, 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv2d(
x,
wt,
kernel_size=(3, 3),
padding=(1, 1),
channels=16,
data_layout="NHWC",
out_dtype="int32",
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, h, w, 16), "int32")
class TestConv2D:
config = {
"group1": dict(
dtype="float32",
out_dtype="float32",
scale=1,
dshape=(1, 32, 18, 18),
kshape=(32, 4, 3, 3),
padding=(1, 1),
channels=32,
groups=8,
kernel_size=(3, 3),
dilation=(1, 1),
),
"group2": dict(
dtype="float32",
out_dtype="float32",
scale=1,
dshape=(1, 32, 18, 18),
kshape=(64, 1, 3, 3),
padding=(1, 1),
channels=64,
groups=32,
kernel_size=(3, 3),
dilation=(1, 1),
),
"normal": dict(
dtype="float32",
out_dtype="float32",
scale=1,
dshape=(1, 3, 224, 224),
kshape=(10, 3, 3, 3),
padding=(1, 1),
channels=10,
groups=1,
kernel_size=(3, 3),
dilation=(1, 1),
),
"mixed_precision_int8_int32_case1": dict(
dtype="int8",
out_dtype="int32",
scale=1,
dshape=(1, 3, 224, 224),
kshape=(10, 3, 3, 3),
padding=(1, 1),
channels=10,
groups=1,
kernel_size=(3, 3),
dilation=(1, 1),
),
"mixed_precision_int8_int32_case2": dict(
dtype="int8",
out_dtype="int32",
scale=1,
dshape=(1, 3, 224, 224),
kshape=(10, 3, 1, 3),
padding=(0, 1),
channels=10,
groups=1,
kernel_size=(1, 3),
dilation=(1, 1),
),
"dilated": dict(
dtype="float32",
out_dtype="float32",
scale=1,
dshape=(1, 3, 18, 18),
kshape=(10, 3, 3, 3),
padding=(1, 1),
channels=10,
groups=1,
kernel_size=(3, 3),
dilation=(3, 3),
),
}
# TODO(Lunderberg): Make a cleaner utility for this type of
# parametrization. It would be much nicer to have the fixture
# name come from the dictionaries themselves, rather than needing
# to be re-packed into tuples.
(
dtype,
out_dtype,
scale,
dshape,
kshape,
padding,
channels,
groups,
kernel_size,
dilation,
) = tvm.testing.parameters(
*[
[
d[p]
for p in [
"dtype",
"out_dtype",
"scale",
"dshape",
"kshape",
"padding",
"channels",
"groups",
"kernel_size",
"dilation",
]
]
for d in config.values()
],
ids=config.keys(),
)
def test_run(
self,
target,
dev,
dtype,
out_dtype,
scale,
dshape,
kshape,
padding,
groups,
dilation,
channels,
kernel_size,
):
target = tvm.target.Target(target)
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.conv2d(
x,
w,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
)
func = relay.Function([x, w], y)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
ref_res = tvm.topi.testing.conv2d_nchw_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, groups=groups
)
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-4, atol=1e-4)
def test_compile_depthwise_conv2d_arm_cpu():
dtype = "float32"
out_dtype = "float32"
scale = 1
dshape = (1, 512, 32, 32)
kshape = (512, 1, 3, 3)
padding = (1, 1)
channels = 512
groups = 512
kernel_size = (3, 3)
dilation = (1, 1)
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.conv2d(
x,
w,
padding=padding,
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=kernel_size,
)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod["main"] = func
test_schedule = '{"i": ["llvm -device=arm_cpu", "depthwise_conv2d_nchw_spatial_pack.arm_cpu", \
[["TENSOR", [1, 512, 32, 32], "float32"], \
["TENSOR", [512, 1, 3, 3], "float32"], \
[1, 1], [1, 1], [1, 1], "float32"], {}, \
["depthwise_conv2d_nchw_spatial_pack.arm_cpu", [1, 512, 32, 32, "float32"], \
[512, 1, 3, 3, "float32"], [1, 1], [1, 1], [1, 1], "float32"], \
{"i": 743640, "t": "", "c": null, \
"e": [["tile_co", "sp", [32, 16]], ["tile_oh", "sp", [8, 1]], \
["tile_ow", "sp", [1, 8]], \
["reorder_0", "re", [0, 1, 2, 3, 4, 5, 8, 6, 7]], \
["reorder_1", "re", [0, 1, 2, 3, 6, 4, 5]], \
["ann_reduce", "an", ["unroll", "none"]], \
["ann_spatial", "an", ["unroll", "unroll", "vec"]], \
["data_pad_inline", "ot", 4], ["data_vec_inline", "ot", 1], \
["conv_inline", "ot", 0]]}], "r": [[0.0002933163], \
0, 3.1976189613342285, 1570811630.6058347], "v": 0.1}'
temp = utils.tempdir()
with open(temp.relpath("temp.log"), "w") as log_file:
log_file.write(test_schedule)
with autotvm.apply_history_best(temp.relpath("temp.log")):
with tvm.transform.PassContext(opt_level=3):
print("Compiling...")
graph_json, mod, params = tvm.relay.build(mod, target="llvm -device=arm_cpu")
@tvm.testing.uses_gpu
def test_conv2d_winograd():
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = autotvm.task.space.FallbackConfigEntity()
cfg.is_fallback = False
cfg.cost = 0.1 if "winograd" in workload[0] else 1
cfg["tile_b"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_rc"] = autotvm.task.space.SplitEntity([-1, 1])
cfg["auto_unroll_max_step"] = autotvm.task.space.OtherOptionEntity(1500)
cfg["unroll_explicit"] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg
return cfg
def run_test_conv2d_cuda(
dtype, out_dtype, scale, dshape, kshape, padding=(1, 1), groups=1, dilation=(1, 1), **attrs
):
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.conv2d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.InferType()(mod)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = tvm.topi.testing.conv2d_nchw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups
)
with WinogradFallback(), tvm.transform.PassContext(opt_level=3):
for target, dev in tvm.testing.enabled_targets():
if target != "cuda":
continue
dev = tvm.device(target, 0)
params = {"w": tvm.nd.array(kernel)}
graph, lib, params = relay.build_module.build(mod, target=target, params=params)
module = tvm.contrib.graph_executor.create(graph, lib, dev)
module.set_input("x", tvm.nd.array(data))
module.set_input(**params)
module.run()
op_res1 = module.get_output(0)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-3, atol=1e-3)
# normal winograd: stride 1, padding 1, kernel 3x3
dshape = (1, 80, 73, 73)
kshape = (192, 80, 3, 3)
run_test_conv2d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(1, 1), channels=192, kernel_size=(3, 3)
)
# extended winograd: stride 1, padding N, kernel 3x3
run_test_conv2d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(0, 0), channels=192, kernel_size=(3, 3)
)
run_test_conv2d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(3, 3)
)
# extended winograd: stride 1, padding N, kernel NxN
kshape = (192, 80, 7, 7)
run_test_conv2d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(7, 7)
)
@tvm.testing.uses_gpu
def test_conv3d_infer_type():
# symbolic in batch dimension
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv3d(x, w, kernel_size=(3, 3, 3), padding=(1, 1, 1), channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 224, 224, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType((2, 10, 3, 3, 3), "float32")
# infer by shape of w, mixed precision
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222, 222, 222), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 222, 222, 222), "int32")
# Infer with NDHWC
n, c, d, h, w = 4, 32, 224, 224, 224
x = relay.var("x", relay.TensorType((n, d, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv3d(
x,
wt,
kernel_size=(3, 3, 3),
padding=(1, 1, 1),
channels=16,
data_layout="NDHWC",
out_dtype="int32",
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32")
# Infer with groups
x = relay.var("x", relay.TensorType((1, 16, 224, 224, 224), "float32"))
w = relay.var("w", relay.TensorType((4, 4, 1, 1, 1), "float32"))
y = relay.nn.conv3d(x, w, groups=4, kernel_size=(1, 1, 1), channels=4)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 224), "float32")
@tvm.testing.uses_gpu
def test_conv3d_run():
def run_test_conv3d(
dtype,
out_dtype,
scale,
dshape,
kshape,
padding=(1, 1, 1),
fref=None,
groups=1,
dilation=(1, 1, 1),
except_targets=None,
**attrs,
):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv3d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
if fref is None:
ref_res = tvm.topi.testing.conv3d_ncdhw_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, groups=groups
)
else:
ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
for target, dev in tvm.testing.enabled_targets():
if target in except_targets:
continue
dev = tvm.device(target, 0)
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv3d
dshape = (1, 3, 5, 224, 224)
kshape = (10, 3, 3, 3, 3)
run_test_conv3d(
"float32",
"float32",
1,
dshape,
kshape,
padding=(1, 1, 1),
channels=10,
kernel_size=(3, 3, 3),
)
@tvm.testing.uses_gpu
def test_conv3d_ndhwc_run():
def run_test_conv3d(
dtype,
out_dtype,
scale,
dshape,
kshape,
padding=(1, 1, 1),
fref=None,
groups=1,
dilation=(1, 1, 1),
except_targets=None,
**attrs,
):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv3d(
x,
w,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NDHWC",
kernel_layout="DHWIO",
**attrs,
)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
dkernel = tvm.topi.testing.dilate_python(kernel, (1, 1) + dilation)
if fref is None:
ref_res = tvm.topi.testing.conv3d_ndhwc_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding
)
else:
ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
for target, dev in tvm.testing.enabled_targets():
if target in except_targets:
continue
dev = tvm.device(target, 0)
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv3d
dshape = (1, 5, 224, 224, 6)
kshape = (3, 3, 3, 6, 10)
run_test_conv3d(
"float32",
"float32",
1,
dshape,
kshape,
padding=(1, 1, 1),
channels=10,
kernel_size=(3, 3, 3),
except_targets=["cuda"],
)
@tvm.testing.uses_gpu
def test_conv3d_winograd():
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = autotvm.task.space.FallbackConfigEntity()
cfg.is_fallback = False
cfg.cost = 0.1 if "winograd" in workload[0] else 1
cfg["tile_b"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_y"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_x"] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg["tile_rc"] = autotvm.task.space.SplitEntity([-1, 1])
cfg["auto_unroll_max_step"] = autotvm.task.space.OtherOptionEntity(0)
cfg["unroll_explicit"] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg
return cfg
def run_test_conv3d_cuda(
dtype,
out_dtype,
scale,
dshape,
kshape,
padding=(1, 1, 1),
groups=1,
dilation=(1, 1, 1),
prepack=False,
**attrs,
):
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
if prepack:
tile_size = _infer_tile_size(np.zeros(shape=dshape), np.zeros(shape=kshape))
w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size)
y = relay.nn.contrib_conv3d_winograd_without_weight_transform(
x,
w_packed,
tile_size,
padding=padding,
dilation=dilation,
groups=groups,
channels=kshape[0],
**attrs,
)
else:
y = relay.nn.conv3d(x, w, padding=padding, dilation=dilation, groups=groups, **attrs)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod["main"] = func
mod = relay.transform.InferType()(mod)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = tvm.topi.testing.conv3d_ncdhw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, groups=groups
)
with WinogradFallback(), tvm.transform.PassContext(opt_level=3):
for target, dev in tvm.testing.enabled_targets():
if target != "cuda":
continue
dev = tvm.device(target, 0)
params = {"w": tvm.nd.array(kernel)}
graph, lib, params = relay.build_module.build(mod, target=target, params=params)
module = tvm.contrib.graph_executor.create(graph, lib, dev)
module.set_input("x", tvm.nd.array(data))
module.set_input(**params)
module.run()
op_res1 = module.get_output(0)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-3, atol=1e-3)
# normal winograd: stride 1, padding 1, kernel 3x3x3
dshape = (1, 32, 16, 16, 16)
kshape = (64, 32, 3, 3, 3)
run_test_conv3d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(1, 1, 1), kernel_size=(3, 3, 3)
)
# Without depth transform using 1x3x3 kernel.
kshape = (64, 32, 1, 3, 3)
run_test_conv3d_cuda(
"float32", "float32", 1, dshape, kshape, padding=(0, 1, 1), kernel_size=(1, 3, 3)
)
# extended winograd: stride 1, padding N, kernel NxNxN
dshape = (1, 61, 20, 20, 20)
kshape = (120, 61, 5, 5, 5)
run_test_conv3d_cuda(
"float32",
"float32",
1,
dshape,
kshape,
padding=(2, 2, 2),
channels=120,
kernel_size=(5, 5, 5),
)
# Without depth transform
kshape = (120, 61, 1, 5, 5)
run_test_conv3d_cuda(
"float32",
"float32",
1,
dshape,
kshape,
padding=(0, 2, 2),
channels=120,
kernel_size=(1, 5, 5),
)
@tvm.testing.uses_gpu
def test_conv3d_transpose_infer_type():
# symbolic in batch dimension
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv3d_transpose(x, w, kernel_size=(3, 3, 3), padding=(1, 1, 1), channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 2, 224, 224, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType((10, 2, 3, 3, 3), "float32")
# infer by shape of w, mixed precision
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8"))
y = relay.nn.conv3d_transpose(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 12, 226, 226, 226), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, d, h, w = te.size_var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
w = relay.var("w", relay.TensorType((10, 12, 3, 3, 3), "int8"))
y = relay.nn.conv3d_transpose(x, w, out_dtype="int32")
assert 'out_dtype="int32"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 12, 226, 226, 226), "int32")
@tvm.testing.uses_gpu
def test_conv3d_transpose_ncdhw_run():
dshape = (1, 3, 24, 24, 24)
kshape = (3, 4, 2, 2, 2)
x = relay.var("x", shape=dshape)
w = relay.var("w")
y = relay.nn.conv3d_transpose(
x, w, channels=4, kernel_size=(2, 2, 2), strides=(1, 1, 1), padding=(1, 1, 1)
)
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)
ref_res = tvm.topi.testing.conv3d_transpose_ncdhw_python(data, kernel, 1, 1, 0)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
def test_compile_depthwise_conv3d():
dshape = [1, 16, 10, 10, 10]
wshape = [16, 2, 1, 1, 1]
params = {}
data = relay.var("data", shape=dshape, dtype="float32")
kernel = relay.const(tvm.nd.array(np.ones(shape=wshape).astype(dtype="float32")))
mod = tvm.IRModule()
res = relay.nn.conv3d(
data,
kernel,
kernel_size=[1, 1, 1],
padding=[0] * 3,
channels=32,
groups=16,
data_layout="NCDHW",
kernel_layout="OIDHW",
)
func = relay.Function([data], res)
mod = tvm.IRModule.from_expr(func)
target = "llvm"
_ = relay.build(mod, tvm.target.Target(target, host=target))
@tvm.testing.uses_gpu
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
n, c, h, w = te.size_var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
w = relay.var("w", relay.IncompleteType())
y = relay.nn.conv2d_transpose(x, w, kernel_size=(3, 3), padding=(1, 1), channels=15)
assert "channels=15" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 15, 10, 12), "float32")
assert yy.args[1].checked_type == relay.TensorType((10, 15, 3, 3), "float32")
# infer by shape of w, mixed precision
n, h, w, c = te.size_var("n"), 10, 10, 12
x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
w = relay.var("w", relay.TensorType((12, 11, 5, 5), "float32"))
y = relay.nn.conv2d_transpose(x, w, output_padding=(1, 1), channels=11, data_layout="NHWC")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 15, 15, 11), "float32")
@tvm.testing.uses_gpu
def test_conv2d_transpose_nchw_run():
k_layouts = {"OIHW": (10, 3, 3, 3), "IOHW": (3, 10, 3, 3)}
output_padding = (1, 1)
for k_layout, kshape in k_layouts.items():
dshape = (1, 3, 18, 18)
x = relay.var("x", shape=dshape)
w = relay.var("w")
y = relay.nn.conv2d_transpose(
x,
w,
channels=10,
kernel_size=(3, 3),
strides=(2, 2),
padding=(1, 1),
output_padding=output_padding,
kernel_layout=k_layout,
data_layout="NCHW",
)
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)
if k_layout != "IOHW":
# Must be OIHW so switch
kernel_iohw = np.transpose(kernel, [1, 0, 2, 3])
else:
kernel_iohw = kernel
ref_res = tvm.topi.testing.conv2d_transpose_nchw_python(
data, kernel_iohw, 2, 1, output_padding
)
enabled_targets = tvm.testing.enabled_targets()
if cudnn.exists() and k_layout == "IOHW":
enabled_targets.append(("cuda -libs=cudnn", tvm.cuda(0)))
for target, dev in enabled_targets:
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_conv2d_transpose_nhwc_run():
dshape_nhwc = (1, 18, 18, 3)
kshape_hwoi = (3, 3, 10, 3)
x = relay.var("x", shape=dshape_nhwc)
w = relay.var("w")
y = relay.nn.conv2d_transpose(
x,
w,
channels=10,
kernel_size=(3, 3),
strides=(2, 2),
padding=(1, 1),
output_padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWOI",
)
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_hwoi).astype(dtype)
ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(
data, kernel, "HWOI", 2, 1, output_padding=(1, 1)
)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_conv2d_transpose_nhwc_cudnn():
if not cudnn.exists():
return
dshape_nhwc = (1, 18, 18, 3)
kshape_ihwo = (3, 3, 3, 10)
x = relay.var("x", shape=dshape_nhwc)
w = relay.var("w", shape=kshape_ihwo)
y = relay.nn.conv2d_transpose(
x,
w,
channels=10,
kernel_size=(3, 3),
strides=(2, 2),
padding=(1, 1),
output_padding=(1, 1),
data_layout="NHWC",
kernel_layout="IHWO",
)
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape_nhwc).astype(dtype)
kernel = np.random.uniform(size=kshape_ihwo).astype(dtype)
ref_res = tvm.topi.testing.conv2d_transpose_nhwc_python(
data, np.transpose(kernel, [1, 2, 3, 0]), "HWOI", 2, 1, output_padding=(1, 1)
)
target = "cuda -libs=cudnn"
dev = tvm.cuda(0)
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_conv1d_transpose_ncw_run():
dshape = (1, 3, 18)
kshape = (3, 10, 3)
oshape = (1, 10, 36)
x = relay.var("x", shape=dshape)
w = relay.var("w")
y = relay.nn.conv1d_transpose(
x, w, channels=10, kernel_size=(3,), strides=(2,), padding=(1,), output_padding=(1,)
)
func = relay.Function([x, w], y)
dtype = "float32"
data = np.random.uniform(size=dshape).astype(dtype)
kernel = np.random.uniform(size=kshape).astype(dtype)
ref_res = tvm.topi.testing.conv1d_transpose_ncw_python(data, kernel, 2, 1, output_padding=(1,))
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data, kernel
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_upsampling_infer_type():
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
scale = tvm.tir.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
'method="BINLINEAR"' in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(
n,
c,
tvm.tir.Cast("int32", te.round(h * scale)),
tvm.tir.Cast("int32", te.round(w * scale)),
),
"float32",
)
n, c = te.size_var("n"), te.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 200), "float32"))
y = relay.nn.upsampling(x, scale_h=2, scale_w=2, layout="NCHW", method="bilinear")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 400), "float32")
@tvm.testing.uses_gpu
def test_upsampling3d_infer_type():
n, c, d, h, w = (
te.size_var("n"),
te.size_var("c"),
te.size_var("d"),
te.size_var("h"),
te.size_var("w"),
)
scale = tvm.tir.const(2.0, "float64")
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = relay.nn.upsampling3d(
x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear"
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(
n,
c,
tvm.tir.Cast("int32", te.round(d * scale)),
tvm.tir.Cast("int32", te.round(h * scale)),
tvm.tir.Cast("int32", te.round(w * scale)),
),
"float32",
)
n, c = te.size_var("n"), te.size_var("c")
x = relay.var("x", relay.TensorType((n, c, 100, 100, 200), "float32"))
y = relay.nn.upsampling3d(
x, scale_d=2, scale_h=2, scale_w=2, layout="NCDHW", method="trilinear"
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 200, 200, 400), "float32")
def _test_global_pool2d(opfunc, reffunc):
n, c, h, w = te.size_var("n"), te.size_var("c"), 224, 224
x = relay.var("x", relay.TensorType((n, h, w, c), "float32"))
y = opfunc(x, layout="NHWC")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 1, 1, c), "float32")
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = opfunc(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 1, 1), "float32")
# test execution
dtype = "float32"
dshape = (1, 1024, 7, 7)
x = relay.var("x", shape=dshape)
y = opfunc(x)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = reffunc(data, axis=(2, 3), keepdims=True)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_pool2d():
def _test_pool2d(opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding=0):
n, c, h, w = te.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = opfunc(x, pool_size=(1, 1))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 28, 28)
x = relay.var("x", shape=dshape)
y = opfunc(x, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.poolnd_python(
data,
[pool_size, pool_size],
[strides, strides],
[dilation, dilation],
[padding, padding],
[padding, padding],
pool_type,
count_include_pad=False,
ceil_mode=False,
)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
def _test_pool2d_int(opfunc, reffunc, dtype):
n, c, h, w = te.size_var("n"), 10, 224, 224
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))
y = opfunc(x, pool_size=(1, 1))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224, 224), dtype)
# test execution
dshape = (1, 3, 28, 28)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
y = opfunc(x, pool_size=(2, 2), strides=(2, 2), padding=(0, 0))
func = relay.Function([x], y)
data = np.random.randint(low=-128, high=128, size=dshape)
ref_res = reffunc(data.reshape(1, 3, 14, 2, 14, 2), axis=(3, 5)).astype(dtype)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_pool2d(relay.nn.max_pool2d, "max")
_test_pool2d(relay.nn.max_pool2d, "max", pool_size=2, strides=2, padding=0)
_test_pool2d(relay.nn.max_pool2d, "max", pool_size=2, strides=2, padding=0, dilation=2)
_test_pool2d(relay.nn.avg_pool2d, "avg")
_test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, padding=0)
_test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, padding=0, dilation=2)
_test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int64")
_test_pool2d_int(relay.nn.avg_pool2d, np.mean, "float16")
_test_global_pool2d(relay.nn.global_max_pool2d, np.max)
_test_global_pool2d(relay.nn.global_avg_pool2d, np.mean)
def _test_global_pool1d(opfunc, reffunc):
n, c, w = te.size_var("n"), te.size_var("c"), 224
x = relay.var("x", relay.TensorType((n, w, c), "float32"))
y = opfunc(x, layout="NWC")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 1, c), "float32")
n, c, w = te.size_var("n"), te.size_var("c"), te.size_var("w")
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
y = opfunc(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, 1), "float32")
# test execution
dtype = "float32"
dshape = (1, 1024, 7)
x = relay.var("x", shape=dshape)
y = opfunc(x)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = reffunc(data, axis=(2,), keepdims=True)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_pool1d():
def _test_pool1d(
opfunc, pool_type, pool_size=2, strides=2, dilation=1, padding=0, dtype="float32"
):
n, c, w = te.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "float32"))
y = opfunc(x, pool_size=(1,))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 224), "float32")
# test execution
dshape = (1, 3, 32)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(x, pool_size=pool_size, strides=strides, dilation=dilation, padding=padding)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.poolnd_python(
data,
[pool_size],
[strides],
[dilation],
[padding],
[padding],
pool_type,
count_include_pad=False,
ceil_mode=False,
)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_pool1d(relay.nn.max_pool1d, "max")
_test_pool1d(relay.nn.max_pool1d, "max", dtype="int32")
_test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, padding=0)
_test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, padding=0, dilation=2)
_test_pool1d(relay.nn.avg_pool1d, "avg")
_test_pool1d(relay.nn.avg_pool1d, "avg", dtype="int64")
_test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, padding=0)
_test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, padding=0, dilation=2)
_test_global_pool1d(relay.nn.global_max_pool1d, np.max)
_test_global_pool1d(relay.nn.global_avg_pool1d, np.mean)
@tvm.testing.uses_gpu
def test_pool3d():
def _test_pool3d(
opfunc,
pool_type,
pool_size=2,
strides=2,
dilation=1,
padding=[0, 0, 0, 0, 0, 0],
dtype="float32",
):
n, c, d, h, w = te.size_var("n"), 10, 5, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "float32"))
y = opfunc(x, pool_size=(1, 1, 1))
assert "pool_size=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 10, 5, 224, 224), "float32")
# test execution
dtype = "float32"
dshape = (1, 3, 32, 32, 32)
for shape_dtype in ["int32", "int64"]:
x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in dshape], dtype=dtype)
pool_type = "max" if "max" in str(opfunc) else "avg"
y = opfunc(
x,
pool_size=pool_size,
strides=strides,
padding=padding,
dilation=dilation,
)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = tvm.topi.testing.poolnd_python(
data,
[pool_size, pool_size, pool_size],
[strides, strides, strides],
[dilation, dilation, dilation],
padding[:3],
padding[3:],
pool_type,
count_include_pad=False,
ceil_mode=False,
)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_pool3d(relay.nn.max_pool3d, "max")
_test_pool3d(relay.nn.max_pool3d, "max", dtype="int32")
_test_pool3d(relay.nn.max_pool3d, "max", padding=(2, 0, 0, 2, 0, 0))
_test_pool3d(relay.nn.max_pool3d, "max", padding=(0, 3, 0, 0, 3, 0))
_test_pool3d(relay.nn.max_pool3d, "max", padding=(0, 0, 4, 0, 0, 4))
_test_pool3d(relay.nn.max_pool3d, "max", pool_size=2, strides=2)
_test_pool3d(relay.nn.max_pool3d, "max", pool_size=2, strides=2, dilation=2)
_test_pool3d(relay.nn.avg_pool3d, "avg")
_test_pool3d(relay.nn.avg_pool3d, "avg", dtype="int32")
_test_pool3d(relay.nn.avg_pool3d, "avg", padding=(2, 0, 0, 2, 0, 0))
_test_pool3d(relay.nn.avg_pool3d, "avg", padding=(0, 3, 0, 0, 3, 0))
_test_pool3d(relay.nn.avg_pool3d, "avg", padding=(0, 0, 4, 0, 0, 4))
_test_pool3d(relay.nn.avg_pool3d, "avg", pool_size=2, strides=2)
_test_pool3d(relay.nn.avg_pool3d, "avg", pool_size=2, strides=2, dilation=2)
@tvm.testing.uses_gpu
def test_avg_pool2d_no_count_pad():
kh, kw = (4, 4)
sh, sw = (2, 2)
ph, pw = (2, 2)
n = 1
(ic, ih, iw) = (3, 28, 28)
(oc, oh, ow) = (3, 15, 15)
dshape = (n, ic, ih, iw)
x = relay.var("x", shape=dshape)
y = relay.nn.avg_pool2d(
x, pool_size=(kh, kw), strides=(sw, sw), padding=(ph, pw), count_include_pad=False
)
func = relay.Function([x], y)
dtype = "float32"
a_np = np.random.uniform(low=0.001, size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih + 2 * ph, iw + 2 * pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih + ph)), (range(pw, iw + pw)))
pad_np[np.ix_(*no_zero)] = a_np
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
for i in range(oh):
for j in range(ow):
pad_count = np.sum(
pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw] > 0, axis=(2, 3)
)
b_np[:, :, i, j] = np.sum(
pad_np[:, :, i * sh : i * sh + kh, j * sw : j * sw + kw], axis=(2, 3)
) / np.maximum(pad_count, 1)
ref_res = np.maximum(b_np, 0.0)
data = a_np
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_flatten_infer_type(executor_kind):
d1, d2, d3, d4 = te.size_var("d1"), te.size_var("d2"), te.size_var("d3"), te.size_var("d4")
x = relay.var("x", relay.TensorType((d1, d2, d3, d4), "float32"))
y = relay.nn.batch_flatten(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((d1, ((d2 * d3) * d4)), "float32")
x = relay.var("x", relay.TensorType((3, 2, 4, 3), "float32"))
y = relay.nn.batch_flatten(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((3, 24), "float32")
x = relay.var("x", relay.TensorType((d1, 2, d3, 3), "float32"))
y = relay.nn.batch_flatten(x)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((d1, ((2 * d3) * 3)), "float32")
shape = (1, 5, 10, 10)
o_shape = (1, 500)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
z = relay.nn.batch_flatten(x)
yy = run_infer_type(z)
assert yy.checked_type == relay.TensorType(o_shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = x_data.flatten().reshape(o_shape)
for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(func)(
x_data
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
@tvm.testing.uses_gpu
def test_pad_infer_type():
# entirely concrete cases
n, c, h, w = 1, 2, 3, 4
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((3, 6, 9, 12), "float32")
n, c, h, w = 4, 6, 3, 5
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(t, ((-1, -1), (2, -2), (0, -3), (4, 4)), pad_mode="reflect")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((2, 6, 0, 13), "float32")
# some symbolic values
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(t, ((1, 1), (2, 2), (3, 3), (4, 4)))
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(t, ((-1, -1), (-2, -2), (1, -3), (4, 4)))
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n + (-2), c + (-4), h + (-2), w + 8), "float32")
# dealing with dynamic vals
n, c, h, w = te.size_var("n"), 2, 3, te.size_var("w")
t = relay.var("t", relay.TensorType((n, c, h, w), "float32"))
y = relay.nn.pad(
t, ((1, 1), (2, 2), (3, 3), (4, 4)), pad_value=relay.var("pad_value", "float32")
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n + 2, 6, 9, w + 8), "float32")
def _get_numpy_pad(dshape, data, pad, pad_value=0):
mod_pad = []
for axis, (pad_x, pad_y) in enumerate(pad):
indices = range(dshape[axis])
if pad_x < 0:
indices = indices[abs(pad_x) :]
pad_x = 0
if pad_y < 0:
indices = indices[:pad_y]
pad_y = 0
data = np.take(data, indices, axis)
mod_pad.append((pad_x, pad_y))
return np.pad(data, tuple(mod_pad), "constant", constant_values=pad_value)
@tvm.testing.uses_gpu
def test_pad_run():
def _test_run(dtype):
dshape_list = [(4, 10, 7, 7), (4, 6, 3, 5)]
pad_list = [((1, 1), (2, 2), (3, 3), (4, 4)), ((-1, -1), (2, -2), (0, -2), (4, 4))]
for dshape, pad in zip(dshape_list, pad_list):
x = relay.var("x", shape=dshape)
y = relay.nn.pad(x, pad)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = _get_numpy_pad(dshape, data, pad)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_run("float32")
_test_run("int32")
@tvm.testing.uses_gpu
def test_pad_run_dynamic_pad_value():
def _test_run(dtype):
dshape = (4, 6, 3, 5)
pad = ((-1, -1), (2, -2), (0, -2), (4, 4))
data = relay.var("data", shape=dshape, dtype=dtype)
pad_value = relay.var("pad_value", dtype)
pad_data = relay.nn.pad(data, pad, pad_value=pad_value)
f = relay.Function([data, pad_value], pad_data)
data_arr = np.random.uniform(-10, 10, size=dshape).astype(dtype)
pad_value_arr = 2.0
ref_res = _get_numpy_pad(dshape, data_arr, pad, pad_value=pad_value_arr)
for target, dev in tvm.testing.enabled_targets():
result = relay.create_executor(kind="graph", device=dev, target=target).evaluate(f)(
data_arr, pad_value_arr
)
tvm.testing.assert_allclose(result.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_run("float32")
_test_run("int32")
def test_pad_value_in_array():
A = relay.var("A", shape=(32, 32), dtype="int8")
# Extract pad value from an array
p0 = relay.Constant(tvm.nd.array(np.array([2], dtype="int8")))
p1 = relay.nn.pad(A, pad_value=p0, pad_width=((1, 1), (1, 1)))
func = relay.Function(relay.analysis.free_vars(p1), p1)
mod = tvm.IRModule.from_expr(func)
target = "llvm"
lib = relay.build(
mod,
tvm.target.Target(target, host=target),
runtime=relay.backend.Runtime("cpp"),
executor=relay.backend.Executor("aot", {"unpacked-api": False, "interface-api": "packed"}),
)
@tvm.testing.uses_gpu
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_lrn(executor_kind, dtype):
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
x = relay.var("x", shape=(n, c, h, w), dtype=dtype)
y = relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=0.00001, beta=0.75)
"alpha=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w), dtype)
shape = (1, 5, 10, 10)
x = relay.var("x", relay.TensorType(shape, dtype))
size = 5
axis = 1
bias = 0.5
alpha = 0.00001
beta = 0.75
z = relay.nn.lrn(x, size=size, axis=axis, bias=bias, alpha=alpha, beta=beta)
yy = run_infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = tvm.topi.testing.lrn_python(x_data, size, axis, bias, alpha, beta)
for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(func)(
x_data
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
@tvm.testing.uses_gpu
def test_l2_normalize(executor_kind):
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w")
x = relay.var("x", shape=(n, c, h, w))
y = relay.nn.l2_normalize(x, eps=0.001, axis=[1])
"axis=" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, c, h, w))
shape = (1, 5, 10, 10)
dtype = "float32"
x = relay.var("x", relay.TensorType(shape, dtype))
eps = 0.001
axis = 1
z = relay.nn.l2_normalize(x, eps=0.001, axis=[axis])
yy = run_infer_type(z)
assert yy.checked_type == relay.TensorType(shape, dtype)
func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=shape).astype(dtype)
ref_res = tvm.topi.testing.l2_normalize_python(x_data, eps, axis)
for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor(executor_kind, device=dev, target=target).evaluate(func)(
x_data
)
tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5)
def batch_flatten(data):
shape = data.shape
target_dim = 1
for i in range(len(shape) - 1):
target_dim = target_dim * shape[i + 1]
return np.reshape(data, (shape[0], target_dim))
@tvm.testing.uses_gpu
def test_batch_flatten():
t1 = relay.TensorType((5, 10, 5))
x = relay.Var("x", t1)
func = relay.Function([x], relay.nn.batch_flatten(x))
data = np.random.rand(5, 10, 5).astype(t1.dtype)
ref_res = batch_flatten(data)
for target, dev in tvm.testing.enabled_targets():
op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
np.testing.assert_allclose(op_res.numpy(), ref_res, rtol=0.01)
def _test_upsampling(layout, method, align_corners=False):
n, c, h, w = te.size_var("n"), 16, 32, 32
scale_h = 2.0
scale_w = 2.0
dtype = "float32"
def get_shape():
if layout == "NCHW":
return (c, h, w), (c, int(round(h * scale_h)), int(round(w * scale_w)))
else:
return (h, w, c), (int(round(h * scale_h)), int(round(w * scale_w)), c)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling(
x,
scale_h=scale_h,
scale_w=scale_w,
layout=layout,
method=method,
align_corners=align_corners,
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling(
x,
scale_h=scale_h,
scale_w=scale_w,
layout=layout,
method=method,
align_corners=align_corners,
)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref = tvm.topi.testing.resize2d_python(
data,
(scale_h, scale_w),
layout,
method[2:] if method[0:2] == "bi" else method,
"align_corners" if align_corners else "asymmetric",
)
for target, dev in tvm.testing.enabled_targets():
out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_upsampling():
_test_upsampling("NCHW", "nearest_neighbor")
_test_upsampling("NCHW", "bilinear", True)
_test_upsampling("NHWC", "nearest_neighbor")
_test_upsampling("NHWC", "bilinear", True)
def _test_upsampling3d(layout, method, coordinate_transformation_mode="half_pixel"):
n, c, d, h, w = te.size_var("n"), 8, 16, 16, 16
scale_d = 2.0
scale_h = 2.0
scale_w = 2.0
dtype = "float32"
def get_shape():
if layout == "NCDHW":
return (c, d, h, w), (
c,
int(round(d * scale_d)),
int(round(h * scale_h)),
int(round(w * scale_w)),
)
else:
return (d, h, w, c), (
int(round(d * scale_d)),
int(round(h * scale_h)),
int(round(w * scale_w)),
c,
)
ishape, oshape = get_shape()
x = relay.var("x", relay.TensorType((n,) + ishape, dtype))
y = relay.nn.upsampling3d(
x,
scale_d=scale_d,
scale_h=scale_h,
scale_w=scale_w,
layout=layout,
method=method,
coordinate_transformation_mode=coordinate_transformation_mode,
)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n,) + oshape, dtype)
dshape = (1,) + ishape
x = relay.var("x", shape=dshape)
y = relay.nn.upsampling3d(
x,
scale_d=scale_d,
scale_h=scale_h,
scale_w=scale_w,
layout=layout,
method=method,
coordinate_transformation_mode=coordinate_transformation_mode,
)
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
ref = tvm.topi.testing.resize3d_python(
data,
(scale_d, scale_h, scale_w),
layout,
method[3:] if method[0:3] == "tri" else method,
coordinate_transformation_mode,
)
for target, dev in tvm.testing.enabled_targets():
out = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data)
tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-5, atol=1e-5)
@tvm.testing.uses_gpu
def test_upsampling3d():
_test_upsampling3d("NCDHW", "nearest_neighbor", "asymmetric")
_test_upsampling3d("NCDHW", "trilinear", "align_corners")
_test_upsampling3d("NDHWC", "nearest_neighbor", "asymmetric")
_test_upsampling3d("NDHWC", "trilinear", "align_corners")
@tvm.testing.requires_x86
@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8")
class TestConv2DInt8Intrinsics:
supported_targets = [
"llvm -mcpu=nehalem",
"llvm -mcpu=core-avx2",
"llvm -mcpu=skylake-avx512",
"llvm -mcpu=cascadelake",
]
unsupported_targets = [
"llvm -mcpu=x86-64",
]
data_layout, kernel_layout = tvm.testing.parameters(
("NCHW", "OIHW"),
# TODO(@anijain2305, @icemelon9): disable conv2d_int8 for NHWC data layout.
# Re-enable this after adding conv2d_NCHWc_int8 support for NHWC.
# ("NHWC", "HWIO"),
)
input_channels, output_channels = tvm.testing.parameters(
# Sweep the input channels to check int8 robustness
# Input channels should be a multiple of 4 internally.
(1, 16),
(4, 16),
(6, 16),
# Sweep the output channels to check int8 robustness
# Output channels should be a multiple of 16 internally.
(8, 4),
(8, 16),
(8, 20),
# Check that both non-divisible oc and ic work
(17, 29),
)
@tvm.testing.fixture
def fast_int8_intrinsic(self, target):
if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target:
return "pmaddubs"
elif "cascadelake" in target:
return "vpdpbusd"
else:
assert False, "Target should be Nehalem or core-avx2 or Skylake or Cascadelake"
@tvm.testing.fixture
def assembly(
self,
target,
dtypes,
input_channels,
output_channels,
data_layout,
kernel_layout,
):
if (
input_channels == 17
and output_channels == 29
and target == "llvm -mcpu=x86-64"
and tvm.target.codegen.llvm_version_major() in [16, 17]
):
pytest.skip(
"Non divisible dims does not produce vectorized code when 15 < LLVM Version < 18."
)
input_dtype, weight_dtype, output_dtype = dtypes
image_size = (64, 64)
kernel_size = (3, 3)
batch_size = 1
h, w = image_size
if data_layout == "NCHW":
data_shape = (batch_size, input_channels, *image_size)
elif data_layout == "NHWC":
data_shape = (batch_size, *image_size, input_channels)
else:
raise ValueError(f"Unsupported data layout: {data_layout}")
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
if kernel_layout == "OIHW":
kernel_shape = (output_channels, input_channels, *kernel_size)
elif kernel_layout == "HWIO":
kernel_shape = (*kernel_size, input_channels, output_channels)
else:
raise ValueError("Not supported")
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(
x,
weight,
kernel_size=kernel_size,
channels=output_channels,
padding=(0, 0, 0, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=output_dtype,
)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)
return lib.get_source("asm")
# Ensure that code uses the fast int8 instructions when available.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize(
"dtypes",
[
# compile conv2d for x86 (skylake, cascadelake) and test
# assembly contains *pmadd* instructions
("uint8", "int8", "int32"),
# Check that int8 x int8 goes through legalization so that
# fast instructions can be picked up.
("int8", "int8", "int32"),
],
)
def test_uses_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic in assembly
# For datatypes that don't have HW support, ensure that code is
# generated without the fast int8 intrinsic.
@tvm.testing.parametrize_targets(*supported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")])
def test_no_intrinsic(
self,
fast_int8_intrinsic,
assembly,
):
assert fast_int8_intrinsic not in assembly
# Check that a vectorized instruction is generated for older Intel
# generations, because we default to NCHWc layout.
@tvm.testing.parametrize_targets(*unsupported_targets)
@pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")])
def test_uses_vectorized_instruction(self, assembly):
assert "pmulhw" in assembly or "pmaddwd" in assembly
assert "paddd" in assembly
@tvm.testing.uses_gpu
def test_depthwise_conv2d_int8():
input_dtype = "uint8"
weight_dtype = "int8"
output_dtype = "int32"
data_shape = (1, 64, 56, 56)
x = relay.var("x", relay.TensorType(data_shape, input_dtype))
kernel_shape = (64, 1, 3, 3)
weight = relay.var("weight", relay.TensorType(kernel_shape, weight_dtype))
y = relay.nn.conv2d(
x,
weight,
kernel_size=(3, 3),
groups=64,
padding=(1, 1),
dilation=(1, 1),
out_dtype=output_dtype,
)
func = relay.Function([x, weight], y)
wdata = np.random.rand(*kernel_shape) * 10
parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))}
targets = [
"llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512",
"llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake",
]
llvm_version = tvm.target.codegen.llvm_version_major()
for target in targets:
if llvm_version >= 8:
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(func, target, params=parameters)
@tvm.testing.uses_gpu
def test_bitserial_conv2d_infer_type():
# Basic shape test with ambiguous batch.
n, c, h, w = te.size_var("n"), 32, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "int16"))
w = relay.var("w", relay.ty.TensorType((32, 32, 3, 3), "int16"))
y = relay.nn.bitserial_conv2d(x, w, kernel_size=(3, 3), padding=(0, 0), channels=32)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, 32, 222, 222), "int16")
@tvm.testing.uses_gpu
def test_bitpack_infer_type():
# Test axis packing shape inference.
o, i, h, w = 32, 32, 128, 128
x = relay.var("x", relay.ty.TensorType((o, i, h, w), "int16"))
y = relay.nn.bitpack(x, bit_axis=4, pack_axis=1, pack_type="uint16", bits=1)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((32, 2, 128, 128, 1), "uint16")
# TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases
@tvm.testing.uses_gpu
def test_correlation():
def _test_correlation(
data_shape,
kernel_size,
max_displacement,
stride1,
stride2,
padding,
is_multiply,
dtype="float32",
):
data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype))
data2 = relay.var("data2", relay.ty.TensorType(data_shape, dtype))
y = relay.nn.correlation(
data1,
data2,
kernel_size,
max_displacement,
stride1,
stride2,
padding,
is_multiply,
"NCHW",
)
yy = run_infer_type(y)
padded_height = data_shape[2] + 2 * padding
padded_width = data_shape[3] + 2 * padding
border_size = (kernel_size - 1) // 2 + max_displacement
displacement_radius = max_displacement // stride2
out_channel = ((2 * displacement_radius) + 1) ** 2
out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1
out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1
assert yy.checked_type == relay.TensorType(
(data_shape[0], out_channel, out_height, out_width), dtype
)
func = relay.Function([data1, data2], y)
data1_np = np.random.uniform(size=data_shape).astype(dtype)
data2_np = np.random.uniform(size=data_shape).astype(dtype)
ref_res = tvm.topi.testing.correlation_nchw_python(
data1_np,
data2_np,
kernel_size,
max_displacement,
stride1,
stride2,
padding,
is_multiply,
)
for target, dev in tvm.testing.enabled_targets():
op_res1 = relay.create_executor("graph", device=dev, target=target).evaluate(func)(
data1_np, data2_np
)
tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5, atol=1e-5)
_test_correlation(
(1, 3, 10, 10),
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1,
padding=4,
is_multiply=True,
)
_test_correlation(
(1, 3, 10, 10),
kernel_size=1,
max_displacement=5,
stride1=1,
stride2=1,
padding=5,
is_multiply=True,
)
_test_correlation(
(5, 1, 4, 4),
kernel_size=3,
max_displacement=1,
stride1=2,
stride2=1,
padding=2,
is_multiply=True,
)
_test_correlation(
(5, 1, 6, 4),
kernel_size=3,
max_displacement=1,
stride1=2,
stride2=2,
padding=2,
is_multiply=False,
)
_test_correlation(
(5, 1, 11, 11),
kernel_size=5,
max_displacement=1,
stride1=1,
stride2=1,
padding=2,
is_multiply=False,
)
@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_conv2d_rocm_sdot4():
d_shape = (1, 64, 56, 56)
w_shape = (64, 64, 3, 3)
padding = (1, 1)
strides = (1, 1)
data_dtype = "int8"
weight_dtype = "int8"
out_dtype = "int32"
data = relay.var("data", shape=d_shape, dtype=data_dtype)
weight = relay.var("weight", shape=w_shape, dtype=weight_dtype)
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=out_dtype,
)
mod = tvm.IRModule.from_expr(conv2d)
data_np = np.random.uniform(1, 10, d_shape).astype("int8")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("int8")
target = "rocm -mattr=+dotprod"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
asm = lib.lib.imported_modules[0].get_source("asm")
assert "v_dot4_i32_i8" in asm
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
runtime.set_input("data", data_np)
runtime.run()
out = runtime.get_output(0).numpy()
ref = tvm.topi.testing.conv2d_nchw_python(
data_np.astype("int32"), weight_np.astype("int32"), strides, padding
)
np.testing.assert_equal(out, ref)
def np_float2tvm_bf16(arr):
"""Convert a numpy array of float to a TVM array
of bf16"""
orig = arr.view("<u4")
bias = np.bitwise_and(np.right_shift(orig, 16), 1) + 0x7FFF
nparr = np.right_shift(orig + bias, 16).astype("uint16")
return tvm.nd.empty(nparr.shape, "bfloat16").copyfrom(nparr)
def np_bf162np_float(arr):
"""Convert a numpy array of bf16 (uint16) to a numpy array
of float"""
u32 = np.left_shift(arr.astype("uint32"), 16)
return u32.view("<f4")
@tvm.testing.requires_x86
def test_conv2d_nchw_dnnl():
if not tvm.get_global_func("tvm.contrib.dnnl.conv2d", allow_missing=True):
print(
"skip because extern dnnl function is not available, \
built with dnnl=ON"
)
return
d_shape = (1, 64, 56, 56)
w_shape = (64, 64, 3, 3)
padding = (1, 1)
strides = (1, 1)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[0]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
)
return conv2d
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nchw_python(data_np, weight_np, strides, padding)
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
runtime.set_input("data", data_np)
runtime.run()
out = runtime.get_output(0).numpy()
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
@tvm.testing.requires_x86
def test_conv2d_nhwc_dnnl():
if not tvm.get_global_func("tvm.contrib.dnnl.conv2d", allow_missing=True):
print(
"skip because extern dnnl function is not available, \
built with dnnl=ON"
)
return
d_shape = (1, 56, 56, 64)
w_shape = (3, 3, 64, 64)
padding = (1, 1)
strides = (1, 1)
def get_subgraph(dtype):
data = relay.var("data", shape=d_shape, dtype=dtype)
weight = relay.var("weight", shape=w_shape, dtype=dtype)
out_channel = w_shape[3]
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[:2],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=dtype,
data_layout="NHWC",
kernel_layout="HWIO",
)
return conv2d
for t in ["float32", "bfloat16"]:
mod = tvm.IRModule.from_expr(get_subgraph(t))
data_np = np.random.uniform(1, 10, d_shape).astype("float32")
weight_np = np.random.uniform(1, 10, size=w_shape).astype("float32")
ref = tvm.topi.testing.conv2d_nhwc_python(data_np, weight_np, strides, padding)
if t == "bfloat16":
data_np = np_float2tvm_bf16(data_np)
weight_np = np_float2tvm_bf16(weight_np)
target = "llvm -mcpu=skylake-avx512 -libs=dnnl"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params={"weight": weight_np})
dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
runtime.set_input("data", data_np)
runtime.run()
out = runtime.get_output(0).numpy()
if t == "bfloat16":
out = np_bf162np_float(out)
np.testing.assert_allclose(out, ref, rtol=1e-2)
else:
np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs):
def get_conv2d_nchw(
d_shape,
w_shape,
data_dtype,
):
out_dtype = "int32"
strides = (1, 1)
padding = (1, 1)
data = relay.var("data", shape=d_shape, dtype=data_dtype)
weight = relay.var("weight", shape=w_shape, dtype="int8")
out_channel = w_shape[0]
return relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=w_shape[2:],
channels=out_channel,
padding=padding,
strides=strides,
out_dtype=out_dtype,
)
I, O, H, W = 64, 64, 56, 56
kH = kW = 3
data_shape = (1, I, H, W)
weight_shape = (O, I, kH, kW)
bias_shape = (1, weight_shape[0], 1, 1)
bias = relay.var("bias", shape=bias_shape, dtype="int32")
bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32")
weight_np = np.random.uniform(-32, 32, size=weight_shape).astype("int8")
conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype)
bias_add = relay.add(conv2d, bias)
mod = tvm.IRModule.from_expr(bias_add)
if data_dtype == "uint8":
data_np = np.random.uniform(0, 64, size=data_shape).astype("uint8")
else:
data_np = np.random.uniform(-32, 32, size=data_shape).astype("int8")
params = {"weight": weight_np, "bias": bias_np}
ref = (
relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm")
.evaluate()(*[data_np, weight_np, bias_np])
.numpy()
)
dev = tvm.cpu(0)
with tvm.transform.PassContext(
opt_level=3,
):
lib = relay.build(mod, target=target, params=params)
for dot_product_instr in dot_product_instrs:
assert dot_product_instr in lib.lib.get_source("asm")
rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
rt_mod.set_input("data", data_np)
rt_mod.run()
out = rt_mod.get_output(0).numpy()
np.testing.assert_equal(out, ref)
@tvm.testing.requires_arm_dot
def test_conv2d_int8_alter_dtype_arm():
_test_conv2d_int8_alter_dtype(
"uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", ["sdot"]
)
@tvm.testing.requires_x86_vnni
def test_conv2d_int8_alter_dtype_vnni():
_test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake", ["vpdpbusd"])
@tvm.testing.requires_x86_avx512
def test_conv2d_int8_alter_dtype_avx512():
_test_conv2d_int8_alter_dtype(
"int8", "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]
)
if __name__ == "__main__":
tvm.testing.main()