| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Test translate from relax.""" |
| |
| import torch |
| from torch.nn import Module |
| |
| import numpy as np |
| |
| import tvm.testing |
| from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate |
| |
| from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen |
| from tvm.contrib.msc.core.frontend import translate as core_translate |
| from tvm.contrib.msc.core.utils.namespace import MSCFramework |
| from tvm.contrib.msc.core import utils as msc_utils |
| |
| |
| def verify_model(torch_model, input_info, opt_config=None): |
| """Compare torch module IR""" |
| |
| orig_mod, _ = torch_translate.from_torch(torch_model, input_info, as_msc=False) |
| target = "llvm" |
| dev = tvm.cpu() |
| args = [msc_utils.random_data(i, MSCFramework.TVM) for i in input_info] |
| |
| def _tvm_runtime_to_np(obj): |
| if isinstance(obj, tvm.runtime.Tensor): |
| return obj.numpy() |
| elif isinstance(obj, tvm.runtime.ShapeTuple): |
| return np.array(obj, dtype="int64") |
| elif isinstance(obj, (list, tvm.ir.container.Array)): |
| return [_tvm_runtime_to_np(item) for item in obj] |
| elif isinstance(obj, tuple): |
| return tuple(_tvm_runtime_to_np(item) for item in obj) |
| else: |
| return obj |
| |
| def _run_relax(relax_mod): |
| relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod) |
| relax_exec = tvm.compile(relax_mod, target) |
| vm_runner = tvm.relax.VirtualMachine(relax_exec, dev) |
| res = vm_runner["main"](*args) |
| return _tvm_runtime_to_np(res) |
| |
| rt_mod = tvm_codegen.to_relax( |
| *core_translate.from_relax(orig_mod, opt_config=opt_config), |
| codegen_config={"explicit_name": False}, |
| ) |
| |
| orig_output = _run_relax(orig_mod) |
| rt_output = _run_relax(rt_mod) |
| if not isinstance(orig_output, (list, tuple)): |
| orig_output = [orig_output] |
| if not isinstance(rt_output, (list, tuple)): |
| rt_output = [rt_output] |
| for o_out, r_out in zip(orig_output, rt_output): |
| tvm.testing.assert_allclose(o_out, r_out) |
| |
| |
| def test_conv1d(): |
| """test relax translator for conv1d""" |
| |
| class Conv1D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) |
| |
| def forward(self, data): |
| return self.conv(data) |
| |
| class Conv1D2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) |
| |
| def forward(self, data): |
| return self.conv(data) |
| |
| input_info = [([1, 3, 10], "float32")] |
| verify_model(Conv1D1(), input_info) |
| verify_model(Conv1D2(), input_info) |
| |
| |
| def test_conv2d(): |
| """test relax translator for conv2d""" |
| |
| class Conv2D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) |
| |
| def forward(self, data): |
| return self.conv(data) |
| |
| class Conv2D2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) |
| |
| def forward(self, data): |
| return self.conv(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Conv2D1(), input_info) |
| verify_model(Conv2D2(), input_info) |
| |
| |
| def test_linear(): |
| """test relax translator for linear""" |
| |
| class Dense1(Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 7, bias=True) |
| |
| def forward(self, data): |
| return self.linear(data) |
| |
| class Dense2(Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 7, bias=False) |
| |
| def forward(self, data): |
| return self.linear(data) |
| |
| class MatMul1(Module): |
| def forward(self, x, y): |
| return torch.matmul(x, y) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Dense1(), input_info) |
| verify_model(Dense2(), input_info) |
| verify_model(MatMul1(), [([10, 10], "float32"), ([10, 10], "float32")]) |
| |
| |
| def test_bmm(): |
| """test relax translator for bmm""" |
| |
| class BMM(Module): |
| def forward(self, x, y): |
| return torch.bmm(x, y) |
| |
| input_info = [((4, 128, 256), "float32"), ((4, 256, 512), "float32")] |
| verify_model(BMM(), input_info) |
| |
| |
| def test_baddbmm(): |
| """test relax translator for baddbmm""" |
| |
| class BAddBMM1(Module): |
| def forward(self, c, x, y): |
| return torch.baddbmm(c, x, y) |
| |
| class BAddBMM2(Module): |
| def forward(self, c, x, y): |
| return torch.baddbmm(c, x, y, alpha=2, beta=0) |
| |
| input_info = [ |
| ((4, 128, 512), "float32"), |
| ((4, 128, 256), "float32"), |
| ((4, 256, 512), "float32"), |
| ] |
| verify_model(BAddBMM1(), input_info) |
| verify_model(BAddBMM2(), input_info) |
| |
| |
| def test_relu(): |
| """test relax translator for relu""" |
| |
| class ReLU(Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, data): |
| return self.relu(data) |
| |
| class ReLU1(Module): |
| def forward(self, data): |
| return torch.nn.functional.relu(data) |
| |
| input_info = [([10, 10], "float32")] |
| verify_model(ReLU(), input_info) |
| verify_model(ReLU1(), input_info) |
| |
| |
| def test_relu6(): |
| """test relax translator for relu6""" |
| |
| class ReLU6(Module): |
| def __init__(self): |
| super().__init__() |
| self.relu6 = torch.nn.ReLU6() |
| |
| def forward(self, data): |
| return self.relu6(data) |
| |
| input_info = [([10, 10], "float32")] |
| verify_model(ReLU6(), input_info) |
| |
| |
| def test_maxpool2d(): |
| """test relax translator for maxpool2d""" |
| |
| class MaxPool2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| class MaxPool2d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| class MaxPool2d3(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(MaxPool2d(), input_info) |
| verify_model(MaxPool2d2(), input_info) |
| verify_model(MaxPool2d3(), input_info) |
| |
| |
| def test_avgpool2d(): |
| """test relax translator for avgpool2d""" |
| |
| class AvgPool2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| class AvgPool2d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(AvgPool2d(), input_info) |
| verify_model(AvgPool2d2(), input_info) |
| |
| |
| def test_adaptive_avgpool2d(): |
| """test relax translator for adaptive_avgpool2d""" |
| |
| class AdaptiveAvgPool2d0(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) |
| |
| def forward(self, data): |
| return self.pool(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(AdaptiveAvgPool2d0(), input_info) |
| |
| |
| def test_flatten(): |
| """test relax translator for flatten""" |
| |
| class Flatten(Module): |
| def __init__(self): |
| super().__init__() |
| self.f = torch.nn.Flatten(2, -1) |
| |
| def forward(self, data): |
| return self.f(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Flatten(), input_info) |
| verify_model(torch.nn.Flatten(2, -1), input_info) |
| |
| |
| def test_batchnorm2d(): |
| """test relax translator for batchnorm2d""" |
| |
| class BatchNorm2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.batchnorm = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, data): |
| return self.batchnorm(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(BatchNorm2d(), input_info) |
| |
| |
| def test_embedding(): |
| """test relax translator for embedding""" |
| |
| class Embedding(Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = torch.nn.Embedding(10, 3) |
| |
| def forward(self, data): |
| return self.embedding(data) |
| |
| verify_model(Embedding(), [([4], "int64")]) |
| verify_model(Embedding(), [([4, 5], "int64")]) |
| |
| |
| def test_dropout(): |
| """test relax translator for dropout""" |
| |
| class Dropout1(Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = torch.nn.Dropout(0.5) |
| |
| def forward(self, data): |
| return self.dropout(data) |
| |
| class Dropout2(Module): |
| def forward(self, data): |
| return torch.dropout(data, 0.5, train=True) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Dropout1(), input_info) |
| verify_model(Dropout2(), input_info) |
| |
| |
| def test_layernorm(): |
| """test relax translator for layernorm""" |
| |
| class LayerNorm(Module): |
| def __init__(self): |
| super().__init__() |
| self.layernorm = torch.nn.LayerNorm((10, 10)) |
| |
| def forward(self, data): |
| return self.layernorm(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(LayerNorm(), input_info) |
| |
| |
| def test_functional_layernorm(): |
| """test relax translator for functional_layernorm""" |
| |
| class LayerNorm(Module): |
| def __init__(self, shape): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.ones(shape)) |
| self.bias = torch.nn.Parameter(torch.zeros(shape)) |
| |
| def forward(self, data): |
| return torch.nn.functional.layer_norm( |
| data, self.weight.shape, self.weight, self.bias, 1e-5 |
| ) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(LayerNorm((10, 10)), input_info) |
| |
| |
| def test_cross_entropy(): |
| """test relax translator for cross_entropy""" |
| |
| class CrossEntropy1(Module): |
| def __init__(self): |
| super().__init__() |
| self.loss = torch.nn.CrossEntropyLoss() |
| |
| def forward(self, logits, targets): |
| return self.loss(logits, targets) |
| |
| class CrossEntropy2(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.nn.Parameter(torch.ones((2,))) |
| self.loss = torch.nn.CrossEntropyLoss(weight=self.weight) |
| |
| def forward(self, logits, targets): |
| return self.loss(logits, targets) |
| |
| class CrossEntropy3(Module): |
| def __init__(self): |
| super().__init__() |
| self.loss = torch.nn.CrossEntropyLoss(ignore_index=1, reduction="sum") |
| |
| def forward(self, logits, targets): |
| return self.loss(logits, targets) |
| |
| input_info = [([3, 2], "float32"), ([3], "int32")] |
| verify_model(CrossEntropy1(), input_info) |
| verify_model(CrossEntropy2(), input_info) |
| verify_model(CrossEntropy3(), input_info) |
| |
| |
| def test_functional_cross_entropy(): |
| """test relax translator for functional_cross_entropy""" |
| |
| class CrossEntropy(Module): |
| def forward(self, logits, targets): |
| return torch.nn.functional.cross_entropy(logits, targets) |
| |
| input_info = [([3, 10], "float32"), ([3], "int32")] |
| verify_model(CrossEntropy(), input_info) |
| |
| |
| def test_silu(): |
| """test relax translator for silu""" |
| |
| class SiLU(Module): |
| def __init__(self): |
| super().__init__() |
| self.silu = torch.nn.SiLU() |
| |
| def forward(self, data): |
| return self.silu(data) |
| |
| class SiLU2(Module): |
| def forward(self, data): |
| return torch.nn.functional.silu(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(SiLU(), input_info) |
| verify_model(SiLU2(), input_info) |
| |
| |
| def test_groupnorm(): |
| """test relax translator for groupnorm""" |
| |
| class GroupNorm(Module): |
| def __init__(self): |
| super().__init__() |
| self.groupnorm = torch.nn.GroupNorm(3, 3) |
| |
| def forward(self, data): |
| return self.groupnorm(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(GroupNorm(), input_info) |
| |
| |
| def test_softmax(): |
| """test relax translator for softmax""" |
| |
| class Softmax(Module): |
| def __init__(self): |
| super().__init__() |
| self.softmax = torch.nn.Softmax(dim=1) |
| |
| def forward(self, data): |
| return self.softmax(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Softmax(), input_info) |
| |
| |
| def test_binary(): |
| """test relax translator for binary""" |
| |
| input_info1 = [([1, 3, 10, 10], "float32"), ([1, 3, 10, 10], "float32")] |
| input_info2 = [([1, 3, 10, 10], "float32")] |
| |
| # Add |
| class Add1(Module): |
| def forward(self, lhs, rhs): |
| return lhs + rhs |
| |
| class Add2(Module): |
| def forward(self, lhs): |
| return lhs + 1.0 |
| |
| verify_model(Add1(), input_info1) |
| verify_model(Add2(), input_info2) |
| |
| # Sub |
| class Sub1(Module): |
| def forward(self, lhs, rhs): |
| return lhs - rhs |
| |
| class Sub2(Module): |
| def forward(self, lhs): |
| return lhs - 1.0 |
| |
| verify_model(Sub1(), input_info1) |
| verify_model(Sub2(), input_info2) |
| |
| # Mul |
| class Mul1(Module): |
| def forward(self, lhs, rhs): |
| return lhs * rhs |
| |
| class Mul2(Module): |
| def forward(self, lhs): |
| return lhs * 1.0 |
| |
| verify_model(Mul1(), input_info1) |
| verify_model(Mul2(), input_info2) |
| |
| # True div |
| class TrueDiv1(Module): |
| def forward(self, lhs, rhs): |
| return lhs / rhs |
| |
| class TrueDiv2(Module): |
| def forward(self, lhs): |
| return lhs / 1.0 |
| |
| verify_model(TrueDiv1(), input_info1) |
| verify_model(TrueDiv2(), input_info2) |
| |
| # Floor div |
| class FloorDiv1(Module): |
| def forward(self, lhs, rhs): |
| return lhs // rhs |
| |
| class FloorDiv2(Module): |
| def forward(self, lhs): |
| return lhs // 1.0 |
| |
| verify_model(FloorDiv1(), input_info1) |
| verify_model(FloorDiv2(), input_info2) |
| |
| # Power |
| class Power1(Module): |
| def forward(self, lhs, rhs): |
| return lhs**rhs |
| |
| class Power2(Module): |
| def forward(self, lhs): |
| return lhs**1.0 |
| |
| verify_model(Power1(), input_info1) |
| verify_model(Power2(), input_info2) |
| |
| # LT |
| class LT1(Module): |
| def forward(self, lhs, rhs): |
| return lhs < rhs |
| |
| class LT2(Module): |
| def forward(self, lhs): |
| return lhs < 1.0 |
| |
| verify_model(LT1(), input_info1) |
| verify_model(LT2(), input_info2) |
| |
| |
| def test_size(): |
| """test relax translator for size""" |
| |
| class Size(Module): |
| def forward(self, data): |
| return data.size() |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Size(), input_info) |
| |
| |
| def test_squeeze(): |
| """test relax translator for squeeze""" |
| |
| class Squeeze1(Module): |
| def forward(self, data): |
| return data.squeeze(1) |
| |
| class Squeeze2(Module): |
| def forward(self, data): |
| return data.squeeze() |
| |
| input_info = [([3, 1, 4, 1], "float32")] |
| verify_model(Squeeze1(), input_info) |
| verify_model(Squeeze2(), input_info) |
| |
| |
| def test_unsqueeze(): |
| """test relax translator for unsqueeze""" |
| |
| class Unsqueeze1(Module): |
| def forward(self, data): |
| return data.unsqueeze(1) |
| |
| class Unsqueeze2(Module): |
| def forward(self, data): |
| return data.unsqueeze(-1) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Unsqueeze1(), input_info) |
| verify_model(Unsqueeze2(), input_info) |
| |
| |
| def test_getattr(): |
| """test relax translator for getattr""" |
| |
| class GetAttr1(Module): |
| def forward(self, data): |
| return data.shape |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(GetAttr1(), input_info) |
| |
| |
| def test_getitem(): |
| """test relax translator for getitem""" |
| |
| class Slice1(Module): |
| def forward(self, x): |
| return x[0, 1::2, :, :3] |
| |
| class Slice2(Module): |
| def forward(self, x): |
| return x[:, None, None, :, None] |
| |
| verify_model(Slice1(), [([1, 3, 10, 10], "float32")]) |
| verify_model(Slice2(), [([8, 16], "float32")]) |
| |
| |
| def test_unary(): |
| """test relax translator for unary""" |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| |
| # sin |
| class Sin(Module): |
| def forward(self, data): |
| return torch.sin(data) |
| |
| verify_model(Sin(), input_info) |
| |
| # cos |
| class Cos(Module): |
| def forward(self, data): |
| return torch.cos(data) |
| |
| verify_model(Cos(), input_info) |
| |
| # exp |
| class Exp(Module): |
| def forward(self, data): |
| return torch.exp(data) |
| |
| verify_model(Exp(), input_info) |
| |
| # sqrt |
| class Sqrt(Module): |
| def forward(self, data): |
| return torch.sqrt(data) |
| |
| verify_model(Sqrt(), input_info) |
| |
| # sigmoid |
| class Sigmoid(Module): |
| def forward(self, data): |
| return torch.sigmoid(data) |
| |
| verify_model(Sigmoid(), input_info) |
| |
| # round |
| class Round(Module): |
| def forward(self, data): |
| return torch.round(data) |
| |
| verify_model(Round(), input_info) |
| |
| |
| def test_gelu(): |
| """test relax translator for gelu""" |
| |
| class Gelu(Module): |
| def forward(self, data): |
| return torch.nn.functional.gelu(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Gelu(), input_info) |
| |
| |
| def test_tanh(): |
| """test relax translator for tanh""" |
| |
| class Tanh(Module): |
| def forward(self, data): |
| return torch.tanh(data) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Tanh(), input_info) |
| |
| |
| def test_clamp(): |
| """test relax translator for clamp""" |
| |
| class Clamp(Module): |
| def forward(self, data): |
| return torch.clamp(data, min=0.1, max=0.5) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Clamp(), input_info) |
| |
| |
| def test_interpolate(): |
| """test relax translator for interpolate""" |
| |
| class Interpolate(Module): |
| def forward(self, data): |
| return torch.nn.functional.interpolate(data, (5, 5)) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Interpolate(), input_info) |
| |
| |
| def test_addmm(): |
| """test relax translator for addmm""" |
| |
| class Addmm(Module): |
| def forward(self, x_1, x_2, x_3): |
| return torch.addmm(x_1, x_2, x_3) |
| |
| input_info = [ |
| ([10, 10], "float32"), |
| ([10, 10], "float32"), |
| ([10, 10], "float32"), |
| ] |
| verify_model(Addmm(), input_info) |
| |
| |
| def test_split(): |
| """test relax translator for split""" |
| |
| class Split1(Module): |
| def forward(self, data): |
| return torch.split(data, 1, dim=1) |
| |
| class Split2(Module): |
| def forward(self, data): |
| return torch.split(data, [1, 2], dim=1) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Split1(), input_info) |
| verify_model(Split2(), input_info) |
| |
| |
| def test_unbind(): |
| """test relax translator for unbind""" |
| |
| class Unbind1(Module): |
| def forward(self, data): |
| return torch.unbind(data) |
| |
| class Unbind2(Module): |
| def forward(self, data): |
| return torch.unbind(data, dim=1) |
| |
| input_info = [([3, 3, 10, 10], "float32")] |
| verify_model(Unbind1(), input_info) |
| verify_model(Unbind2(), input_info) |
| |
| |
| def test_cumsum(): |
| """test relax translator for cumsum""" |
| |
| class Cumsum(Module): |
| def forward(self, data): |
| return torch.cumsum(data, dim=1, dtype=torch.int32) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Cumsum(), input_info) |
| |
| |
| def test_chunk(): |
| """test relax translator for chunk""" |
| |
| class Chunk(Module): |
| def forward(self, data): |
| return torch.chunk(data, 3, dim=1) |
| |
| input_info = [([1, 3, 10, 10], "float32")] |
| verify_model(Chunk(), input_info) |
| |
| |
| def test_inplace_fill(): |
| """test relax translator for inplace_fill""" |
| |
| class InplaceFill(Module): |
| def forward(self, data): |
| data.fill_(1.5) |
| return data |
| |
| verify_model(InplaceFill(), [([10, 10], "float32")], opt_config={"opt_level": 0}) |
| |
| |
| def test_arange(): |
| """test relax translator for arange""" |
| |
| class Arange(Module): |
| def forward(self): |
| return torch.arange(0, 20, dtype=torch.int32) |
| |
| verify_model(Arange(), [([10, 10], "float32")]) |
| |
| |
| def test_empty(): |
| """test relax translator for empty""" |
| |
| class Empty(Module): |
| def forward(self): |
| return torch.empty((10, 10), dtype=torch.float32) |
| |
| verify_model(Empty(), [([10, 10], "float32")]) |
| |
| |
| def test_tensor(): |
| """test relax translator for tensor""" |
| |
| class Empty1(Module): |
| def forward(self): |
| return torch.tensor(3, dtype=torch.float32) |
| |
| class Empty2(Module): |
| def forward(self): |
| return torch.tensor(3) |
| |
| verify_model(Empty1(), [([10, 10], "float32")]) |
| verify_model(Empty2(), [([10, 10], "float32")]) |
| |
| |
| def test_tril(): |
| """test relax translator for tril""" |
| |
| class Tril(Module): |
| def forward(self, data): |
| return torch.tril(data, 1) |
| |
| class InplaceTril(Module): |
| def forward(self, data): |
| data.tril_(1) |
| return data |
| |
| input_info = [([10, 10], "float32")] |
| verify_model(Tril(), input_info) |
| verify_model(InplaceTril(), input_info) |
| |
| |
| def test_triu(): |
| """test relax translator for triu""" |
| |
| class Triu(Module): |
| def forward(self, data): |
| return torch.triu(data, 1) |
| |
| class InplaceTriu(Module): |
| def forward(self, data): |
| data.triu_(1) |
| return data |
| |
| input_info = [([10, 10], "float32")] |
| verify_model(Triu(), input_info) |
| verify_model(InplaceTriu(), input_info) |
| |
| |
| def test_new_ones(): |
| """test relax translator for new_ones""" |
| |
| class NewOnes(Module): |
| def forward(self, x): |
| return x.new_ones(1, 2, 3) |
| |
| input_info = [([1, 2, 3], "float32")] |
| verify_model(NewOnes(), input_info, opt_config={"opt_level": 0}) |
| |
| |
| def test_expand(): |
| """test relax translator for expand""" |
| |
| class Expand1(Module): |
| def forward(self, x): |
| return x.expand(4, 2, 3, 4) |
| |
| class Expand2(Module): |
| def forward(self, x): |
| return x.expand(4, -1, -1, 4) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Expand1(), input_info) |
| verify_model(Expand2(), input_info) |
| |
| |
| def test_reduce(): |
| """test relax translator for reduce""" |
| |
| # sum |
| class Sum(Module): |
| def forward(self, x): |
| return torch.sum(x, (2, 1)) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Sum(), input_info) |
| |
| |
| def test_datatype(): |
| """test relax translator for datatype""" |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| |
| # float |
| class ToFloat(Module): |
| def forward(self, x): |
| return x.float() |
| |
| verify_model(ToFloat(), input_info) |
| |
| # half |
| class ToHalf(Module): |
| def forward(self, x): |
| return x.half() |
| |
| verify_model(ToHalf(), input_info) |
| |
| # type |
| class Type(Module): |
| def forward(self, x): |
| return x.type(torch.float32) |
| |
| # type |
| class TypeFromAttr(Module): |
| def forward(self, x): |
| return x.type(x.getattr("dtype")) |
| |
| # astype |
| class AsType(Module): |
| def forward(self, x): |
| return x.astype(torch.float32) |
| |
| verify_model(Type(), input_info) |
| verify_model(TypeFromAttr(), input_info) |
| verify_model(AsType(), input_info) |
| |
| |
| def test_permute(): |
| """test relax translator for permute""" |
| |
| class Permute(Module): |
| def forward(self, x): |
| return x.permute(0, 3, 2, 1) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Permute(), input_info) |
| |
| |
| def test_reshape(): |
| """test relax translator for reshape""" |
| |
| class Reshape(Module): |
| def forward(self, x): |
| return x.reshape(2, 12) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Reshape(), input_info) |
| |
| |
| def test_transpose(): |
| """test relax translator for transpose""" |
| |
| class Transpose(Module): |
| def forward(self, x): |
| return x.transpose(1, 3) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(Transpose(), input_info) |
| |
| |
| def test_view(): |
| """test relax translator for view""" |
| |
| class View(Module): |
| def forward(self, x): |
| return x.view(2, 12) |
| |
| input_info = [([1, 2, 3, 4], "float32")] |
| verify_model(View(), input_info) |
| |
| |
| def test_keep_params(): |
| """test relax translator for keep_params""" |
| |
| class Conv2D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) |
| |
| def forward(self, data): |
| return self.conv(data) |
| |
| verify_model(Conv2D1(), [([1, 3, 10, 10], "float32")]) |
| |
| |
| def test_unwrap_unit_return_tuple(): |
| """test relax translator for unwrap_unit_return_tuple""" |
| |
| class Identity(Module): |
| def forward(self, x): |
| return (x,) |
| |
| verify_model(Identity(), [([256, 256], "float32")]) |
| |
| |
| def test_no_bind_return_tuple(): |
| """test relax translator for no_bind_return_tuple""" |
| |
| class Identity(Module): |
| def forward(self, x, y): |
| return (x, y) |
| |
| input_info = [([256, 256], "float32"), ([256, 256], "float32")] |
| verify_model(Identity(), input_info) |
| |
| |
| def test_argmax(): |
| """test relax translator for argmax""" |
| |
| class Argmax1(Module): |
| def forward(self, data): |
| return torch.argmax(data, dim=-1) |
| |
| class Argmax2(Module): |
| def forward(self, data): |
| return torch.argmax(data, dim=-1, keepdim=True) |
| |
| verify_model(Argmax1(), [([256, 256], "float32")]) |
| verify_model(Argmax2(), [([256, 256], "float32")]) |
| |
| |
| def test_argmin(): |
| """test relax translator for argmin""" |
| |
| class Argmin1(Module): |
| def forward(self, data): |
| return torch.argmin(data) |
| |
| class Argmin2(Module): |
| def forward(self, data): |
| return torch.argmin(data, keepdim=True) |
| |
| verify_model(Argmin1(), [([256, 256], "float32")]) |
| verify_model(Argmin2(), [([256, 256], "float32")]) |
| |
| |
| def test_to(): |
| """test relax translator for to""" |
| |
| class To1(Module): |
| def forward(self, data): |
| return data.to(torch.float16) |
| |
| class To2(Module): |
| def forward(self, data): |
| return data.to("cpu") |
| |
| verify_model(To1(), [([256, 256], "float32")]) |
| verify_model(To2(), [([256, 256], "float32")]) |
| |
| |
| def test_mean(): |
| """test relax translator for mean""" |
| |
| class Mean(Module): |
| def forward(self, data): |
| return data.mean(-1) |
| |
| class MeanKeepDim(Module): |
| def forward(self, data): |
| return data.mean(-1, keepdim=True) |
| |
| verify_model(Mean(), [([256, 256], "float32")]) |
| verify_model(MeanKeepDim(), [([256, 256], "float32")]) |
| |
| |
| def test_rsqrt(): |
| """test relax translator for rsqrt""" |
| |
| class Rsqrt(Module): |
| def forward(self, data): |
| return torch.rsqrt(data) |
| |
| verify_model(Rsqrt(), [([256, 256], "float32")]) |
| |
| |
| def test_neg(): |
| """test relax translator for neg""" |
| |
| class Neg(Module): |
| def forward(self, data): |
| return -data |
| |
| verify_model(Neg(), [([256, 256], "float32")]) |
| |
| |
| def test_max(): |
| """test relax translator for max""" |
| |
| class Max(Module): |
| def forward(self, x, y): |
| return torch.max(x, y) |
| |
| verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) |
| |
| |
| def test_cat(): |
| """test relax translator for cat""" |
| |
| class Cat1(Module): |
| def forward(self, data, data1, data2): |
| return torch.cat((data, data1, data2), dim=1) |
| |
| class Cat2(Module): |
| def forward(self, data): |
| const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) |
| const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) |
| return torch.cat((data, const1, const2), dim=1) |
| |
| input_info = [ |
| ([1, 3, 10, 10], "float32"), |
| ([1, 3, 10, 10], "float32"), |
| ([1, 3, 10, 10], "float32"), |
| ] |
| verify_model(Cat1(), input_info) |
| verify_model(Cat2(), [([1, 3, 10, 10], "float32")]) |
| |
| |
| def test_stack(): |
| """test relax translator for stack""" |
| |
| class Stack1(Module): |
| def forward(self, data, data1, data2): |
| return torch.stack((data, data1, data2), dim=0) |
| |
| class Stack2(Module): |
| def forward(self, data): |
| const1 = torch.ones((1, 3, 10, 10), dtype=torch.float32) |
| const2 = torch.ones((1, 3, 10, 10), dtype=torch.float32) |
| return torch.stack((data, const1, const2), dim=1) |
| |
| input_info = [ |
| ([1, 3, 10, 10], "float32"), |
| ([1, 3, 10, 10], "float32"), |
| ([1, 3, 10, 10], "float32"), |
| ] |
| verify_model(Stack1(), input_info) |
| verify_model(Stack2(), [([1, 3, 10, 10], "float32")]) |
| |
| |
| def test_scatter(): |
| """test relax translator for scatter""" |
| |
| class Scatter1(Module): |
| def __init__(self): |
| super().__init__() |
| self.index = msc_utils.random_data([(2, 5), "int64"], MSCFramework.TORCH, max_val=5) |
| |
| def forward(self, data, src): |
| return data.scatter(dim=0, index=self.index, src=src) |
| |
| class Scatter2(Module): |
| def forward(self, data, index, src): |
| return data.scatter(0, index, src) |
| |
| verify_model(Scatter1(), [([20, 20], "float32"), ([2, 5], "float32")]) |
| verify_model(Scatter2(), [([20, 20], "float32"), ([2, 5], "int64"), ([2, 5], "float32")]) |
| |
| |
| def test_masked_scatter(): |
| """test relax translator for masked_scatter""" |
| |
| class MaskedScatter1(Module): |
| def __init__(self): |
| super().__init__() |
| self.mask = msc_utils.random_data([(5,), "bool"], MSCFramework.TORCH) |
| |
| def forward(self, data, src): |
| return data.masked_scatter(self.mask, src) |
| |
| class MaskedScatter2(Module): |
| def __init__(self): |
| super().__init__() |
| self.mask = msc_utils.random_data([(2, 5), "bool"], MSCFramework.TORCH) |
| |
| def forward(self, data, src): |
| return data.masked_scatter(self.mask, src) |
| |
| verify_model(MaskedScatter1(), [([5], "float32"), ([10], "float32")]) |
| verify_model(MaskedScatter2(), [([2, 5], "float32"), ([3, 5], "float32")]) |
| |
| |
| def test_attention(): |
| """test relax translator for attention""" |
| |
| # pylint: disable=import-outside-toplevel |
| import torch.nn.functional as F |
| |
| class Attention1(Module): |
| def forward(self, q_data, k_data, v_data): |
| return F.scaled_dot_product_attention(q_data, k_data, v_data) |
| |
| class Attention2(Module): |
| def forward(self, q_data, k_data, v_data): |
| return F.scaled_dot_product_attention(q_data, k_data, v_data, is_causal=True) |
| |
| input_info = [ |
| ([32, 8, 128, 64], "float32"), |
| ([32, 8, 128, 64], "float32"), |
| ([32, 8, 128, 64], "float32"), |
| ] |
| verify_model(Attention1(), input_info) |
| verify_model(Attention2(), input_info) |
| |
| class Attention3(Module): |
| def forward(self, q_data, k_data, v_data, mask): |
| return F.scaled_dot_product_attention(q_data, k_data, v_data, mask) |
| |
| verify_model( |
| Attention3(), |
| [ |
| ([32, 8, 128, 64], "float32"), |
| ([32, 8, 128, 64], "float32"), |
| ([32, 8, 128, 64], "float32"), |
| ([32, 8, 128, 128], "float32"), |
| ], |
| ) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |