blob: 0a8be3df11a05c5005e890303f2c604740948e09 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test translate from relax."""
import torch
from torch.nn import Module
import numpy as np
import tvm.testing
from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate
from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
from tvm.contrib.msc.core.frontend import translate as core_translate
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
def verify_model(torch_model, input_info, opt_config=None):
"""Compare torch module IR"""
orig_mod, _ = torch_translate.from_torch(torch_model, input_info, as_msc=False)
target = "llvm"
dev = tvm.cpu()
args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info]
def _tvm_runtime_to_np(obj):
if isinstance(obj, tvm.runtime.Tensor):
return obj.numpy()
elif isinstance(obj, tvm.runtime.ShapeTuple):
return np.array(obj, dtype="int64")
elif isinstance(obj, (list, tvm.ir.container.Array)):
return [_tvm_runtime_to_np(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(_tvm_runtime_to_np(item) for item in obj)
else:
return obj
def _run_relax(relax_mod):
relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)
relax_exec = tvm.compile(relax_mod, target)
vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
res = vm_runner["main"](*args)
return _tvm_runtime_to_np(res)
rt_mod = tvm_codegen.to_relax(
*core_translate.from_relax(orig_mod, opt_config=opt_config),
codegen_config={"explicit_name": False},
)
orig_output = _run_relax(orig_mod)
rt_output = _run_relax(rt_mod)
if not isinstance(orig_output, (list, tuple)):
orig_output = [orig_output]
if not isinstance(rt_output, (list, tuple)):
rt_output = [rt_output]
for o_out, r_out in zip(orig_output, rt_output):
tvm.testing.assert_allclose(o_out, r_out)
def test_conv1d():
"""test relax translator for conv1d"""
class Conv1D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
class Conv1D2(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv1d(3, 6, 7, bias=False)
def forward(self, data):
return self.conv(data)
input_info = [([1, 3, 10], "float32")]
verify_model(Conv1D1(), input_info)
verify_model(Conv1D2(), input_info)
def test_conv2d():
"""test relax translator for conv2d"""
class Conv2D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
class Conv2D2(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=False)
def forward(self, data):
return self.conv(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Conv2D1(), input_info)
verify_model(Conv2D2(), input_info)
def test_linear():
"""test relax translator for linear"""
class Dense1(Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 7, bias=True)
def forward(self, data):
return self.linear(data)
class Dense2(Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(10, 7, bias=False)
def forward(self, data):
return self.linear(data)
class MatMul1(Module):
def forward(self, x, y):
return torch.matmul(x, y)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Dense1(), input_info)
verify_model(Dense2(), input_info)
verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")])
def test_bmm():
"""test relax translator for bmm"""
class BMM(Module):
def forward(self, x, y):
return torch.bmm(x, y)
input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")]
verify_model(BMM(), input_info)
def test_baddbmm():
"""test relax translator for baddbmm"""
class BAddBMM1(Module):
def forward(self, c, x, y):
return torch.baddbmm(c, x, y)
class BAddBMM2(Module):
def forward(self, c, x, y):
return torch.baddbmm(c, x, y, alpha=2, beta=0)
input_info = [
((4, 128, 512), "float32"),
((4, 128, 256), "float32"),
((4, 256, 512), "float32"),
]
verify_model(BAddBMM1(), input_info)
verify_model(BAddBMM2(), input_info)
def test_relu():
"""test relax translator for relu"""
class ReLU(Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, data):
return self.relu(data)
class ReLU1(Module):
def forward(self, data):
return torch.nn.functional.relu(data)
input_info = [([10, 10], "float32")]
verify_model(ReLU(), input_info)
verify_model(ReLU1(), input_info)
def test_relu6():
"""test relax translator for relu6"""
class ReLU6(Module):
def __init__(self):
super().__init__()
self.relu6 = torch.nn.ReLU6()
def forward(self, data):
return self.relu6(data)
input_info = [([10, 10], "float32")]
verify_model(ReLU6(), input_info)
def test_maxpool2d():
"""test relax translator for maxpool2d"""
class MaxPool2d(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1])
def forward(self, data):
return self.pool(data)
class MaxPool2d2(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3])
def forward(self, data):
return self.pool(data)
class MaxPool2d3(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2)
def forward(self, data):
return self.pool(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(MaxPool2d(), input_info)
verify_model(MaxPool2d2(), input_info)
verify_model(MaxPool2d3(), input_info)
def test_avgpool2d():
"""test relax translator for avgpool2d"""
class AvgPool2d(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1])
def forward(self, data):
return self.pool(data)
class AvgPool2d2(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True)
def forward(self, data):
return self.pool(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(AvgPool2d(), input_info)
verify_model(AvgPool2d2(), input_info)
def test_adaptive_avgpool2d():
"""test relax translator for adaptive_avgpool2d"""
class AdaptiveAvgPool2d0(Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.AdaptiveAvgPool2d([10, 10])
def forward(self, data):
return self.pool(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(AdaptiveAvgPool2d0(), input_info)
def test_flatten():
"""test relax translator for flatten"""
class Flatten(Module):
def __init__(self):
super().__init__()
self.f = torch.nn.Flatten(2, -1)
def forward(self, data):
return self.f(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Flatten(), input_info)
verify_model(torch.nn.Flatten(2, -1), input_info)
def test_batchnorm2d():
"""test relax translator for batchnorm2d"""
class BatchNorm2d(Module):
def __init__(self):
super().__init__()
self.batchnorm = torch.nn.BatchNorm2d(3)
def forward(self, data):
return self.batchnorm(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(BatchNorm2d(), input_info)
def test_embedding():
"""test relax translator for embedding"""
class Embedding(Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 3)
def forward(self, data):
return self.embedding(data)
verify_model(Embedding(), [([4], "int64")])
verify_model(Embedding(), [([4, 5], "int64")])
def test_dropout():
"""test relax translator for dropout"""
class Dropout1(Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
def forward(self, data):
return self.dropout(data)
class Dropout2(Module):
def forward(self, data):
return torch.dropout(data, 0.5, train=True)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Dropout1(), input_info)
verify_model(Dropout2(), input_info)
def test_layernorm():
"""test relax translator for layernorm"""
class LayerNorm(Module):
def __init__(self):
super().__init__()
self.layernorm = torch.nn.LayerNorm((10, 10))
def forward(self, data):
return self.layernorm(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(LayerNorm(), input_info)
def test_functional_layernorm():
"""test relax translator for functional_layernorm"""
class LayerNorm(Module):
def __init__(self, shape):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(shape))
self.bias = torch.nn.Parameter(torch.zeros(shape))
def forward(self, data):
return torch.nn.functional.layer_norm(
data, self.weight.shape, self.weight, self.bias, 1e-5
)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(LayerNorm((10, 10)), input_info)
def test_cross_entropy():
"""test relax translator for cross_entropy"""
class CrossEntropy1(Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, logits, targets):
return self.loss(logits, targets)
class CrossEntropy2(Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((2,)))
self.loss = torch.nn.CrossEntropyLoss(weight=self.weight)
def forward(self, logits, targets):
return self.loss(logits, targets)
class CrossEntropy3(Module):
def __init__(self):
super().__init__()
self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum")
def forward(self, logits, targets):
return self.loss(logits, targets)
input_info = [([3, 2], "float32"), ([3], "int32")]
verify_model(CrossEntropy1(), input_info)
verify_model(CrossEntropy2(), input_info)
verify_model(CrossEntropy3(), input_info)
def test_functional_cross_entropy():
"""test relax translator for functional_cross_entropy"""
class CrossEntropy(Module):
def forward(self, logits, targets):
return torch.nn.functional.cross_entropy(logits, targets)
input_info = [([3, 10], "float32"), ([3], "int32")]
verify_model(CrossEntropy(), input_info)
def test_silu():
"""test relax translator for silu"""
class SiLU(Module):
def __init__(self):
super().__init__()
self.silu = torch.nn.SiLU()
def forward(self, data):
return self.silu(data)
class SiLU2(Module):
def forward(self, data):
return torch.nn.functional.silu(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(SiLU(), input_info)
verify_model(SiLU2(), input_info)
def test_groupnorm():
"""test relax translator for groupnorm"""
class GroupNorm(Module):
def __init__(self):
super().__init__()
self.groupnorm = torch.nn.GroupNorm(3, 3)
def forward(self, data):
return self.groupnorm(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(GroupNorm(), input_info)
def test_softmax():
"""test relax translator for softmax"""
class Softmax(Module):
def __init__(self):
super().__init__()
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, data):
return self.softmax(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Softmax(), input_info)
def test_binary():
"""test relax translator for binary"""
input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")]
input_info2 = [([1, 3, 10, 10], "float32")]
# Add
class Add1(Module):
def forward(self, lhs, rhs):
return lhs + rhs
class Add2(Module):
def forward(self, lhs):
return lhs + 1.0
verify_model(Add1(), input_info1)
verify_model(Add2(), input_info2)
# Sub
class Sub1(Module):
def forward(self, lhs, rhs):
return lhs - rhs
class Sub2(Module):
def forward(self, lhs):
return lhs - 1.0
verify_model(Sub1(), input_info1)
verify_model(Sub2(), input_info2)
# Mul
class Mul1(Module):
def forward(self, lhs, rhs):
return lhs * rhs
class Mul2(Module):
def forward(self, lhs):
return lhs * 1.0
verify_model(Mul1(), input_info1)
verify_model(Mul2(), input_info2)
# True div
class TrueDiv1(Module):
def forward(self, lhs, rhs):
return lhs / rhs
class TrueDiv2(Module):
def forward(self, lhs):
return lhs / 1.0
verify_model(TrueDiv1(), input_info1)
verify_model(TrueDiv2(), input_info2)
# Floor div
class FloorDiv1(Module):
def forward(self, lhs, rhs):
return lhs // rhs
class FloorDiv2(Module):
def forward(self, lhs):
return lhs // 1.0
verify_model(FloorDiv1(), input_info1)
verify_model(FloorDiv2(), input_info2)
# Power
class Power1(Module):
def forward(self, lhs, rhs):
return lhs**rhs
class Power2(Module):
def forward(self, lhs):
return lhs**1.0
verify_model(Power1(), input_info1)
verify_model(Power2(), input_info2)
# LT
class LT1(Module):
def forward(self, lhs, rhs):
return lhs < rhs
class LT2(Module):
def forward(self, lhs):
return lhs < 1.0
verify_model(LT1(), input_info1)
verify_model(LT2(), input_info2)
def test_size():
"""test relax translator for size"""
class Size(Module):
def forward(self, data):
return data.size()
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Size(), input_info)
def test_squeeze():
"""test relax translator for squeeze"""
class Squeeze1(Module):
def forward(self, data):
return data.squeeze(1)
class Squeeze2(Module):
def forward(self, data):
return data.squeeze()
input_info = [([3, 1, 4, 1], "float32")]
verify_model(Squeeze1(), input_info)
verify_model(Squeeze2(), input_info)
def test_unsqueeze():
"""test relax translator for unsqueeze"""
class Unsqueeze1(Module):
def forward(self, data):
return data.unsqueeze(1)
class Unsqueeze2(Module):
def forward(self, data):
return data.unsqueeze(-1)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Unsqueeze1(), input_info)
verify_model(Unsqueeze2(), input_info)
def test_getattr():
"""test relax translator for getattr"""
class GetAttr1(Module):
def forward(self, data):
return data.shape
input_info = [([1, 3, 10, 10], "float32")]
verify_model(GetAttr1(), input_info)
def test_getitem():
"""test relax translator for getitem"""
class Slice1(Module):
def forward(self, x):
return x[0, 1::2, :, :3]
class Slice2(Module):
def forward(self, x):
return x[:, None, None, :, None]
verify_model(Slice1(), [([1, 3, 10, 10], "float32")])
verify_model(Slice2(), [([8, 16], "float32")])
def test_unary():
"""test relax translator for unary"""
input_info = [([1, 3, 10, 10], "float32")]
# sin
class Sin(Module):
def forward(self, data):
return torch.sin(data)
verify_model(Sin(), input_info)
# cos
class Cos(Module):
def forward(self, data):
return torch.cos(data)
verify_model(Cos(), input_info)
# exp
class Exp(Module):
def forward(self, data):
return torch.exp(data)
verify_model(Exp(), input_info)
# sqrt
class Sqrt(Module):
def forward(self, data):
return torch.sqrt(data)
verify_model(Sqrt(), input_info)
# sigmoid
class Sigmoid(Module):
def forward(self, data):
return torch.sigmoid(data)
verify_model(Sigmoid(), input_info)
# round
class Round(Module):
def forward(self, data):
return torch.round(data)
verify_model(Round(), input_info)
def test_gelu():
"""test relax translator for gelu"""
class Gelu(Module):
def forward(self, data):
return torch.nn.functional.gelu(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Gelu(), input_info)
def test_tanh():
"""test relax translator for tanh"""
class Tanh(Module):
def forward(self, data):
return torch.tanh(data)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Tanh(), input_info)
def test_clamp():
"""test relax translator for clamp"""
class Clamp(Module):
def forward(self, data):
return torch.clamp(data, min=0.1, max=0.5)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Clamp(), input_info)
def test_interpolate():
"""test relax translator for interpolate"""
class Interpolate(Module):
def forward(self, data):
return torch.nn.functional.interpolate(data, (5, 5))
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Interpolate(), input_info)
def test_addmm():
"""test relax translator for addmm"""
class Addmm(Module):
def forward(self, x_1, x_2, x_3):
return torch.addmm(x_1, x_2, x_3)
input_info = [
([10, 10], "float32"),
([10, 10], "float32"),
([10, 10], "float32"),
]
verify_model(Addmm(), input_info)
def test_split():
"""test relax translator for split"""
class Split1(Module):
def forward(self, data):
return torch.split(data, 1, dim=1)
class Split2(Module):
def forward(self, data):
return torch.split(data, [1, 2], dim=1)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Split1(), input_info)
verify_model(Split2(), input_info)
def test_unbind():
"""test relax translator for unbind"""
class Unbind1(Module):
def forward(self, data):
return torch.unbind(data)
class Unbind2(Module):
def forward(self, data):
return torch.unbind(data, dim=1)
input_info = [([3, 3, 10, 10], "float32")]
verify_model(Unbind1(), input_info)
verify_model(Unbind2(), input_info)
def test_cumsum():
"""test relax translator for cumsum"""
class Cumsum(Module):
def forward(self, data):
return torch.cumsum(data, dim=1, dtype=torch.int32)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Cumsum(), input_info)
def test_chunk():
"""test relax translator for chunk"""
class Chunk(Module):
def forward(self, data):
return torch.chunk(data, 3, dim=1)
input_info = [([1, 3, 10, 10], "float32")]
verify_model(Chunk(), input_info)
def test_inplace_fill():
"""test relax translator for inplace_fill"""
class InplaceFill(Module):
def forward(self, data):
data.fill_(1.5)
return data
verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0})
def test_arange():
"""test relax translator for arange"""
class Arange(Module):
def forward(self):
return torch.arange(0, 20, dtype=torch.int32)
verify_model(Arange(), [([10, 10], "float32")])
def test_empty():
"""test relax translator for empty"""
class Empty(Module):
def forward(self):
return torch.empty((10, 10), dtype=torch.float32)
verify_model(Empty(), [([10, 10], "float32")])
def test_tensor():
"""test relax translator for tensor"""
class Empty1(Module):
def forward(self):
return torch.tensor(3, dtype=torch.float32)
class Empty2(Module):
def forward(self):
return torch.tensor(3)
verify_model(Empty1(), [([10, 10], "float32")])
verify_model(Empty2(), [([10, 10], "float32")])
def test_tril():
"""test relax translator for tril"""
class Tril(Module):
def forward(self, data):
return torch.tril(data, 1)
class InplaceTril(Module):
def forward(self, data):
data.tril_(1)
return data
input_info = [([10, 10], "float32")]
verify_model(Tril(), input_info)
verify_model(InplaceTril(), input_info)
def test_triu():
"""test relax translator for triu"""
class Triu(Module):
def forward(self, data):
return torch.triu(data, 1)
class InplaceTriu(Module):
def forward(self, data):
data.triu_(1)
return data
input_info = [([10, 10], "float32")]
verify_model(Triu(), input_info)
verify_model(InplaceTriu(), input_info)
def test_new_ones():
"""test relax translator for new_ones"""
class NewOnes(Module):
def forward(self, x):
return x.new_ones(1, 2, 3)
input_info = [([1, 2, 3], "float32")]
verify_model(NewOnes(), input_info, opt_config={"opt_level": 0})
def test_expand():
"""test relax translator for expand"""
class Expand1(Module):
def forward(self, x):
return x.expand(4, 2, 3, 4)
class Expand2(Module):
def forward(self, x):
return x.expand(4, -1, -1, 4)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Expand1(), input_info)
verify_model(Expand2(), input_info)
def test_reduce():
"""test relax translator for reduce"""
# sum
class Sum(Module):
def forward(self, x):
return torch.sum(x, (2, 1))
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Sum(), input_info)
def test_datatype():
"""test relax translator for datatype"""
input_info = [([1, 2, 3, 4], "float32")]
# float
class ToFloat(Module):
def forward(self, x):
return x.float()
verify_model(ToFloat(), input_info)
# half
class ToHalf(Module):
def forward(self, x):
return x.half()
verify_model(ToHalf(), input_info)
# type
class Type(Module):
def forward(self, x):
return x.type(torch.float32)
# type
class TypeFromAttr(Module):
def forward(self, x):
return x.type(x.getattr("dtype"))
# astype
class AsType(Module):
def forward(self, x):
return x.astype(torch.float32)
verify_model(Type(), input_info)
verify_model(TypeFromAttr(), input_info)
verify_model(AsType(), input_info)
def test_permute():
"""test relax translator for permute"""
class Permute(Module):
def forward(self, x):
return x.permute(0, 3, 2, 1)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Permute(), input_info)
def test_reshape():
"""test relax translator for reshape"""
class Reshape(Module):
def forward(self, x):
return x.reshape(2, 12)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Reshape(), input_info)
def test_transpose():
"""test relax translator for transpose"""
class Transpose(Module):
def forward(self, x):
return x.transpose(1, 3)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(Transpose(), input_info)
def test_view():
"""test relax translator for view"""
class View(Module):
def forward(self, x):
return x.view(2, 12)
input_info = [([1, 2, 3, 4], "float32")]
verify_model(View(), input_info)
def test_keep_params():
"""test relax translator for keep_params"""
class Conv2D1(Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)
def forward(self, data):
return self.conv(data)
verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")])
def test_unwrap_unit_return_tuple():
"""test relax translator for unwrap_unit_return_tuple"""
class Identity(Module):
def forward(self, x):
return (x,)
verify_model(Identity(), [([256, 256], "float32")])
def test_no_bind_return_tuple():
"""test relax translator for no_bind_return_tuple"""
class Identity(Module):
def forward(self, x, y):
return (x, y)
input_info = [([256, 256], "float32"), ([256, 256], "float32")]
verify_model(Identity(), input_info)
def test_argmax():
"""test relax translator for argmax"""
class Argmax1(Module):
def forward(self, data):
return torch.argmax(data, dim=-1)
class Argmax2(Module):
def forward(self, data):
return torch.argmax(data, dim=-1, keepdim=True)
verify_model(Argmax1(), [([256, 256], "float32")])
verify_model(Argmax2(), [([256, 256], "float32")])
def test_argmin():
"""test relax translator for argmin"""
class Argmin1(Module):
def forward(self, data):
return torch.argmin(data)
class Argmin2(Module):
def forward(self, data):
return torch.argmin(data, keepdim=True)
verify_model(Argmin1(), [([256, 256], "float32")])
verify_model(Argmin2(), [([256, 256], "float32")])
def test_to():
"""test relax translator for to"""
class To1(Module):
def forward(self, data):
return data.to(torch.float16)
class To2(Module):
def forward(self, data):
return data.to("cpu")
verify_model(To1(), [([256, 256], "float32")])
verify_model(To2(), [([256, 256], "float32")])
def test_mean():
"""test relax translator for mean"""
class Mean(Module):
def forward(self, data):
return data.mean(-1)
class MeanKeepDim(Module):
def forward(self, data):
return data.mean(-1, keepdim=True)
verify_model(Mean(), [([256, 256], "float32")])
verify_model(MeanKeepDim(), [([256, 256], "float32")])
def test_rsqrt():
"""test relax translator for rsqrt"""
class Rsqrt(Module):
def forward(self, data):
return torch.rsqrt(data)
verify_model(Rsqrt(), [([256, 256], "float32")])
def test_neg():
"""test relax translator for neg"""
class Neg(Module):
def forward(self, data):
return -data
verify_model(Neg(), [([256, 256], "float32")])
def test_max():
"""test relax translator for max"""
class Max(Module):
def forward(self, x, y):
return torch.max(x, y)
verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")])
def test_cat():
"""test relax translator for cat"""
class Cat1(Module):
def forward(self, data, data1, data2):
return torch.cat((data, data1, data2), dim=1)
class Cat2(Module):
def forward(self, data):
const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
return torch.cat((data, const1, const2), dim=1)
input_info = [
([1, 3, 10, 10], "float32"),
([1, 3, 10, 10], "float32"),
([1, 3, 10, 10], "float32"),
]
verify_model(Cat1(), input_info)
verify_model(Cat2(), [([1, 3, 10, 10], "float32")])
def test_stack():
"""test relax translator for stack"""
class Stack1(Module):
def forward(self, data, data1, data2):
return torch.stack((data, data1, data2), dim=0)
class Stack2(Module):
def forward(self, data):
const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32)
return torch.stack((data, const1, const2), dim=1)
input_info = [
([1, 3, 10, 10], "float32"),
([1, 3, 10, 10], "float32"),
([1, 3, 10, 10], "float32"),
]
verify_model(Stack1(), input_info)
verify_model(Stack2(), [([1, 3, 10, 10], "float32")])
def test_scatter():
"""test relax translator for scatter"""
class Scatter1(Module):
def __init__(self):
super().__init__()
self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5)
def forward(self, data, src):
return data.scatter(dim=0, index=self.index, src=src)
class Scatter2(Module):
def forward(self, data, index, src):
return data.scatter(0, index, src)
verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")])
verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")])
def test_masked_scatter():
"""test relax translator for masked_scatter"""
class MaskedScatter1(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH)
def forward(self, data, src):
return data.masked_scatter(self.mask, src)
class MaskedScatter2(Module):
def __init__(self):
super().__init__()
self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH)
def forward(self, data, src):
return data.masked_scatter(self.mask, src)
verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")])
verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")])
def test_attention():
"""test relax translator for attention"""
# pylint: disable=import-outside-toplevel
import torch.nn.functional as F
class Attention1(Module):
def forward(self, q_data, k_data, v_data):
return F.scaled_dot_product_attention(q_data, k_data, v_data)
class Attention2(Module):
def forward(self, q_data, k_data, v_data):
return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True)
input_info = [
([32, 8, 128, 64], "float32"),
([32, 8, 128, 64], "float32"),
([32, 8, 128, 64], "float32"),
]
verify_model(Attention1(), input_info)
verify_model(Attention2(), input_info)
class Attention3(Module):
def forward(self, q_data, k_data, v_data, mask):
return F.scaled_dot_product_attention(q_data, k_data, v_data, mask)
verify_model(
Attention3(),
[
([32, 8, 128, 64], "float32"),
([32, 8, 128, 64], "float32"),
([32, 8, 128, 64], "float32"),
([32, 8, 128, 128], "float32"),
],
)
if __name__ == "__main__":
tvm.testing.main()