| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import operator |
| import pytest |
| import torch |
| import numpy as np |
| from torch import nn |
| from torch.nn import Module |
| from torch.export import export |
| |
| import tvm |
| from tvm import relax |
| import tvm.testing |
| from tvm.script import ir as I |
| from tvm.script import relax as R |
| from tvm.script import tir as T |
| from tvm.relax.frontend.torch import from_exported_program |
| |
| |
| def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None): |
| exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) |
| mod = from_exported_program(exported_program) |
| |
| binding = {k: tvm.runtime.tensor(v) for k, v in binding.items()} |
| expected = relax.transform.BindParams("main", binding)(expected) |
| tvm.ir.assert_structural_equal(mod, expected) |
| |
| |
| operator_basic_unary = [ |
| (torch.abs, R.abs), |
| (torch.acos, R.acos), |
| (torch.acosh, R.acosh), |
| (torch.asin, R.asin), |
| (torch.asinh, R.asinh), |
| (torch.atan, R.atan), |
| (torch.atanh, R.atanh), |
| (torch.bitwise_not, R.bitwise_not), |
| (torch.ceil, R.ceil), |
| (torch.cos, R.cos), |
| (torch.cosh, R.cosh), |
| (torch.erf, R.erf), |
| (torch.exp, R.exp), |
| (torch.floor, R.floor), |
| (torch.ops.aten.gelu, R.nn.gelu), |
| (torch.log, R.log), |
| (torch.neg, R.negative), |
| (torch.relu, R.nn.relu), |
| (torch.relu_, R.nn.relu), |
| (torch.round, R.round), |
| (torch.rsqrt, R.rsqrt), |
| (torch.selu, R.nn.selu), |
| (torch.sigmoid, R.sigmoid), |
| (torch.ops.aten.silu, R.nn.silu), |
| (torch.ops.aten.silu_, R.nn.silu), |
| (torch.sin, R.sin), |
| (torch.sinh, R.sinh), |
| (torch.sign, R.sign), |
| (torch.sqrt, R.sqrt), |
| (torch.square, R.square), |
| (torch.tan, R.tan), |
| (torch.tanh, R.tanh), |
| (torch.trunc, R.trunc), |
| ] |
| |
| |
| @pytest.mark.parametrize("pytorch_op, relax_op", operator_basic_unary) |
| def test_basic_unary_ops(pytorch_op, relax_op): |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| class UnaryOp(Module): |
| def forward(self, input): |
| return pytorch_op(input) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = relax_op(input_1) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(UnaryOp(), example_args, {}, expected) |
| |
| |
| operator_bool_unary = [ |
| (torch.isfinite, R.isfinite), |
| (torch.isinf, R.isinf), |
| (torch.isnan, R.isnan), |
| ] |
| |
| |
| @pytest.mark.parametrize("pytorch_op, relax_op", operator_bool_unary) |
| def test_bool_unary_ops(pytorch_op, relax_op): |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| class UnaryOp(Module): |
| def forward(self, input): |
| return pytorch_op(input) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="bool") = relax_op(input_1) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(UnaryOp(), example_args, {}, expected) |
| |
| |
| def test_extended_unary_ops(): |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| # celu |
| class Celu1(Module): |
| def __init__(self): |
| super().__init__() |
| self.celu = torch.nn.CELU() |
| |
| def forward(self, input): |
| return self.celu(input) |
| |
| class Celu2(Module): |
| def forward(self, input): |
| return torch.nn.functional.celu(input) |
| |
| # alpha * min(0, exp(x / alpha) - 1) + max(0, x) |
| @tvm.script.ir_module |
| class expected_celu: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) |
| lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| lv, R.const(1.0, "float32") |
| ) |
| lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( |
| lv_div, R.const(1.0, "float32") |
| ) |
| lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum( |
| R.const(0.0, "float32"), lv_sub |
| ) |
| lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( |
| R.const(1.0, "float32"), lv_min |
| ) |
| lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) |
| lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_celu,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Celu1(), example_args, {}, expected_celu) |
| verify_model(Celu2(), example_args, {}, expected_celu) |
| |
| # clamp |
| class Clamp(Module): |
| def forward(self, input): |
| return torch.clamp(input, min=0.1, max=0.5) |
| |
| @tvm.script.ir_module |
| class expected_clamp: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( |
| input, |
| R.prim_value(T.float64(0.10000000000000001)), |
| R.prim_value(T.float64(0.5)), |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Clamp(), example_args, {}, expected_clamp) |
| |
| class ClampMinOnly(Module): |
| def forward(self, input): |
| return torch.clamp(input, min=0.5, max=None) |
| |
| @tvm.script.ir_module |
| class expected_clamp_min_only: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( |
| input, R.prim_value(T.float64(0.5)), R.prim_value(T.float64("inf")) |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only) |
| |
| class ClampTensors(Module): |
| def forward(self, input): |
| return torch.clamp(input, min=input, max=input) |
| |
| @tvm.script.ir_module |
| class expected_clamp_tensors: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( |
| input, R.shape([1, 3, 10, 10]) |
| ) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.maximum(input, lv) |
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to( |
| input, R.shape([1, 3, 10, 10]) |
| ) |
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(lv1, lv2) |
| lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( |
| lv3, R.prim_value(T.float64("-inf")), R.prim_value(T.float64("inf")) |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors) |
| |
| # dropout |
| |
| class Dropout1(Module): |
| def __init__(self): |
| super().__init__() |
| self.dropout = torch.nn.Dropout(0.5) |
| |
| def forward(self, input): |
| return self.dropout(input) |
| |
| class Dropout2(Module): |
| def forward(self, input): |
| return torch.dropout(input, 0.5, train=True) |
| |
| class Dropout3(Module): |
| def forward(self, input): |
| return torch.ops.aten.dropout_(input, 0.5, train=True) |
| |
| @tvm.script.ir_module |
| class expected_dropout: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Dropout1(), example_args, {}, expected_dropout) |
| verify_model(Dropout2(), example_args, {}, expected_dropout) |
| verify_model(Dropout3(), example_args, {}, expected_dropout) |
| |
| # elu |
| class Elu(Module): |
| def __init__(self): |
| super().__init__() |
| self.elu = torch.nn.ELU() |
| |
| def forward(self, input): |
| return self.elu(input) |
| |
| class Elu2(Module): |
| def forward(self, input): |
| return torch.nn.functional.elu(input) |
| |
| @tvm.script.ir_module |
| class expected_elu: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1) |
| lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( |
| R.const(1.0, dtype="float32"), lv_exp |
| ) |
| lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu( |
| lv_one_minus_exp |
| ) |
| lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply( |
| R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp |
| ) |
| lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) |
| lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv_scaled, lv_relu_x) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv_elu,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Elu(), example_args, {}, expected_elu) |
| verify_model(Elu2(), example_args, {}, expected_elu) |
| |
| # hardsigmoid |
| class Hardsigmoid(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.hs = torch.nn.Hardsigmoid() |
| |
| def forward(self, input): |
| return self.hs(input) |
| |
| class Hardsigmoid2(torch.nn.Module): |
| def forward(self, input): |
| return torch.nn.functional.hardsigmoid(input) |
| |
| @tvm.script.ir_module |
| class expected_hardsigmoid: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) |
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| lv1, R.const(6, "float32") |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv2,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Hardsigmoid(), example_args, {}, expected_hardsigmoid) |
| verify_model(Hardsigmoid2(), example_args, {}, expected_hardsigmoid) |
| |
| # hardwish |
| class Hardswish(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.hs = torch.nn.Hardswish() |
| |
| def forward(self, input): |
| return self.hs(input) |
| |
| class Hardswish2(torch.nn.Module): |
| def forward(self, input): |
| return torch.nn.functional.hardswish(input) |
| |
| class Hardswish3(torch.nn.Module): |
| def forward(self, input): |
| return torch.ops.aten.hardswish_(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(inp_0, R.const(3, "float32")) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(lv, 0, 6) |
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| lv1, R.const(6, "float32") |
| ) |
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(inp_0, lv2) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Hardswish(), example_args, {}, expected1) |
| verify_model(Hardswish2(), example_args, {}, expected1) |
| verify_model(Hardswish3(), example_args, {}, expected1) |
| |
| # log2 |
| class Log2(Module): |
| def forward(self, x): |
| return torch.log2(x) |
| |
| @tvm.script.ir_module |
| class Expected_log2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(inp_0) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| lv, R.const(0.69314718246459961, "float32") |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Log2(), example_args, {}, Expected_log2) |
| |
| # log10 |
| class Log10(Module): |
| def forward(self, x): |
| return torch.log10(x) |
| |
| @tvm.script.ir_module |
| class Expected_log10: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(inp_0) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| lv, R.const(2.302585092994046, "float32") |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Log10(), example_args, {}, Expected_log10) |
| |
| # log1p |
| class Log1p(Module): |
| def forward(self, x): |
| return torch.log1p(x) |
| |
| @tvm.script.ir_module |
| class Expected_log1p: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log( |
| R.add(inp_0, R.const(1, "float32")) |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Log1p(), example_args, {}, Expected_log1p) |
| |
| # reciprocal |
| class Reciprocal(Module): |
| def forward(self, input): |
| return torch.reciprocal(input) |
| |
| @tvm.script.ir_module |
| class expected_reciprocal: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide( |
| R.const(1.0, "float32"), input_1 |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Reciprocal(), example_args, {}, expected_reciprocal) |
| |
| # Returns the maximum value of all elements in the input tensor. |
| class MaxModel(Module): |
| def forward(self, input): |
| return torch.max(input) |
| |
| @tvm.script.ir_module |
| class expected_max: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.max(input, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(MaxModel(), example_args, {}, expected_max) |
| |
| # Returns the minimum value of all elements in the input tensor. |
| class MinModel(Module): |
| def forward(self, input): |
| return torch.min(input) |
| |
| @tvm.script.ir_module |
| class expected_min: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.min(input, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(MinModel(), example_args, {}, expected_min) |
| |
| # relu6 |
| class ReLU6_1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu6 = torch.nn.ReLU6() |
| |
| def forward(self, x): |
| return self.relu6(x) |
| |
| class ReLU6_2(torch.nn.Module): |
| def forward(self, x): |
| return torch.nn.functional.relu6(x) |
| |
| class ReLU6_3(torch.nn.Module): |
| def forward(self, x): |
| return torch.ops.aten.relu6_(x) |
| |
| @tvm.script.ir_module |
| class expected_relu6_1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( |
| x, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(6.0)) |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_relu6_2: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu6(x) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(ReLU6_1(), example_args, {}, expected_relu6_1) |
| verify_model(ReLU6_2(), example_args, {}, expected_relu6_2) |
| verify_model(ReLU6_3(), example_args, {}, expected_relu6_2) |
| |
| |
| def test_hardtanh(): |
| class Hardtanh(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.ht = torch.nn.Hardtanh() |
| |
| def forward(self, input): |
| return self.ht(input) |
| |
| class Hardtanh2(torch.nn.Module): |
| def forward(self, input): |
| return torch.nn.functional.hardtanh(input) |
| |
| class Hardtanh3(torch.nn.Module): |
| def forward(self, input): |
| return torch.ops.aten.hardtanh_(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip( |
| inp_0, R.prim_value(T.float64(-1.0)), R.prim_value(T.float64(1.0)) |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Hardtanh(), example_args, {}, expected1) |
| verify_model(Hardtanh2(), example_args, {}, expected1) |
| verify_model(Hardtanh3(), example_args, {}, expected1) |
| |
| |
| def test_softplus(): |
| import torch |
| from torch.nn import Module |
| |
| torch.set_grad_enabled(False) |
| |
| class Softplus0(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.softplus = torch.nn.Softplus(1.0, 20.0) |
| |
| def forward(self, x): |
| return self.softplus(x) |
| |
| class Softplus1(Module): |
| def forward(self, input): |
| return torch.nn.functional.softplus(input, 1.0, 20.0) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus( |
| x, beta=1.0, threshold=20.0 |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Softplus0(), example_args, {}, expected) |
| verify_model(Softplus1(), example_args, {}, expected) |
| |
| |
| def test_leakyrelu(): |
| import torch |
| from torch.nn import Module |
| |
| torch.set_grad_enabled(False) |
| |
| class LeakyReLU0(Module): |
| def __init__(self): |
| super().__init__() |
| self.leakyrelu = torch.nn.LeakyReLU(0.02) |
| |
| def forward(self, input): |
| return self.leakyrelu(input) |
| |
| class LeakyReLU1(Module): |
| def forward(self, input): |
| return torch.nn.functional.leaky_relu(input, 0.02) |
| |
| class LeakyReLU2(Module): |
| def forward(self, input): |
| return torch.ops.aten.leaky_relu_(input, 0.02) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.leakyrelu(input_1, 0.02) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(LeakyReLU0(), example_args, {}, expected) |
| verify_model(LeakyReLU1(), example_args, {}, expected) |
| verify_model(LeakyReLU2(), example_args, {}, expected) |
| |
| |
| def test_logaddexp(): |
| class LogAddExp(Module): |
| def forward(self, input1, input2): |
| return torch.logaddexp(input1, input2) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| input_2: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log_add_exp(input_1, input_2) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(1, 3, 10, 10, dtype=torch.float32), |
| torch.randn(1, 3, 10, 10, dtype=torch.float32), |
| ) |
| verify_model(LogAddExp(), example_args, {}, expected) |
| |
| |
| def test_logsoftmax(): |
| class LogSoftmax(Module): |
| def __init__(self): |
| super().__init__() |
| self.lsm = torch.nn.LogSoftmax(dim=1) |
| |
| def forward(self, input): |
| return self.lsm(input) |
| |
| class LogSoftmax2(Module): |
| def forward(self, input): |
| return torch.nn.functional.log_softmax(input, dim=1) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.log_softmax(input_1, axis=1) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(LogSoftmax(), example_args, {}, expected1) |
| verify_model(LogSoftmax2(), example_args, {}, expected1) |
| |
| |
| def test_prelu(): |
| class Prelu1(Module): |
| def __init__(self, num_parameters=1, alpha=0.25): |
| super().__init__() |
| self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) |
| |
| def forward(self, x): |
| return self.prelu(x) |
| |
| class Prelu2(torch.nn.Module): |
| def __init__(self): |
| super(Prelu2, self).__init__() |
| self.alpha = torch.nn.Parameter(torch.tensor([0.25])) |
| |
| def forward(self, x): |
| return torch.nn.functional.prelu(x, self.alpha) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( |
| x, R.const([0.25], dtype="float32"), axis=1 |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Prelu1(), example_args, {}, expected) |
| verify_model(Prelu2(), example_args, {}, expected) |
| |
| |
| def test_softmax(): |
| class Softmax(Module): |
| def __init__(self): |
| super().__init__() |
| self.sm = torch.nn.Softmax(dim=1) |
| |
| def forward(self, input): |
| return self.sm(input) |
| |
| class Softmax2(Module): |
| def forward(self, input): |
| return torch.nn.functional.softmax(input, dim=1) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softmax(input_1, axis=1) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Softmax(), example_args, {}, expected1) |
| verify_model(Softmax2(), example_args, {}, expected1) |
| |
| |
| def test_softsign(): |
| class Softsign(Module): |
| def __init__(self): |
| super().__init__() |
| self.ss = torch.nn.Softsign() |
| |
| def forward(self, input): |
| return self.ss(input) |
| |
| class Softsign2(Module): |
| def forward(self, input): |
| return torch.nn.functional.softsign(input) |
| |
| @tvm.script.ir_module |
| class expected_softsign: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| abs_val = R.abs(input) |
| denom = R.add(abs_val, R.const(1.0, "float32")) |
| result = R.divide(input, denom) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (result,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Softsign(), example_args, {}, expected_softsign) |
| verify_model(Softsign2(), example_args, {}, expected_softsign) |
| |
| |
| def test_softshrink(): |
| class Softshrink(Module): |
| def __init__(self): |
| super().__init__() |
| self.softshrink = torch.nn.Softshrink(lambd=0.5) |
| |
| def forward(self, input): |
| return self.softshrink(input) |
| |
| class Softshrink2(Module): |
| def forward(self, input): |
| return torch.nn.functional.softshrink(input, lambd=0.5) |
| |
| @tvm.script.ir_module |
| class expected_softshrink: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract( |
| input, R.const(0.5, "float32") |
| ) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater( |
| input, R.const(0.5, "float32") |
| ) |
| lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32") |
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2) |
| |
| lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add( |
| input, R.const(0.5, "float32") |
| ) |
| lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32")) |
| lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5) |
| lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32") |
| lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7) |
| |
| lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8) |
| |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Softshrink(), example_args, {}, expected_softshrink) |
| verify_model(Softshrink2(), example_args, {}, expected_softshrink) |
| |
| |
| def test_tril_triu(): |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| |
| class Tril(Module): |
| def forward(self, input): |
| return torch.tril(input, 1) |
| |
| @tvm.script.ir_module |
| class expected_tril: |
| @R.function |
| def main( |
| input_1: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Tril(), example_args, {}, expected_tril) |
| |
| class Triu(Module): |
| def forward(self, input): |
| return torch.triu(input, 1) |
| |
| @tvm.script.ir_module |
| class expected_triu: |
| @R.function |
| def main( |
| input_1: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Triu(), example_args, {}, expected_triu) |
| |
| |
| operator_binary_1 = [ |
| (operator.add, R.add), |
| (torch.ops.aten.add_, R.add), |
| (torch.ops.aten.bitwise_or, R.bitwise_or), |
| (torch.ops.aten.bitwise_or_, R.bitwise_or), |
| (operator.sub, R.subtract), |
| (operator.mul, R.multiply), |
| (torch.ops.aten.mul_, R.multiply), |
| (operator.truediv, R.divide), |
| (operator.floordiv, R.floor_divide), |
| (torch.ops.aten.fmod, R.mod), |
| (operator.pow, R.power), |
| (operator.mod, R.floor_mod), |
| (operator.and_, R.bitwise_and), |
| (operator.or_, R.bitwise_or), |
| (operator.xor, R.bitwise_xor), |
| ] |
| |
| |
| @pytest.mark.parametrize("op, relax_op", operator_binary_1) |
| def test_binary1(op, relax_op): |
| example_args1 = ( |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(10, 10, dtype=torch.float32), |
| ) |
| example_args2 = (torch.randn(10, 10, dtype=torch.float32),) |
| |
| class Binary1(Module): |
| def __init__(self, op): |
| super().__init__() |
| self.op = op |
| |
| def forward(self, lhs, rhs): |
| return self.op(lhs, rhs) |
| |
| @tvm.script.ir_module |
| class expected_binary1: |
| @R.function |
| def main( |
| lhs: R.Tensor((10, 10), dtype="float32"), |
| rhs: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class Binary2(Module): |
| def __init__(self, op): |
| super().__init__() |
| self.op = op |
| |
| def forward(self, lhs): |
| return self.op(lhs, 1.0) |
| |
| @tvm.script.ir_module |
| class expected_binary2: |
| @R.function |
| def main( |
| lhs: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0)) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Binary1(op), example_args1, {}, expected_binary1) |
| verify_model(Binary2(op), example_args2, {}, expected_binary2) |
| |
| |
| operator_binary_2 = [ |
| (operator.eq, R.equal), |
| (operator.ne, R.not_equal), |
| (operator.lt, R.less), |
| (operator.le, R.less_equal), |
| (operator.gt, R.greater), |
| (operator.ge, R.greater_equal), |
| ] |
| |
| |
| @pytest.mark.parametrize("op, relax_op", operator_binary_2) |
| def test_binary2(op, relax_op): |
| example_args1 = ( |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(10, 10, dtype=torch.float32), |
| ) |
| example_args2 = (torch.randn(10, 10, dtype=torch.float32),) |
| |
| class Binary1(Module): |
| def __init__(self, op): |
| super().__init__() |
| self.op = op |
| |
| def forward(self, lhs, rhs): |
| return self.op(lhs, rhs) |
| |
| @tvm.script.ir_module |
| class expected_binary1: |
| @R.function |
| def main( |
| lhs: R.Tensor((10, 10), dtype="float32"), |
| rhs: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, rhs) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class Binary2(Module): |
| def __init__(self, op): |
| super().__init__() |
| self.op = op |
| |
| def forward(self, lhs): |
| return self.op(lhs, 1.0) |
| |
| @tvm.script.ir_module |
| class expected_binary2: |
| @R.function |
| def main( |
| lhs: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="bool") = relax_op(lhs, R.const(1.0)) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Binary1(op), example_args1, {}, expected_binary1) |
| verify_model(Binary2(op), example_args2, {}, expected_binary2) |
| |
| |
| def test_binary3(): |
| example_args1 = ( |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(10, 10, dtype=torch.float32), |
| ) |
| example_args2 = (torch.randn(10, 10, dtype=torch.float32),) |
| |
| # Max |
| class Max1(Module): |
| def forward(self, x, y): |
| return torch.max(x, y) |
| |
| @I.ir_module |
| class expected_max1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((10, 10), dtype="float32"), |
| inp_1: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.maximum(inp_0, inp_1) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Max1(), example_args1, {}, expected_max1) |
| |
| # Min |
| class Min1(Module): |
| def forward(self, x, y): |
| return torch.min(x, y) |
| |
| @I.ir_module |
| class expected_min1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((10, 10), dtype="float32"), |
| inp_1: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.minimum(inp_0, inp_1) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Min1(), example_args1, {}, expected_min1) |
| |
| # RSub |
| class RSub1(Module): |
| def forward(self, x, y): |
| return torch.rsub(x, y) |
| |
| class RSub2(Module): |
| def forward(self, x): |
| return torch.rsub(x, 5.0) |
| |
| @tvm.script.ir_module |
| class expected_rsub1: |
| @R.function |
| def main( |
| x: R.Tensor((10, 10), dtype="float32"), y: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.subtract(y, x) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_rsub2: |
| @R.function |
| def main( |
| x: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.subtract(R.const(5.0, "float32"), x) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(RSub1(), example_args1, {}, expected_rsub1) |
| verify_model(RSub2(), example_args2, {}, expected_rsub2) |
| |
| |
| # IsIn |
| |
| |
| def test_isin(): |
| class IsInModel(torch.nn.Module): |
| def forward(self, x, test_elements): |
| return torch.isin(x, test_elements) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="bool")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, axis=[-1]) |
| lv1: R.Tensor((8,), dtype="float32") = R.reshape(test_elements, R.shape([8])) |
| lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1) |
| lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False) |
| lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32")) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(8, dtype=torch.float32), |
| ) |
| verify_model(IsInModel(), example_args, {}, expected) |
| |
| |
| def test_div_mode(): |
| # Case 1: Basic division (no rounding mode) |
| class DivModel(torch.nn.Module): |
| def forward(self, a, b): |
| return torch.div(a, b) |
| |
| @tvm.script.ir_module |
| class expected_div: |
| @R.function |
| def main( |
| a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b) |
| gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(64, 64, dtype=torch.float32), |
| torch.randn(64, dtype=torch.float32), |
| ) |
| verify_model(DivModel(), example_args, {}, expected_div) |
| |
| # Case 2: Division with trunc rounding |
| class DivTruncModel(torch.nn.Module): |
| def forward(self, a, b): |
| return torch.div(a, b, rounding_mode="trunc") |
| |
| @tvm.script.ir_module |
| class expected_div_trunc: |
| @R.function |
| def main( |
| a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((64, 64), dtype="float32") = R.divide(a, b) |
| lv1: R.Tensor((64, 64), dtype="float32") = R.trunc(lv) |
| gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| verify_model(DivTruncModel(), example_args, {}, expected_div_trunc) |
| |
| # Case 3: Division with floor rounding |
| class DivFloorModel(torch.nn.Module): |
| def forward(self, a, b): |
| return torch.div(a, b, rounding_mode="floor") |
| |
| @tvm.script.ir_module |
| class expected_div_floor: |
| @R.function |
| def main( |
| a: R.Tensor((64, 64), dtype="float32"), b: R.Tensor((64,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((64, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((64, 64), dtype="float32") = R.floor_divide(a, b) |
| gv: R.Tuple(R.Tensor((64, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(DivFloorModel(), example_args, {}, expected_div_floor) |
| |
| |
| def test_batchnorm2d(): |
| class BatchNorm2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.bn = torch.nn.BatchNorm2d(3) |
| |
| def forward(self, input): |
| return self.bn(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((3,), dtype="float32"), |
| w2: R.Tensor((3,), dtype="float32"), |
| w3: R.Tensor((3,), dtype="float32"), |
| w4: R.Tensor((3,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| R.Tensor((3,), dtype="float32"), |
| R.Tensor((3,), dtype="float32"), |
| ) = R.nn.batch_norm( |
| input_1, |
| w1, |
| w2, |
| w3, |
| w4, |
| axis=1, |
| epsilon=1e-05, |
| center=True, |
| scale=True, |
| ) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = BatchNorm2d().eval() |
| binding = { |
| "w1": model.bn.weight.detach().numpy(), |
| "w2": model.bn.bias.detach().numpy(), |
| "w3": model.bn.running_mean.detach().numpy(), |
| "w4": model.bn.running_var.detach().numpy(), |
| } |
| verify_model(model, example_args, binding, expected1) |
| |
| |
| def test_adaptive_avgpool1d(): |
| class AdaptiveAvgPool1d0(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AdaptiveAvgPool1d(output_size=5) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AdaptiveAvgPool1d1(torch.nn.Module): |
| def forward(self, input): |
| return torch.nn.functional.adaptive_avg_pool1d(input, output_size=5) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 5), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 5), dtype="float32") = R.nn.adaptive_avg_pool1d( |
| input_1, output_size=[5], layout="NCW" |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 5), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) |
| verify_model(AdaptiveAvgPool1d0(), example_args, {}, expected1) |
| verify_model(AdaptiveAvgPool1d1(), example_args, {}, expected1) |
| |
| |
| def test_adaptive_avgpool2d(): |
| class AdaptiveAvgPool2d0(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AdaptiveAvgPool2d1(Module): |
| def forward(self, input): |
| return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( |
| input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) |
| verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) |
| |
| |
| def test_adaptive_avgpool3d(): |
| class AdaptiveAvgPool3d0(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AdaptiveAvgPool3d([4, 4, 4]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AdaptiveAvgPool3d1(torch.nn.Module): |
| def forward(self, input): |
| return torch.nn.functional.adaptive_avg_pool3d(input, [4, 4, 4]) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = R.nn.adaptive_avg_pool3d( |
| input_1, output_size=[4, 4, 4], layout="NCDHW", out_layout="NCDHW" |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) |
| verify_model(AdaptiveAvgPool3d0(), example_args, {}, expected1) |
| verify_model(AdaptiveAvgPool3d1(), example_args, {}, expected1) |
| |
| |
| def test_addmm(): |
| class Addmm1(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x1, x2, x3): |
| return torch.addmm(x1, x2, x3) |
| |
| class Addmm2(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x1, x2, x3): |
| return torch.addmm(x1, x2, x3, beta=0.8, alpha=0.5) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x1: R.Tensor((10, 10), dtype="float32"), |
| x2: R.Tensor((10, 10), dtype="float32"), |
| x3: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") |
| lv1: R.Tensor((10, 10), dtype="float32") = R.add(x1, lv) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| x1: R.Tensor((10, 10), dtype="float32"), |
| x2: R.Tensor((10, 10), dtype="float32"), |
| x3: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.matmul(x2, x3, out_dtype="float32") |
| lv1: R.Tensor((10, 10), dtype="float32") = R.multiply(lv, R.const(0.5, "float32")) |
| lv2: R.Tensor((10, 10), dtype="float32") = R.multiply(x1, R.const(0.8, "float32")) |
| lv3: R.Tensor((10, 10), dtype="float32") = R.add(lv2, lv1) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(10, 10, dtype=torch.float32), |
| torch.randn(10, 10, dtype=torch.float32), |
| ) |
| |
| verify_model(Addmm1(), example_args, {}, expected1) |
| verify_model(Addmm2(), example_args, {}, expected2) |
| |
| |
| def test_avg_pool1d(): |
| class AvgPool1d1(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool1d(kernel_size=1) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10), dtype="float32") = R.nn.avg_pool1d( |
| input_1, |
| pool_size=[1], |
| strides=[1], |
| dilation=[1], |
| padding=[0, 0], |
| ceil_mode=False, |
| count_include_pad=True, |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool1d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool1d(kernel_size=3, stride=2, padding=1, ceil_mode=True) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AvgPool1d3(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool1d( |
| input, kernel_size=3, stride=2, padding=1, ceil_mode=True |
| ) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool1d( |
| input_1, |
| pool_size=[3], |
| strides=[2], |
| dilation=[1], |
| padding=[1, 1], |
| ceil_mode=True, |
| count_include_pad=True, |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool1d4(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool1d(input, kernel_size=2, stride=2, padding=0) |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 10), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool1d( |
| input_1, |
| pool_size=[2], |
| strides=[2], |
| dilation=[1], |
| padding=[0, 0], |
| ceil_mode=False, |
| count_include_pad=True, |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) |
| verify_model(AvgPool1d1(), example_args, {}, expected1) |
| verify_model(AvgPool1d2(), example_args, {}, expected2) |
| verify_model(AvgPool1d3(), example_args, {}, expected2) |
| verify_model(AvgPool1d4(), example_args, {}, expected3) |
| |
| |
| def test_avg_pool2d(): |
| class AvgPool2d1(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool2d(kernel_size=[1, 1]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.avg_pool2d( |
| input_1, |
| pool_size=[1, 1], |
| strides=[1, 1], |
| dilation=[1, 1], |
| padding=[0, 0, 0, 0], |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool2d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool2d(kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AvgPool2d3(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool2d( |
| input, kernel_size=[4, 4], stride=2, padding=2, ceil_mode=True |
| ) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool2d( |
| input_1, |
| pool_size=[4, 4], |
| strides=[2, 2], |
| dilation=[1, 1], |
| padding=[2, 2, 2, 2], |
| ceil_mode=True, |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool2d4(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool2d( |
| input_1, |
| pool_size=[2, 1], |
| strides=[2, 1], |
| dilation=[1, 1], |
| padding=[0, 0, 0, 0], |
| ceil_mode=False, |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(AvgPool2d1(), example_args, {}, expected1) |
| verify_model(AvgPool2d2(), example_args, {}, expected2) |
| verify_model(AvgPool2d3(), example_args, {}, expected2) |
| verify_model(AvgPool2d4(), example_args, {}, expected3) |
| |
| |
| def test_avg_pool3d(): |
| class AvgPool3d1(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool3d(kernel_size=1) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 8, 8, 8), dtype="float32") = R.nn.avg_pool3d( |
| input_1, |
| pool_size=[1, 1, 1], |
| strides=[1, 1, 1], |
| dilation=[1, 1, 1], |
| padding=[0, 0, 0, 0, 0, 0], |
| ceil_mode=False, |
| count_include_pad=True, |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 8, 8, 8), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool3d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.AvgPool3d(kernel_size=3, stride=2, padding=1, ceil_mode=True) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class AvgPool3d3(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool3d( |
| input, kernel_size=3, stride=2, padding=1, ceil_mode=True |
| ) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool3d( |
| input_1, |
| pool_size=[3, 3, 3], |
| strides=[2, 2, 2], |
| dilation=[1, 1, 1], |
| padding=[1, 1, 1, 1, 1, 1], |
| ceil_mode=True, |
| count_include_pad=True, |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| class AvgPool3d4(Module): |
| def forward(self, input): |
| return torch.nn.functional.avg_pool3d(input, kernel_size=[2, 1, 2], stride=[2, 1, 2]) |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main(input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.avg_pool3d( |
| input_1, |
| pool_size=[2, 1, 2], |
| strides=[2, 1, 2], |
| dilation=[1, 1, 1], |
| padding=[0, 0, 0, 0, 0, 0], |
| ceil_mode=False, |
| count_include_pad=True, |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) |
| verify_model(AvgPool3d1(), example_args, {}, expected1) |
| verify_model(AvgPool3d2(), example_args, {}, expected2) |
| verify_model(AvgPool3d3(), example_args, {}, expected2) |
| verify_model(AvgPool3d4(), example_args, {}, expected3) |
| |
| |
| def test_baddbmm(): |
| class BAddBMM1(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, c, x, y): |
| return torch.baddbmm(c, x, y) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((4, 128, 512), dtype="float32"), |
| inp_1: R.Tensor((4, 128, 256), dtype="float32"), |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) |
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, inp_0) |
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| class BAddBMM2(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, c, x, y): |
| return torch.baddbmm(c, x, y, alpha=2, beta=0) |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((4, 128, 512), dtype="float32"), |
| inp_1: R.Tensor((4, 128, 256), dtype="float32"), |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) |
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( |
| lv, R.const(2, "float32") |
| ) |
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| class BAddBMM3(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, c, x, y): |
| return torch.baddbmm(c, x, y, alpha=2, beta=3) |
| |
| @tvm.script.ir_module |
| class Expected3: |
| @R.function |
| def main( |
| inp_0: R.Tensor((4, 128, 512), dtype="float32"), |
| inp_1: R.Tensor((4, 128, 256), dtype="float32"), |
| inp_2: R.Tensor((4, 256, 512), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, inp_2) |
| lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( |
| lv, R.const(2, "float32") |
| ) |
| lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply( |
| inp_0, R.const(3, "float32") |
| ) |
| lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(4, 128, 512, dtype=torch.float32), |
| torch.randn(4, 128, 256, dtype=torch.float32), |
| torch.randn(4, 256, 512, dtype=torch.float32), |
| ) |
| verify_model( |
| BAddBMM1(), |
| example_args, |
| {}, |
| Expected1, |
| ) |
| |
| verify_model( |
| BAddBMM2(), |
| example_args, |
| {}, |
| Expected2, |
| ) |
| |
| verify_model( |
| BAddBMM3(), |
| example_args, |
| {}, |
| Expected3, |
| ) |
| |
| |
| def test_bmm(): |
| class BMM(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y): |
| return torch.bmm(x, y) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((4, 128, 256), dtype="float32"), |
| input_2: R.Tensor((4, 256, 512), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul( |
| input_1, input_2, out_dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(4, 128, 256, dtype=torch.float32), |
| torch.randn(4, 256, 512, dtype=torch.float32), |
| ) |
| verify_model( |
| BMM(), |
| example_args, |
| {}, |
| Expected, |
| ) |
| |
| |
| def test_conv_transpose1d(): |
| class ConvTranspose1d1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| class ConvTranspose1d1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[6, 6, 3]) |
| self.bias = torch.randn(size=[6]) |
| |
| def forward(self, input): |
| return torch.nn.functional.conv_transpose1d(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 6, 4), dtype="float32"), |
| w1: R.Tensor((6, 6, 3), dtype="float32"), |
| w2: R.Tensor((6,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0], |
| output_padding=[0], |
| dilation=[1], |
| data_layout="NCW", |
| kernel_layout="IOW", |
| out_layout="NCW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1]) |
| lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class ConvTranspose1d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.ConvTranspose1d(6, 6, 3, bias=False) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 6, 4), dtype="float32"), |
| w1: R.Tensor((6, 6, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 6), dtype="float32") = R.nn.conv1d_transpose( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0], |
| output_padding=[0], |
| dilation=[1], |
| data_layout="NCW", |
| kernel_layout="IOW", |
| out_layout="NCW", |
| out_dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((1, 6, 6), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 6, 4, dtype=torch.float32),) |
| |
| model = ConvTranspose1d1() |
| binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = ConvTranspose1d1Func() |
| binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = ConvTranspose1d2() |
| binding = {"w1": model.conv.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_conv_transpose2d(): |
| class ConvTranspose2d1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| class ConvTranspose2d1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[3, 3, 7, 7]) |
| self.bias = torch.randn(size=[3]) |
| |
| def forward(self, input): |
| return torch.nn.functional.conv_transpose2d(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((3, 3, 7, 7), dtype="float32"), |
| w2: R.Tensor((3,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( |
| input_1, |
| w1, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| output_padding=[0, 0], |
| dilation=[1, 1], |
| data_layout="NCHW", |
| kernel_layout="IOHW", |
| out_layout="NCHW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1]) |
| lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class ConvTranspose2d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.ConvTranspose2d(3, 3, 7, bias=False) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((3, 3, 7, 7), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 3, 16, 16), dtype="float32") = R.nn.conv2d_transpose( |
| input_1, |
| w1, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| output_padding=[0, 0], |
| dilation=[1, 1], |
| data_layout="NCHW", |
| kernel_layout="IOHW", |
| out_layout="NCHW", |
| out_dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 16, 16), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = ConvTranspose2d1() |
| binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = ConvTranspose2d1Func() |
| binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = ConvTranspose2d2() |
| binding = {"w1": model.conv.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_conv1d(): |
| class Conv1D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 6, 7, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| class Conv1D1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[6, 3, 7]) |
| self.bias = torch.randn(size=[6]) |
| |
| def forward(self, input): |
| return torch.nn.functional.conv1d(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| w1: R.Tensor((6, 3, 7), dtype="float32"), |
| w2: R.Tensor((6,), dtype="float32"), |
| input_1: R.Tensor((1, 3, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0], |
| dilation=[1], |
| data_layout="NCW", |
| kernel_layout="OIW", |
| out_layout="NCW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1]) |
| lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class Conv1D2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv1d(3, 6, 7, bias=False) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| w1: R.Tensor((6, 3, 7), dtype="float32"), |
| input_1: R.Tensor((1, 3, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4), dtype="float32") = R.nn.conv1d( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0], |
| dilation=[1], |
| data_layout="NCW", |
| kernel_layout="OIW", |
| out_layout="NCW", |
| out_dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((1, 6, 4), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, dtype=torch.float32),) |
| |
| model = Conv1D1() |
| binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv1D1Func() |
| binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv1D2() |
| binding = {"w1": model.conv.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_conv2d(): |
| class Conv2D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| class Conv2D1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[6, 3, 7, 7]) |
| self.bias = torch.randn(size=[6]) |
| |
| def forward(self, input): |
| return torch.nn.functional.conv2d(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((6, 3, 7, 7), dtype="float32"), |
| w2: R.Tensor((6,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( |
| input_1, |
| w1, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| data_layout="NCHW", |
| kernel_layout="OIHW", |
| out_layout="NCHW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) |
| lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class Conv2D2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((6, 3, 7, 7), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( |
| input_1, |
| w1, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| data_layout="NCHW", |
| kernel_layout="OIHW", |
| out_layout="NCHW", |
| out_dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = Conv2D1() |
| binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv2D1Func() |
| binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv2D2() |
| binding = {"w1": model.conv.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_conv3d(): |
| class Conv3D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv3d(3, 6, 7, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| class Conv3D1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[6, 3, 7, 7, 7]) |
| self.bias = torch.randn(size=[6]) |
| |
| def forward(self, input): |
| return torch.nn.functional.conv3d(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), |
| w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), |
| w2: R.Tensor((6,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0, 0], |
| dilation=[1], |
| data_layout="NCDHW", |
| kernel_layout="OIDHW", |
| out_layout="NCDHW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1]) |
| lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class Conv3D2(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv3d(3, 6, 7, bias=False) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32"), |
| w1: R.Tensor((6, 3, 7, 7, 7), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.nn.conv3d( |
| input_1, |
| w1, |
| strides=[1], |
| padding=[0, 0, 0], |
| dilation=[1], |
| data_layout="NCDHW", |
| kernel_layout="OIDHW", |
| out_layout="NCDHW", |
| out_dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((1, 6, 4, 4, 4), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) |
| |
| model = Conv3D1() |
| binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv3D1Func() |
| binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Conv3D2() |
| binding = {"w1": model.conv.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_pad(): |
| class PadModel(torch.nn.Module): |
| def __init__(self, pad, mode="constant", value=0.0): |
| super().__init__() |
| self.pad = pad |
| self.mode = mode |
| self.value = value |
| |
| def forward(self, x): |
| if self.mode == "constant": |
| return torch.nn.functional.pad(x, self.pad, mode=self.mode, value=self.value) |
| else: |
| return torch.nn.functional.pad(x, self.pad, mode=self.mode) |
| |
| @tvm.script.ir_module |
| class expected_constant: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( |
| x, |
| pad_width=[0, 0, 0, 0, 2, 2, 1, 1], |
| pad_mode="constant", |
| pad_value=0.0, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_reflect: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( |
| x, |
| pad_width=[0, 0, 0, 0, 2, 2, 1, 1], |
| pad_mode="reflect", |
| pad_value=0.0, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_replicate: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( |
| x, |
| pad_width=[0, 0, 0, 0, 2, 2, 1, 1], |
| pad_mode="replicate", |
| pad_value=0.0, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_circular: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 14, 12), dtype="float32") = R.nn.pad( |
| x, |
| pad_width=[0, 0, 0, 0, 2, 2, 1, 1], |
| pad_mode="circular", |
| pad_value=0.0, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 14, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(PadModel(pad=[1, 1, 2, 2]), example_args, {}, expected_constant) |
| verify_model(PadModel(pad=[1, 1, 2, 2], mode="reflect"), example_args, {}, expected_reflect) |
| verify_model(PadModel(pad=[1, 1, 2, 2], mode="replicate"), example_args, {}, expected_replicate) |
| verify_model(PadModel(pad=[1, 1, 2, 2], mode="circular"), example_args, {}, expected_circular) |
| |
| |
| def test_pixel_shuffle(): |
| class PixelShuffle1(torch.nn.Module): |
| def __init__(self, upscale_factor=2): |
| super().__init__() |
| self.pixel_shuffle = torch.nn.PixelShuffle(upscale_factor) |
| |
| def forward(self, x): |
| return self.pixel_shuffle(x) |
| |
| class PixelShuffle2(torch.nn.Module): |
| def __init__(self, upscale_factor=2): |
| super().__init__() |
| self.upscale_factor = upscale_factor |
| |
| def forward(self, x): |
| return torch.nn.functional.pixel_shuffle(x, self.upscale_factor) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 8, 10, 15), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 20, 30), dtype="float32") = R.nn.pixel_shuffle( |
| x, upscale_factor=2 |
| ) |
| gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),) |
| verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected) |
| verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected) |
| |
| |
| def test_einsum(): |
| class Einsum1(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return torch.einsum("ii", x) |
| |
| class Einsum2(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y): |
| return torch.einsum("i,j->ij", x, y) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((4, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), subscripts="ii") |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 4), dtype="float32") = R.einsum( |
| (inp_0, inp_1), subscripts="i,j->ij" |
| ) |
| gv: R.Tuple(R.Tensor((5, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(4, 4, dtype=torch.float32),) |
| verify_model(Einsum1(), example_args, {}, Expected1) |
| |
| example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, dtype=torch.float32)) |
| verify_model(Einsum2(), example_args, {}, Expected2) |
| |
| |
| def test_outer(): |
| class Outer(torch.nn.Module): |
| def forward(self, x, y): |
| return torch.outer(x, y) |
| |
| @tvm.script.ir_module |
| class expected: |
| @R.function |
| def main( |
| a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b) |
| gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(3, dtype=torch.float32), |
| torch.randn(4, dtype=torch.float32), |
| ) |
| verify_model(Outer(), example_args, {}, expected) |
| |
| |
| def test_embedding(): |
| class Embedding(Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = torch.nn.Embedding(10, 3) |
| |
| def forward(self, input): |
| return self.embedding(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((4,), dtype="int64"), w1: R.Tensor((10, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((4,), dtype="int32") = R.astype(input_1, dtype="int32") |
| lv1: R.Tensor((4, 3), dtype="float32") = R.take(w1, lv, axis=0) |
| gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randint(low=-int(1e5), high=int(1e5), size=(4,), dtype=torch.int64),) |
| |
| model = Embedding() |
| binding = {"w1": model.embedding.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| |
| def test_groupnorm(): |
| import torch |
| from torch.nn import Module |
| |
| torch.set_grad_enabled(False) |
| torch.random.manual_seed(0) |
| |
| class GroupNorm(Module): |
| def __init__(self): |
| super().__init__() |
| self.gn = torch.nn.GroupNorm(3, 3) |
| |
| def forward(self, input): |
| return self.gn(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((3,), dtype="float32"), |
| w2: R.Tensor((3,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.group_norm( |
| input_1, |
| w1, |
| w2, |
| num_groups=3, |
| channel_axis=1, |
| axes=[2, 3], |
| epsilon=1.0000000000000001e-05, |
| center=True, |
| scale=True, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = GroupNorm() |
| binding = { |
| "w1": model.gn.weight.detach().numpy(), |
| "w2": model.gn.bias.detach().numpy(), |
| } |
| verify_model(model, example_args, binding, expected1) |
| |
| |
| def test_instancenorm2d(): |
| torch.set_grad_enabled(False) |
| torch.random.manual_seed(0) |
| |
| class InstanceNorm2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.gn = torch.nn.InstanceNorm2d(3) |
| |
| def forward(self, input): |
| return self.gn(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((3,), dtype="float32"), |
| w2: R.Tensor((3,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.instance_norm( |
| input_1, |
| w1, |
| w2, |
| channel_axis=1, |
| axes=[2, 3], |
| epsilon=1e-05, |
| center=True, |
| scale=True, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = InstanceNorm2d() |
| binding = { |
| "w1": torch.ones(3).detach().numpy(), |
| "w2": torch.zeros(3).detach().numpy(), |
| } |
| verify_model(model, example_args, binding, expected1) |
| |
| |
| def test_layernorm(): |
| class LayerNorm(Module): |
| def __init__(self): |
| super().__init__() |
| self.ln = torch.nn.LayerNorm((10, 10)) |
| |
| def forward(self, input): |
| return self.ln(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| w1: R.Tensor((10, 10), dtype="float32"), |
| w2: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.layer_norm( |
| input_1, |
| w1, |
| w2, |
| axes=[-2, -1], |
| epsilon=1e-05, |
| center=True, |
| scale=True, |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = LayerNorm() |
| binding = { |
| "w1": model.ln.weight.detach().numpy(), |
| "w2": model.ln.bias.detach().numpy(), |
| } |
| verify_model(LayerNorm(), example_args, binding, expected1) |
| |
| |
| def test_linear(): |
| class Dense1(Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 7, bias=True) |
| |
| def forward(self, input): |
| return self.linear(input) |
| |
| class Dense1Func(Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = torch.randn(size=[7, 10]) |
| self.bias = torch.randn(size=[7]) |
| |
| def forward(self, input): |
| return torch.nn.functional.linear(input, self.weight, self.bias) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| w1: R.Tensor((7, 10), dtype="float32"), |
| w2: R.Tensor((7,), dtype="float32"), |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) |
| lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( |
| input_1, lv, out_dtype="float32" |
| ) |
| lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,) |
| R.output(gv) |
| return gv |
| |
| class Dense2(Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 7, bias=False) |
| |
| def forward(self, input): |
| return self.linear(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| w1: R.Tensor((7, 10), dtype="float32"), |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) |
| lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( |
| input_1, lv, out_dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| model = Dense1() |
| binding = {"w1": model.linear.weight.detach().numpy(), "w2": model.linear.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Dense1Func() |
| binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} |
| verify_model(model, example_args, binding, expected1) |
| |
| model = Dense2() |
| binding = {"w1": model.linear.weight.detach().numpy()} |
| verify_model(model, example_args, binding, expected2) |
| |
| |
| def test_maxpool1d(): |
| class MaxPool1d(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool1d(kernel_size=2) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class MaxPool1d_functional(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.nn.functional.max_pool1d(input, kernel_size=2) |
| |
| class MaxPool1d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool1d(kernel_size=3, stride=2) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 8), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool1d( |
| input_1, |
| pool_size=[2], |
| strides=[2], |
| dilation=[1], |
| padding=[0, 0], |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 8), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool1d( |
| input_1, |
| pool_size=[2], |
| strides=[2], |
| dilation=[1], |
| padding=[0, 0], |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool1d( |
| input_1, |
| pool_size=[3], |
| strides=[2], |
| dilation=[1], |
| padding=[0, 0], |
| layout="NCW", |
| out_layout="NCW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Example inputs |
| example_args1 = (torch.randn(1, 3, 8, dtype=torch.float32),) |
| example_args2 = (torch.randn(1, 3, 8, dtype=torch.float32),) |
| example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),) |
| |
| # Verify the models |
| verify_model(MaxPool1d(), example_args1, {}, expected1) |
| verify_model(MaxPool1d_functional(), example_args2, {}, expected2) |
| verify_model(MaxPool1d2(), example_args3, {}, expected3) |
| |
| |
| def test_maxpool2d(): |
| class MaxPool2d(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class MaxPool2d_functional(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1]) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( |
| input_1, |
| pool_size=[1, 1], |
| strides=[1, 1], |
| dilation=[1, 1], |
| padding=[0, 0, 0, 0], |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class MaxPool2d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( |
| input_1, |
| pool_size=[2, 2], |
| strides=[2, 2], |
| dilation=[2, 3], |
| padding=[0, 0, 0, 0], |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class MaxPool2d3(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( |
| input_1, |
| pool_size=[4, 4], |
| strides=[2, 2], |
| dilation=[1, 1], |
| padding=[2, 2, 2, 2], |
| layout="NCHW", |
| out_layout="NCHW", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(MaxPool2d(), example_args, {}, expected1) |
| verify_model(MaxPool2d_functional(), example_args, {}, expected1) |
| verify_model(MaxPool2d2(), example_args, {}, expected2) |
| verify_model(MaxPool2d3(), example_args, {}, expected3) |
| |
| |
| def test_maxpool3d(): |
| class MaxPool3d(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool3d(kernel_size=[1, 1, 1]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| class MaxPool3d_functional(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.nn.functional.max_pool3d(input, kernel_size=[1, 1, 1]) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool3d( |
| input_1, |
| pool_size=[1, 1, 1], |
| strides=[1, 1, 1], |
| dilation=[1, 1, 1], |
| padding=[0, 0, 0, 0, 0, 0], |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| class MaxPool3d2(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool3d(kernel_size=[2, 2, 2], dilation=[2, 2, 2]) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool3d( |
| input_1, |
| pool_size=[2, 2, 2], |
| strides=[2, 2, 2], |
| dilation=[2, 2, 2], |
| padding=[0, 0, 0, 0, 0, 0], |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| class MaxPool3d3(Module): |
| def __init__(self): |
| super().__init__() |
| self.pool = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], padding=1, stride=2) |
| |
| def forward(self, input): |
| return self.pool(input) |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")): |
| with R.dataflow(): |
| lv = R.nn.max_pool3d( |
| input_1, |
| pool_size=[3, 3, 3], |
| strides=[2, 2, 2], |
| dilation=[1, 1, 1], |
| padding=[1, 1, 1, 1, 1, 1], |
| layout="NCDHW", |
| out_layout="NCDHW", |
| ) |
| gv = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Example input tensors |
| example_args1 = (torch.randn(1, 3, 4, 4, 4, dtype=torch.float32),) |
| example_args2 = (torch.randn(1, 3, 8, 8, 8, dtype=torch.float32),) |
| example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),) |
| |
| # Verify the models with expected IR modules |
| verify_model(MaxPool3d(), example_args1, {}, expected1) |
| verify_model(MaxPool3d_functional(), example_args1, {}, expected1) |
| verify_model(MaxPool3d2(), example_args2, {}, expected2) |
| verify_model(MaxPool3d3(), example_args3, {}, expected3) |
| |
| |
| def test_scaled_dot_product_attention(): |
| class Attention1(Module): |
| def forward(self, q, k, v): |
| return torch.nn.functional.scaled_dot_product_attention(q, k, v) |
| |
| @I.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_0, axes=[0, 2, 1, 3] |
| ) |
| lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_1, axes=[0, 2, 1, 3] |
| ) |
| lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_2, axes=[0, 2, 1, 3] |
| ) |
| lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( |
| lv, lv1, lv2, scale=None |
| ) |
| lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( |
| lv3, axes=[0, 2, 1, 3] |
| ) |
| gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| class Attention2(Module): |
| def forward(self, q, k, v, mask): |
| return torch.nn.functional.scaled_dot_product_attention(q, k, v, mask) |
| |
| @I.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"), |
| inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_0, axes=[0, 2, 1, 3] |
| ) |
| lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_1, axes=[0, 2, 1, 3] |
| ) |
| lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = R.permute_dims( |
| inp_2, axes=[0, 2, 1, 3] |
| ) |
| lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = R.nn.attention( |
| lv, lv1, lv2, inp_3, scale=None |
| ) |
| lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = R.permute_dims( |
| lv3, axes=[0, 2, 1, 3] |
| ) |
| gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| verify_model( |
| Attention1(), |
| ( |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| ), |
| {}, |
| Expected1, |
| ) |
| |
| verify_model( |
| Attention2(), |
| ( |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| torch.randn(32, 8, 128, 64, dtype=torch.float32), |
| torch.randn(32, 8, 128, 128, dtype=torch.float32), |
| ), |
| {}, |
| Expected2, |
| ) |
| |
| |
| def test_unbind(): |
| class Unbind1(Module): |
| def forward(self, data): |
| return torch.unbind(data) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((3, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) = R.split(input_1, indices_or_sections=3, axis=0) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) |
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] |
| lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) |
| lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] |
| lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) |
| lv7: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv2, lv4, lv6) |
| lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] |
| lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] |
| lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] |
| gv: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv8, lv9, lv10) |
| R.output(gv) |
| return gv |
| |
| class Unbind2(Module): |
| def forward(self, data): |
| return torch.unbind(data, dim=1) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((3, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| ) = R.split(input_1, indices_or_sections=3, axis=1) |
| lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) |
| lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] |
| lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) |
| lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] |
| lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) |
| lv7: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv2, lv4, lv6) |
| lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] |
| lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] |
| lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] |
| gv: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv8, lv9, lv10) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected3: |
| @R.function |
| def main( |
| data: R.Tensor((3, 1, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data, axis=[1]) |
| lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,) |
| lv2: R.Tensor((3, 3), dtype="float32") = lv1[0] |
| gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Unbind1(), example_args, {}, expected1) |
| verify_model(Unbind2(), example_args, {}, expected2) |
| single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),) |
| verify_model(Unbind2(), single_dim_args, {}, expected3) |
| |
| |
| def test_interpolate(): |
| class InterpolateBilinear(Module): |
| def forward(self, input): |
| return torch.nn.functional.interpolate(input, (224, 224), mode="bilinear") |
| |
| @tvm.script.ir_module |
| class expected_bilinear: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 112, 112), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( |
| input, |
| R.shape([224, 224]), |
| roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], |
| layout="NCHW", |
| method="linear", |
| coordinate_transformation_mode="half_pixel", |
| rounding_method="round", |
| cubic_alpha=-0.75, |
| cubic_exclude=0, |
| extrapolation_value=0.0, |
| out_dtype="void", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class InterpolateNearest(Module): |
| def forward(self, input): |
| return torch.nn.functional.interpolate(input, (224, 224), mode="nearest") |
| |
| @tvm.script.ir_module |
| class expected_nearest: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 112, 112), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( |
| input, |
| R.shape([224, 224]), |
| roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], |
| layout="NCHW", |
| method="nearest_neighbor", |
| coordinate_transformation_mode="half_pixel", |
| rounding_method="round", |
| cubic_alpha=-0.75, |
| cubic_exclude=0, |
| extrapolation_value=0.0, |
| out_dtype="void", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class InterpolateBicubic(Module): |
| def forward(self, input): |
| return torch.nn.functional.interpolate(input, (224, 224), mode="bicubic") |
| |
| @tvm.script.ir_module |
| class expected_bicubic: |
| @R.function |
| def main( |
| input: R.Tensor((1, 3, 112, 112), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 224, 224), dtype="float32") = R.image.resize2d( |
| input, |
| R.shape([224, 224]), |
| roi=[T.float32(0.0), T.float32(0.0), T.float32(0.0), T.float32(0.0)], |
| layout="NCHW", |
| method="cubic", |
| coordinate_transformation_mode="half_pixel", |
| rounding_method="round", |
| cubic_alpha=-0.75, |
| cubic_exclude=0, |
| extrapolation_value=0.0, |
| out_dtype="void", |
| ) |
| gv: R.Tuple(R.Tensor((1, 3, 224, 224), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 112, 112, dtype=torch.float32),) |
| verify_model(InterpolateBilinear(), example_args, {}, expected_bilinear) |
| verify_model(InterpolateNearest(), example_args, {}, expected_nearest) |
| verify_model(InterpolateBicubic(), example_args, {}, expected_bicubic) |
| |
| |
| def test_mean(): |
| class Mean(Module): |
| def forward(self, input): |
| return input.mean(-1) |
| |
| class MeanKeepDim(Module): |
| def forward(self, input: torch.Tensor): |
| return input.mean(-1, keepdim=True) |
| |
| @I.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((256,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((256,), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=False) |
| gv: R.Tuple(R.Tensor((256,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @I.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((256, 1), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((256, 1), dtype="float32") = R.mean(inp_0, axis=[-1], keepdims=True) |
| gv: R.Tuple(R.Tensor((256, 1), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(256, 256, dtype=torch.float32),) |
| verify_model(Mean(), example_args, {}, Expected1) |
| verify_model(MeanKeepDim(), example_args, {}, Expected2) |
| |
| |
| def test_sum(): |
| class Sum(Module): |
| def forward(self, x): |
| return torch.sum(x, (2, 1)) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 4), dtype="float32") = R.sum(inp_0, axis=[2, 1], keepdims=False) |
| gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Sum(), example_args, {}, expected1) |
| |
| |
| def test_argmax_argmin(): |
| example_args = (torch.randn(256, 256, dtype=torch.float32),) |
| |
| class Argmax1(Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.argmax(input, dim=-1) |
| |
| class Argmax2(Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.argmax(input, dim=-1, keepdim=True) |
| |
| @tvm.script.ir_module |
| class expected_argmax1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((256,), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((256,), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=False) |
| gv: R.Tuple(R.Tensor((256,), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_argmax2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((256, 1), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((256, 1), dtype="int64") = R.argmax(inp_0, axis=-1, keepdims=True) |
| gv: R.Tuple(R.Tensor((256, 1), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Argmax1(), example_args, {}, expected_argmax1) |
| verify_model(Argmax2(), example_args, {}, expected_argmax2) |
| |
| class Argmin1(Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.argmin(input) |
| |
| class Argmin2(Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, input): |
| return torch.argmin(input, keepdim=True) |
| |
| @tvm.script.ir_module |
| class expected_argmin1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected_argmin2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 1), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 1), dtype="int64") = R.argmin(inp_0, axis=None, keepdims=True) |
| gv: R.Tuple(R.Tensor((1, 1), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Argmin1(), example_args, {}, expected_argmin1) |
| verify_model(Argmin2(), example_args, {}, expected_argmin2) |
| |
| |
| def test_cat_concat(): |
| class Cat0(Module): |
| def forward(self, x, y): |
| return torch.cat((x, y)) |
| |
| class Cat1(Module): |
| def forward(self, x, y): |
| return torch.cat((x, y), dim=1) |
| |
| class Cat2(Module): |
| def forward(self, x, y): |
| return torch.cat((x, y), 1) |
| |
| class Cat3(Module): |
| def forward(self, x, y): |
| return torch.concat((x, y), dim=0) |
| |
| @I.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, inp_1), axis=0) |
| gv: R.Tuple(R.Tensor((4, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @I.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 6), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, inp_1), axis=1) |
| gv: R.Tuple(R.Tensor((2, 6), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) |
| verify_model(Cat0(), example_args, {}, Expected1) |
| verify_model(Cat1(), example_args, {}, Expected2) |
| verify_model(Cat2(), example_args, {}, Expected2) |
| verify_model(Cat3(), example_args, {}, Expected1) |
| |
| |
| def test_cumsum(): |
| class Cumsum(Module): |
| def forward(self, input): |
| return torch.cumsum(input, dim=1, dtype=torch.int32) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="int32") = R.cumsum(input_1, axis=1, dtype="int32") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="int32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Cumsum(), example_args, {}, expected1) |
| |
| |
| def test_expand(): |
| class Expand1(Module): |
| def forward(self, x): |
| return x.expand(4, 2, 3, 4) |
| |
| class Expand2(Module): |
| def forward(self, x): |
| return x.expand(4, -1, -1, 4) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((4, 2, 3, 4), dtype="float32") = R.broadcast_to(x, (4, 2, 3, 4)) |
| gv: R.Tuple(R.Tensor((4, 2, 3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Expand1(), example_args, {}, expected1) |
| verify_model(Expand2(), example_args, {}, expected1) |
| |
| |
| def test_flatten(): |
| class Flatten(Module): |
| def __init__(self): |
| super().__init__() |
| self.f = torch.nn.Flatten(2, -1) |
| |
| def forward(self, input): |
| return self.f(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 100), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 100), dtype="float32") = R.reshape(input_1, (1, 3, 100)) |
| gv: R.Tuple(R.Tensor((1, 3, 100), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Flatten(), example_args, {}, expected1) |
| |
| |
| def test_meshgrid(): |
| class Meshgrid1(Module): |
| def forward(self, input1, input2): |
| return torch.meshgrid((input1, input2), indexing="ij") |
| |
| class Meshgrid2(Module): |
| def forward(self, input1, input2): |
| return torch.meshgrid((input1, input2), indexing="xy") |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") |
| ) = R.meshgrid((input1, input2), indexing="ij") |
| lv1: R.Tensor((3, 3), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 3), dtype="float32") = lv[1] |
| gv: R.Tuple( |
| R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") |
| ) = (lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") |
| ) = R.meshgrid((input1, input2), indexing="xy") |
| lv1: R.Tensor((3, 3), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 3), dtype="float32") = lv[1] |
| gv: R.Tuple( |
| R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32") |
| ) = (lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(3, dtype=torch.float32), |
| torch.randn(3, dtype=torch.float32), |
| ) |
| verify_model(Meshgrid1(), example_args, {}, expected1) |
| verify_model(Meshgrid2(), example_args, {}, expected2) |
| |
| |
| def test_permute(): |
| class Permute1(Module): |
| def forward(self, x): |
| return x.permute(0, 3, 2, 1) |
| |
| class Permute2(Module): |
| def forward(self, x): |
| return torch.permute(x, (0, 3, 2, 1)) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) |
| gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Permute1(), example_args, {}, expected1) |
| verify_model(Permute2(), example_args, {}, expected1) |
| |
| |
| def test_repeat(): |
| class Tile1(Module): |
| def forward(self, x: torch.Tensor): |
| return x.repeat(2) |
| |
| class Tile2(Module): |
| def forward(self, x: torch.Tensor): |
| return x.repeat(4, 2) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main(x: R.Tensor((3,), dtype="float32")) -> R.Tuple(R.Tensor((6,), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) |
| gv: R.Tuple(R.Tensor((6,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) |
| gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(3, dtype=torch.float32),) |
| verify_model(Tile1(), example_args, {}, expected1) |
| |
| example_args = (torch.randn(1, 3, dtype=torch.float32),) |
| verify_model(Tile2(), example_args, {}, expected2) |
| |
| example_args = (torch.randn(1, 3, dtype=torch.float32),) |
| verify_model(Tile2(), example_args, {}, expected2) |
| |
| |
| def test_reshape(): |
| class Reshape(Module): |
| def forward(self, x): |
| return x.reshape(2, 12) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) |
| gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Reshape(), example_args, {}, expected1) |
| |
| |
| def test_reshape_as(): |
| class ReshapeAs(Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| return x.reshape_as(y) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32"), |
| y: R.Tensor((2, 12), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) |
| gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(1, 2, 3, 4, dtype=torch.float32), |
| torch.randn(2, 12, dtype=torch.float32), |
| ) |
| verify_model(ReshapeAs(), example_args, {}, expected1) |
| |
| |
| def test_roll(): |
| class Roll1(Module): |
| def forward(self, x): |
| return torch.roll(x, 1) |
| |
| class Roll2(Module): |
| def forward(self, x): |
| return torch.roll(x, -1, 0) |
| |
| class Roll3(Module): |
| def forward(self, x): |
| return torch.roll(x, shifts=(2, 1), dims=(0, 1)) |
| |
| # Test case 1: torch.roll(x, 1) |
| @I.ir_module |
| class Expected1: |
| @R.function |
| def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8])) |
| lv1: R.Tensor((7,), dtype="int64") = R.strided_slice( |
| lv, |
| axes=[0], |
| begin=[R.prim_value(0)], |
| end=[R.prim_value(7)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv2: R.Tensor((1,), dtype="int64") = R.strided_slice( |
| lv, |
| axes=[0], |
| begin=[R.prim_value(7)], |
| end=[R.prim_value(8)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), axis=0) |
| lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, R.shape([4, 2])) |
| gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| # Test case 2: torch.roll(x, -1, 0) |
| @I.ir_module |
| class Expected2: |
| @R.function |
| def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice( |
| x, |
| axes=[0], |
| begin=[R.prim_value(0)], |
| end=[R.prim_value(1)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice( |
| x, |
| axes=[0], |
| begin=[R.prim_value(1)], |
| end=[R.prim_value(4)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) |
| gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,) |
| R.output(gv) |
| return gv |
| |
| # Test case 3: torch.roll(x, shifts=(2,1), dims=(0,1)) |
| @I.ir_module |
| class Expected3: |
| @R.function |
| def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 2), dtype="int64")): |
| with R.dataflow(): |
| # First roll along dim=0 with shift=2 |
| lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice( |
| x, |
| axes=[0], |
| begin=[R.prim_value(0)], |
| end=[R.prim_value(2)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice( |
| x, |
| axes=[0], |
| begin=[R.prim_value(2)], |
| end=[R.prim_value(4)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), axis=0) |
| |
| # Second roll along dim=1 with shift=1 |
| lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice( |
| lv2, |
| axes=[1], |
| begin=[R.prim_value(0)], |
| end=[R.prim_value(1)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice( |
| lv2, |
| axes=[1], |
| begin=[R.prim_value(1)], |
| end=[R.prim_value(2)], |
| strides=[R.prim_value(1)], |
| assume_inbound=False, |
| ) |
| lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), axis=1) |
| gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,) |
| R.output(gv) |
| return gv |
| |
| # Test inputs |
| example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64) |
| |
| # Run verification for each case |
| verify_model(Roll1(), (example_input,), {}, Expected1) |
| verify_model(Roll2(), (example_input,), {}, Expected2) |
| verify_model(Roll3(), (example_input,), {}, Expected3) |
| |
| |
| def test_select_slice(): |
| class Slice1(Module): |
| def forward(self, x): |
| return x[0, 1::2, :, :3] |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 10, 3), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((3, 10, 10), dtype="float32") = R.take(x, R.const(0, "int64"), axis=0) |
| lv1: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( |
| lv, |
| (R.prim_value(0),), |
| (R.prim_value(1),), |
| (R.prim_value(9223372036854775807),), |
| (R.prim_value(2),), |
| assume_inbound=False, |
| ) |
| lv2: R.Tensor((1, 10, 10), dtype="float32") = R.strided_slice( |
| lv1, |
| (R.prim_value(1),), |
| (R.prim_value(0),), |
| (R.prim_value(9223372036854775807),), |
| (R.prim_value(1),), |
| assume_inbound=False, |
| ) |
| lv3: R.Tensor((1, 10, 3), dtype="float32") = R.strided_slice( |
| lv2, |
| (R.prim_value(2),), |
| (R.prim_value(0),), |
| (R.prim_value(3),), |
| (R.prim_value(1),), |
| assume_inbound=False, |
| ) |
| gv: R.Tuple(R.Tensor((1, 10, 3), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| class Slice2(Module): |
| def forward(self, x): |
| return x[:, None, None, :, None] |
| |
| @I.ir_module |
| class expected2: |
| @R.function |
| def main( |
| x: R.Tensor((8, 16), dtype="float32") |
| ) -> R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((8, 16), dtype="float32") = R.strided_slice( |
| x, |
| (R.prim_value(0),), |
| (R.prim_value(0),), |
| (R.prim_value(9223372036854775807),), |
| (R.prim_value(1),), |
| assume_inbound=False, |
| ) |
| lv1: R.Tensor((8, 1, 16), dtype="float32") = R.expand_dims(lv, axis=[1]) |
| lv2: R.Tensor((8, 1, 1, 16), dtype="float32") = R.expand_dims(lv1, axis=[2]) |
| lv3: R.Tensor((8, 1, 1, 16), dtype="float32") = R.strided_slice( |
| lv2, |
| (R.prim_value(3),), |
| (R.prim_value(0),), |
| (R.prim_value(9223372036854775807),), |
| (R.prim_value(1),), |
| assume_inbound=False, |
| ) |
| lv4: R.Tensor((8, 1, 1, 16, 1), dtype="float32") = R.expand_dims(lv3, axis=[4]) |
| gv: R.Tuple(R.Tensor((8, 1, 1, 16, 1), dtype="float32")) = (lv4,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Slice1(), example_args, {}, expected1) |
| |
| example_args = (torch.randn(8, 16, dtype=torch.float32),) |
| verify_model(Slice2(), example_args, {}, expected2) |
| |
| |
| def test_slice_scatter(): |
| class SliceScatter1(Module): |
| def forward(self, input, src): |
| return torch.slice_scatter(input, src, dim=1, start=1, end=7, step=2) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| a: R.Tensor((8, 8, 10, 10), dtype="float32"), |
| b: R.Tensor((8, 3, 10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((8, 8, 10, 10), dtype="float32") = R.slice_scatter( |
| a, b, R.prim_value(1), R.prim_value(7), R.prim_value(2), axis=1 |
| ) |
| gv: R.Tuple(R.Tensor((8, 8, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class SliceScatter2(Module): |
| def forward(self, input, src): |
| return torch.slice_scatter(input, src, dim=0, start=0, end=6, step=1) |
| |
| @I.ir_module |
| class expected2: |
| @R.function |
| def main( |
| a: R.Tensor((8, 16), dtype="float32"), b: R.Tensor((6, 16), dtype="float32") |
| ) -> R.Tuple(R.Tensor((8, 16), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((8, 16), dtype="float32") = R.slice_scatter( |
| a, b, R.prim_value(0), R.prim_value(6), R.prim_value(1), axis=0 |
| ) |
| gv: R.Tuple(R.Tensor((8, 16), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), torch.randn(8, 3, 10, 10)) |
| verify_model(SliceScatter1(), example_args, {}, expected1) |
| |
| example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 16)) |
| verify_model(SliceScatter2(), example_args, {}, expected2) |
| |
| |
| def test_split(): |
| class Chunk(Module): |
| def forward(self, input): |
| return torch.chunk(input, 3, dim=1) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple( |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| ): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| ) = R.split(input_1, indices_or_sections=3, axis=1) |
| lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1] |
| lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2] |
| gv: R.Tuple( |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| R.Tensor((1, 1, 10, 10), dtype="float32"), |
| ) = (lv1, lv2, lv3) |
| R.output(gv) |
| return gv |
| |
| class Unbind1(Module): |
| def forward(self, data): |
| return torch.unbind(data) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((3, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| R.Tensor((1, 3, 10, 10), dtype="float32"), |
| ) = R.split(input_1, indices_or_sections=3, axis=0) |
| lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[0]) |
| lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1] |
| lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[0]) |
| lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2] |
| lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[0]) |
| lv7: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv2, lv4, lv6) |
| lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] |
| lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] |
| lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] |
| gv: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv8, lv9, lv10) |
| R.output(gv) |
| return gv |
| |
| class Unbind2(Module): |
| def forward(self, data): |
| return torch.unbind(data, dim=1) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((3, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| R.Tensor((3, 1, 10, 10), dtype="float32"), |
| ) = R.split(input_1, indices_or_sections=3, axis=1) |
| lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, axis=[1]) |
| lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1] |
| lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, axis=[1]) |
| lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2] |
| lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, axis=[1]) |
| lv7: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv2, lv4, lv6) |
| lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0] |
| lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1] |
| lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2] |
| gv: R.Tuple( |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| R.Tensor((3, 10, 10), dtype="float32"), |
| ) = (lv8, lv9, lv10) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Chunk(), example_args, {}, Expected) |
| |
| example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),) |
| verify_model(Unbind1(), example_args, {}, expected1) |
| verify_model(Unbind2(), example_args, {}, expected2) |
| |
| |
| def test_squeeze(): |
| class Squeeze1(Module): |
| def forward(self, input): |
| return input.squeeze(1) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 4, 1), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[1]) |
| gv: R.Tuple(R.Tensor((3, 4, 1), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class Squeeze2(Module): |
| def forward(self, input): |
| return input.squeeze() |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, axis=None) |
| gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) |
| |
| verify_model(Squeeze1(), example_args, {}, Expected1) |
| verify_model(Squeeze2(), example_args, {}, Expected2) |
| |
| |
| def test_stack(): |
| class Stack0(Module): |
| def forward(self, x, y): |
| return torch.stack((x, y)) # default dim=0 |
| |
| class Stack1(Module): |
| def forward(self, x, y): |
| return torch.stack((x, y), dim=1) |
| |
| class Stack2(Module): |
| def forward(self, x, y): |
| return torch.stack((x, y), 1) # positional dim |
| |
| class Stack3(Module): |
| def forward(self, x, y): |
| return torch.stack((x, y), dim=-1) # negative dim |
| |
| @I.ir_module |
| class Expected0: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=0) |
| gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @I.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 2, 3), dtype="float32") = R.stack((inp_0, inp_1), axis=1) |
| gv: R.Tuple(R.Tensor((2, 2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @I.ir_module |
| class Expected3: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 3, 2), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3, 2), dtype="float32") = R.stack((inp_0, inp_1), axis=-1) |
| gv: R.Tuple(R.Tensor((2, 3, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, dtype=torch.float32)) |
| |
| verify_model(Stack0(), example_args, {}, Expected0) |
| verify_model(Stack1(), example_args, {}, Expected1) |
| verify_model(Stack2(), example_args, {}, Expected1) |
| verify_model(Stack3(), example_args, {}, Expected3) |
| |
| |
| def test_tile(): |
| class Tile1(Module): |
| def forward(self, x): |
| return x.tile((2,)) |
| |
| class Tile2(Module): |
| def forward(self, x): |
| return x.tile(4, 2) |
| |
| class Tile3(Module): |
| def forward(self, x): |
| return torch.tile(x, (4, 2)) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 6), dtype="float32") = R.tile(x, [2]) |
| gv: R.Tuple(R.Tensor((1, 6), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| x: R.Tensor((1, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((4, 6), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) |
| gv: R.Tuple(R.Tensor((4, 6), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, dtype=torch.float32),) |
| verify_model(Tile1(), example_args, {}, expected1) |
| verify_model(Tile2(), example_args, {}, expected2) |
| verify_model(Tile3(), example_args, {}, expected2) |
| |
| |
| def test_transpose(): |
| class Transpose(Module): |
| def forward(self, x): |
| return x.transpose(1, 3) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 3, 2), dtype="float32") = R.permute_dims(x, axes=[0, 3, 2, 1]) |
| gv: R.Tuple(R.Tensor((1, 4, 3, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(Transpose(), example_args, {}, expected1) |
| |
| |
| def test_unsqueeze(): |
| class Unsqueeze1(Module): |
| def forward(self, input): |
| return input.unsqueeze(1) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 1, 3, 10, 10), dtype="float32") = R.expand_dims(input_1, 1) |
| gv: R.Tuple(R.Tensor((1, 1, 3, 10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class Unsqueeze2(Module): |
| def forward(self, input): |
| return input.unsqueeze(-1) |
| |
| @tvm.script.ir_module |
| class expected2: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 10, 10, 1), dtype="float32") = R.expand_dims(input_1, -1) |
| gv: R.Tuple(R.Tensor((1, 3, 10, 10, 1), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| |
| verify_model(Unsqueeze1(), example_args, {}, expected1) |
| verify_model(Unsqueeze2(), example_args, {}, expected2) |
| |
| |
| def test_view(): |
| class View(Module): |
| def forward(self, x): |
| return x.view(2, 12) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) |
| gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(View(), example_args, {}, expected1) |
| |
| |
| def test_arange(): |
| class Arange(Module): |
| def forward(self, input): |
| return torch.arange(0, 20, dtype=torch.int32) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((20,), dtype="int32")): |
| with R.dataflow(): |
| lv: R.Tensor((20,), dtype="int32") = R.arange(0, 20, 1, dtype="int32") |
| gv: R.Tuple(R.Tensor((20,), dtype="int32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(Arange(), example_args, {}, Expected) |
| |
| |
| def test_hamming_window(): |
| class HammingWindow(Module): |
| def forward(self, input): |
| return torch.hamming_window(20, True, dtype=torch.float32) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((20,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((20,), dtype="float32") = R.hamming_window( |
| R.prim_value(20), |
| R.prim_value(1), |
| R.prim_value(T.float32(0.54000000000000004)), |
| R.prim_value(T.float32(0.46000000000000002)), |
| dtype="float32", |
| ) |
| gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(HammingWindow(), example_args, {}, Expected) |
| |
| |
| def test_contiguous(): |
| class Contiguous(Module): |
| def forward(self, input): |
| return input.contiguous() |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((10, 10), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(Contiguous(), example_args, {}, Expected) |
| |
| |
| def test_clone(): |
| class Clone(Module): |
| def forward(self, input): |
| return torch.clone(input) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (input,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(Clone(), example_args, {}, Expected) |
| |
| |
| def test_empty(): |
| class Empty(Module): |
| def forward(self, input): |
| return torch.empty((10, 10), dtype=torch.float32) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.zeros( |
| R.shape([10, 10]), dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(Empty(), example_args, {}, Expected) |
| |
| |
| def test_fill(): |
| class Fill(Module): |
| def forward(self, input: torch.Tensor): |
| return torch.fill(input, 1.5) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((10, 10), dtype="float32") |
| ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype="float32") = R.full( |
| R.shape([10, 10]), R.const(1.5, "float32"), dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(10, 10, dtype=torch.float32),) |
| verify_model(Fill(), example_args, {}, Expected) |
| |
| |
| def test_fill_inplace(): |
| class FillInplace(Module): |
| def forward(self, input: torch.Tensor): |
| input.fill_(42.0) |
| return input |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.full( |
| R.shape([2, 3]), R.const(42.0, "float32"), dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 3, dtype=torch.float32),) |
| verify_model(FillInplace(), example_args, {}, Expected) |
| |
| |
| def test_masked_fill(): |
| class Masked_Fill(Module): |
| def forward(self, input: torch.Tensor, mask: torch.Tensor): |
| return torch.masked_fill(input, mask, 0) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float32") = R.full_like( |
| input, R.const(0, "int32"), dtype="void" |
| ) |
| lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5) |
| verify_model(Masked_Fill(), example_args, {}, Expected) |
| |
| |
| def test_masked_fill_inplace(): |
| class Masked_Fill_Inplace(Module): |
| def forward(self, input: torch.Tensor, mask: torch.Tensor): |
| return input.masked_fill_(mask, 1.5) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32"), mask: R.Tensor((128, 128), dtype="bool") |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float32") = R.full_like( |
| input, R.const(1.5, "float32"), dtype="void" |
| ) |
| lv1: R.Tensor((128, 128), dtype="float32") = R.where(mask, lv, input) |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(128, 128, dtype=torch.float32), torch.rand(128, 128) < 0.5) |
| verify_model(Masked_Fill_Inplace(), example_args, {}, Expected) |
| |
| |
| def test_new_ones(): |
| class NewOnes(Module): |
| def forward(self, x): |
| return x.new_ones(1, 2, 3) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3), dtype="float32") = R.full( |
| (1, 2, 3), R.const(1, "float32"), dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((1, 2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, dtype=torch.float32),) |
| verify_model(NewOnes(), example_args, {}, expected1) |
| |
| |
| def test_new_zeros(): |
| class NewZeros(torch.nn.Module): |
| def forward(self, x): |
| return x.new_zeros(1, 128, 128) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| x: R.Tensor((1, 128, 128), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 128, 128), dtype="float32") = R.full( |
| R.shape([1, 128, 128]), R.const(0, "float32"), dtype="float32" |
| ) |
| gv: R.Tuple(R.Tensor((1, 128, 128), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 128, 128, dtype=torch.float32),) |
| verify_model(NewZeros(), example_args, {}, expected1) |
| |
| |
| def test_to_copy(): |
| # float |
| class ToFloat(Module): |
| def forward(self, x): |
| return x.float() |
| |
| @tvm.script.ir_module |
| class expected_float: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # half |
| class ToHalf(Module): |
| def forward(self, x): |
| return x.half() |
| |
| @tvm.script.ir_module |
| class expected_half: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(x, dtype="float16") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # type |
| class Type(Module): |
| def forward(self, x): |
| return x.type(torch.float32) |
| |
| @tvm.script.ir_module |
| class expected_type: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): |
| # block 0 |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(x, dtype="float32") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class To1(Module): |
| def forward(self, input): |
| return input.to(torch.float16) |
| |
| @I.ir_module |
| class expected_to1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="float16") = R.astype(inp_0, dtype="float16") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float16")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class To2(Module): |
| def forward(self, input): |
| return input.to("cpu") |
| |
| @I.ir_module |
| class expected_to2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") |
| ) -> R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3, 4), dtype="float32") = R.astype(inp_0, dtype="float32") |
| gv: R.Tuple(R.Tensor((1, 2, 3, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) |
| verify_model(ToFloat(), example_args, {}, expected_float) |
| verify_model(ToHalf(), example_args, {}, expected_half) |
| verify_model(Type(), example_args, {}, expected_type) |
| verify_model(To1(), example_args, {}, expected_to1) |
| verify_model(To2(), example_args, {}, expected_to2) |
| |
| |
| def test_keep_params(): |
| class Conv2D1(Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) |
| |
| def forward(self, input): |
| return self.conv(input) |
| |
| @tvm.script.ir_module |
| class expected1: |
| @R.function |
| def main( |
| input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), |
| conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), |
| conv_bias: R.Tensor((6,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): |
| R.func_attr({"num_input": 1}) |
| # block 0 |
| with R.dataflow(): |
| lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( |
| input_1, |
| conv_weight, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| data_layout="NCHW", |
| kernel_layout="OIHW", |
| out_layout="NCHW", |
| out_dtype="float32", |
| ) |
| lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1]) |
| lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) |
| gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| from tvm.relax.frontend import detach_params |
| |
| example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) |
| model = Conv2D1() |
| exported_program = torch.export.export(model, example_args) |
| mod = from_exported_program(exported_program, keep_params_as_input=True) |
| mod, params = detach_params(mod) |
| tvm.ir.assert_structural_equal(mod, expected1) |
| func = mod["main"] |
| params = params["main"] |
| |
| assert len(params) == len(func.params) - 1 |
| for param_var, param_tensor in zip(func.params[1:], params): |
| assert tuple(x.value for x in param_var.struct_info.shape.values) == param_tensor.shape |
| assert param_var.struct_info.dtype == param_tensor.dtype |
| |
| tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) |
| tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) |
| |
| |
| def test_unwrap_unit_return_tuple(): |
| class Identity(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x): |
| return (x,) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32") |
| ) -> R.Tensor((256, 256), dtype="float32"): |
| with R.dataflow(): |
| gv: R.Tensor((256, 256), dtype="float32") = inp_0 |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(256, 256, dtype=torch.float32),) |
| exported_program = export(Identity(), args=example_args) |
| mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True) |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_no_bind_return_tuple(): |
| class Identity(Module): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y): |
| return (x, y) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((256, 256), dtype="float32"), |
| inp_1: R.Tensor((256, 256), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tensor((256, 256), dtype="float32") = inp_0 |
| gv1: R.Tensor((256, 256), dtype="float32") = inp_1 |
| R.output(gv, gv1) |
| return (gv, gv1) |
| |
| example_args = ( |
| torch.randn(256, 256, dtype=torch.float32), |
| torch.randn(256, 256, dtype=torch.float32), |
| ) |
| exported_program = export(Identity(), args=example_args) |
| mod = from_exported_program(exported_program, no_bind_return_tuple=True) |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_empty_like(): |
| class EmptyLike(Module): |
| def forward(self, data): |
| return torch.empty_like(data) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((5,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0, dtype="void") |
| gv: R.Tuple(R.Tensor((5,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, dtype=torch.float32),) |
| |
| verify_model(EmptyLike(), example_args, {}, Expected) |
| |
| |
| def test_one_hot(): |
| class OneHot(Module): |
| def forward(self, indices): |
| return torch.nn.functional.one_hot(indices, num_classes=10) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5,), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((5, 10), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 10), dtype="int64") = R.one_hot( |
| inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1 |
| ) |
| gv: R.Tuple(R.Tensor((5, 10), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randint(0, 10, (5,), dtype=torch.int64),) |
| |
| verify_model(OneHot(), example_args, {}, Expected) |
| |
| |
| def test_ones_like(): |
| class OnesLike(Module): |
| def forward(self, input): |
| return torch.ones_like(input) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32") |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float32") = R.ones_like(input, dtype="void") |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.rand(128, 128, dtype=torch.float32),) |
| |
| verify_model(OnesLike(), example_args, {}, Expected) |
| |
| |
| def test_zero_inplace(): |
| class ZeroInplace(Module): |
| def forward(self, input): |
| return input.zero_() |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32") |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.rand(128, 128, dtype=torch.float32),) |
| |
| verify_model(ZeroInplace(), example_args, {}, Expected) |
| |
| |
| def test_zeros(): |
| class Zeros(Module): |
| def forward(self, input): |
| return torch.zeros(5, 2) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 2), dtype="float32") = R.zeros(R.shape([5, 2]), dtype="float32") |
| gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.rand(128, 128, dtype=torch.float32),) |
| |
| verify_model(Zeros(), example_args, {}, Expected) |
| |
| |
| def test_zeros_like(): |
| class ZerosLike(Module): |
| def forward(self, input): |
| return torch.zeros_like(input) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32") |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float32") = R.zeros_like(input, dtype="void") |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.rand(128, 128, dtype=torch.float32),) |
| verify_model(ZerosLike(), example_args, {}, Expected) |
| |
| |
| def test_type_as(): |
| class TypeAs(Module): |
| def forward(self, input, other): |
| return input.type_as(other) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((128, 128), dtype="float32"), |
| other: R.Tensor((128, 128), dtype="float16"), |
| ) -> R.Tuple(R.Tensor((128, 128), dtype="float16")): |
| with R.dataflow(): |
| lv: R.Tensor((128, 128), dtype="float16") = R.astype(input, dtype="float16") |
| gv: R.Tuple(R.Tensor((128, 128), dtype="float16")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.rand(128, 128, dtype=torch.float32), |
| torch.rand(128, 128, dtype=torch.float16), |
| ) |
| |
| verify_model(TypeAs(), example_args, {}, Expected) |
| |
| |
| def test_select(): |
| class Select(Module): |
| def forward(self, input): |
| return torch.select(input, 0, 1) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((3,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3,), dtype="float32") = R.take(inp_0, R.const(1, "int64"), axis=0) |
| gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 3, dtype=torch.float32),) |
| |
| verify_model(Select(), example_args, {}, Expected) |
| |
| |
| def test_unflatten(): |
| class Unflatten(Module): |
| def forward(self, input): |
| return torch.ops.aten.unflatten(input, 1, (3, 5)) |
| |
| class Unflatten1(Module): |
| def forward(self, input): |
| return torch.ops.aten.unflatten(input, -2, (3, 5)) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 15, 7), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3, 5, 7), dtype="float32") = R.reshape(inp_0, [2, 3, 5, 7]) |
| gv: R.Tuple(R.Tensor((2, 3, 5, 7), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 15, 7, dtype=torch.float32),) |
| |
| verify_model(Unflatten(), example_args, {}, Expected) |
| verify_model(Unflatten1(), example_args, {}, Expected) |
| |
| |
| def test_gather(): |
| class Gather0(Module): |
| def forward(self, data, indices): |
| return torch.gather(data, 0, indices) |
| |
| class Gather1(Module): |
| def forward(self, data, indices): |
| return torch.gather(data, 1, indices) |
| |
| class Gather2(Module): |
| def forward(self, data, indices): |
| return torch.gather(data, -1, indices) |
| |
| class Gather3(Module): |
| def forward(self, data, indices): |
| return torch.gather(data, -2, indices) |
| |
| @tvm.script.ir_module |
| class Expected0: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=0) |
| gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=1) |
| gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-1) |
| gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected3: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 3), dtype="float32"), |
| inp_1: R.Tensor((2, 3), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.gather_elements(inp_0, inp_1, axis=-2) |
| gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(2, 3, dtype=torch.float32), |
| torch.randint(0, 3, (2, 3), dtype=torch.int64), |
| ) |
| |
| verify_model(Gather0(), example_args, {}, Expected0) |
| verify_model(Gather1(), example_args, {}, Expected1) |
| verify_model(Gather2(), example_args, {}, Expected2) |
| verify_model(Gather3(), example_args, {}, Expected3) |
| |
| |
| def test_index_put(): |
| # Test case 1: 1D input |
| class IndexPut1D(Module): |
| def forward(self, data, indices_0, values): |
| indices_tuple = (indices_0,) |
| return data.index_put_(indices_tuple, values, accumulate=False) |
| |
| example_args_1d = ( |
| torch.randn(64, dtype=torch.float32), |
| torch.randint(0, 64, (128,), dtype=torch.int64), |
| torch.randn(128, dtype=torch.float32), |
| ) |
| |
| @I.ir_module |
| class Expected1D: |
| @R.function |
| def main( |
| data: R.Tensor((64,), dtype="float32"), |
| indices_0: R.Tensor((128,), dtype="int64"), |
| values: R.Tensor((128,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((64,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((64,), dtype="float32") = R.index_put( |
| data, R.tuple(indices_0), values, accumulate=False |
| ) |
| gv: R.Tuple(R.Tensor((64,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Test case 2: 2D input |
| class IndexPut2D(Module): |
| def forward(self, data, indices_0, indices_1, values): |
| indices_tuple = (indices_0, indices_1) |
| return data.index_put_(indices_tuple, values, accumulate=False) |
| |
| example_args_2d = ( |
| torch.randn(32, 64, dtype=torch.float32), |
| torch.randint(0, 32, (128,), dtype=torch.int64), |
| torch.randint(0, 64, (128,), dtype=torch.int64), |
| torch.randn(128, dtype=torch.float32), |
| ) |
| |
| @I.ir_module |
| class Expected2D: |
| @R.function |
| def main( |
| data: R.Tensor((32, 64), dtype="float32"), |
| indices_0: R.Tensor((128,), dtype="int64"), |
| indices_1: R.Tensor((128,), dtype="int64"), |
| values: R.Tensor((128,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((32, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((32, 64), dtype="float32") = R.index_put( |
| data, R.tuple(indices_0, indices_1), values, accumulate=False |
| ) |
| gv: R.Tuple(R.Tensor((32, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Test case 3: 3D input |
| class IndexPut3D(Module): |
| def forward(self, data, indices_0, indices_1, indices_2, values): |
| indices_tuple = (indices_0, indices_1, indices_2) |
| return data.index_put_(indices_tuple, values, accumulate=False) |
| |
| example_args_3d = ( |
| torch.randn(16, 32, 64, dtype=torch.float32), |
| torch.randint(0, 16, (128,), dtype=torch.int64), |
| torch.randint(0, 32, (128,), dtype=torch.int64), |
| torch.randint(0, 64, (128,), dtype=torch.int64), |
| torch.randn(128, dtype=torch.float32), |
| ) |
| |
| @I.ir_module |
| class Expected3D: |
| @R.function |
| def main( |
| data: R.Tensor((16, 32, 64), dtype="float32"), |
| indices_0: R.Tensor((128,), dtype="int64"), |
| indices_1: R.Tensor((128,), dtype="int64"), |
| indices_2: R.Tensor((128,), dtype="int64"), |
| values: R.Tensor((128,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((16, 32, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((16, 32, 64), dtype="float32") = R.index_put( |
| data, R.tuple(indices_0, indices_1, indices_2), values, accumulate=False |
| ) |
| gv: R.Tuple(R.Tensor((16, 32, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Test case 4: 4D input |
| class IndexPut4D(Module): |
| def forward(self, data, indices_0, indices_1, indices_2, indices_3, values): |
| indices_tuple = (indices_0, indices_1, indices_2, indices_3) |
| return data.index_put_(indices_tuple, values, accumulate=False) |
| |
| example_args_4d = ( |
| torch.randn(8, 16, 32, 64, dtype=torch.float32), |
| torch.randint(0, 8, (128,), dtype=torch.int64), |
| torch.randint(0, 16, (128,), dtype=torch.int64), |
| torch.randint(0, 32, (128,), dtype=torch.int64), |
| torch.randint(0, 64, (128,), dtype=torch.int64), |
| torch.randn(128, dtype=torch.float32), |
| ) |
| |
| @I.ir_module |
| class Expected4D: |
| @R.function |
| def main( |
| data: R.Tensor((8, 16, 32, 64), dtype="float32"), |
| indices_0: R.Tensor((128,), dtype="int64"), |
| indices_1: R.Tensor((128,), dtype="int64"), |
| indices_2: R.Tensor((128,), dtype="int64"), |
| indices_3: R.Tensor((128,), dtype="int64"), |
| values: R.Tensor((128,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((8, 16, 32, 64), dtype="float32") = R.index_put( |
| data, |
| R.tuple(indices_0, indices_1, indices_2, indices_3), |
| values, |
| accumulate=False, |
| ) |
| gv: R.Tuple(R.Tensor((8, 16, 32, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Test case 5: 5D input |
| class IndexPut5D(Module): |
| def forward(self, data, indices_0, indices_1, indices_2, indices_3, indices_4, values): |
| indices_tuple = (indices_0, indices_1, indices_2, indices_3, indices_4) |
| return data.index_put_(indices_tuple, values, accumulate=False) |
| |
| example_args_5d = ( |
| torch.randn(4, 8, 16, 32, 64, dtype=torch.float32), |
| torch.randint(0, 4, (128,), dtype=torch.int64), |
| torch.randint(0, 8, (128,), dtype=torch.int64), |
| torch.randint(0, 16, (128,), dtype=torch.int64), |
| torch.randint(0, 32, (128,), dtype=torch.int64), |
| torch.randint(0, 64, (128,), dtype=torch.int64), |
| torch.randn(128, dtype=torch.float32), |
| ) |
| |
| @I.ir_module |
| class Expected5D: |
| @R.function |
| def main( |
| data: R.Tensor((4, 8, 16, 32, 64), dtype="float32"), |
| indices_0: R.Tensor((128,), dtype="int64"), |
| indices_1: R.Tensor((128,), dtype="int64"), |
| indices_2: R.Tensor((128,), dtype="int64"), |
| indices_3: R.Tensor((128,), dtype="int64"), |
| indices_4: R.Tensor((128,), dtype="int64"), |
| values: R.Tensor((128,), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 8, 16, 32, 64), dtype="float32") = R.index_put( |
| data, |
| R.tuple(indices_0, indices_1, indices_2, indices_3, indices_4), |
| values, |
| accumulate=False, |
| ) |
| gv: R.Tuple(R.Tensor((4, 8, 16, 32, 64), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| # Run verification for each case |
| verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) |
| verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) |
| verify_model(IndexPut3D(), example_args_3d, {}, Expected3D) |
| verify_model(IndexPut4D(), example_args_4d, {}, Expected4D) |
| verify_model(IndexPut5D(), example_args_5d, {}, Expected5D) |
| |
| |
| def test_flip(): |
| class Flip0(Module): |
| def forward(self, data): |
| return torch.flip(data, [0]) |
| |
| class Flip1(Module): |
| def forward(self, data): |
| return torch.flip(data, [1]) |
| |
| @tvm.script.ir_module |
| class Expected0: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=0) |
| gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.flip(inp_0, axis=1) |
| gv: R.Tuple(R.Tensor((2, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 2, dtype=torch.float32),) |
| |
| verify_model(Flip0(), example_args, {}, Expected0) |
| verify_model(Flip1(), example_args, {}, Expected1) |
| |
| |
| def test_take(): |
| class Take(Module): |
| def forward(self, data, indices): |
| return torch.take(data, indices) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5,), dtype="float32"), |
| inp_1: R.Tensor((3,), dtype="int64"), |
| ) -> R.Tuple(R.Tensor((3,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3,), dtype="int32") = R.astype(inp_1, dtype="int32") |
| lv1: R.Tensor((3,), dtype="float32") = R.take(inp_0, lv, axis=None) |
| gv: R.Tuple(R.Tensor((3,), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args = ( |
| torch.randn(5, dtype=torch.float32), |
| torch.randint(0, 5, (3,), dtype=torch.int64), |
| ) |
| |
| verify_model(Take(), example_args, {}, Expected) |
| |
| |
| def test_std(): |
| class Std(Module): |
| def forward(self, x): |
| return torch.std(x) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.std(inp_0, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Std(), example_args, {}, Expected) |
| |
| |
| def test_var(): |
| class Var(Module): |
| def forward(self, x): |
| return torch.var(x) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.variance(inp_0, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Var(), example_args, {}, Expected) |
| |
| |
| def test_prod(): |
| class Prod(Module): |
| def forward(self, x): |
| return torch.prod(x) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.prod(inp_0, axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Prod(), example_args, {}, Expected) |
| |
| |
| def test_cumprod(): |
| class Cumprod(Module): |
| def forward(self, x): |
| return torch.cumprod(x, 0) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 3), dtype="float32") = R.cumprod(inp_0, axis=0, exclusive=False) |
| gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_input = torch.randn(5, 3, dtype=torch.float32) |
| verify_model(Cumprod(), (example_input,), {}, Expected) |
| |
| |
| def test_where(): |
| class Where(Module): |
| def forward(self, condition, x, y): |
| return torch.where(condition, x, y) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| inp_0: R.Tensor((5, 3), dtype="bool"), |
| inp_1: R.Tensor((5, 3), dtype="float32"), |
| inp_2: R.Tensor((5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 3), dtype="float32") = R.where(inp_0, inp_1, inp_2) |
| gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| condition = torch.randint(0, 2, (5, 3), dtype=torch.bool) |
| x = torch.randn(5, 3, dtype=torch.float32) |
| y = torch.randn(5, 3, dtype=torch.float32) |
| |
| verify_model(Where(), (condition, x, y), {}, Expected) |
| |
| |
| def test_bucketize(): |
| class Bucketize(Module): |
| def forward(self, input_tensor, boundaries): |
| return torch.bucketize(input_tensor, boundaries) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((20,), dtype="int64"), boundaries: R.Tensor((10,), dtype="int64") |
| ) -> R.Tuple(R.Tensor((20,), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tensor((20,), dtype="int64") = R.bucketize( |
| input, boundaries, out_int32=False, right=False |
| ) |
| gv: R.Tuple(R.Tensor((20,), dtype="int64")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| input_tensor = torch.arange(0, 20) |
| boundaries = torch.arange(0, 20, 2) |
| |
| verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected) |
| |
| |
| def test_argsort(): |
| class Argsort(Module): |
| def forward(self, x): |
| return torch.argsort(x, dim=1, descending=True) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype="int32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 3), dtype="int32") = R.argsort( |
| x, axis=1, descending=True, dtype="int32" |
| ) |
| gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Argsort(), example_args, {}, Expected) |
| |
| |
| def test_topk(): |
| class Topk(Module): |
| def forward(self, x): |
| return torch.topk(x, k=2, dim=1, largest=True, sorted=True) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((5, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")): |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64") |
| ) = R.topk(x, k=2, axis=1, ret_type="both", largest=True, dtype="int64") |
| lv1: R.Tensor((5, 2), dtype="float32") = lv[0] |
| lv2: R.Tensor((5, 2), dtype="int64") = lv[1] |
| gv: R.Tuple(R.Tensor((5, 2), dtype="float32"), R.Tensor((5, 2), dtype="int64")) = ( |
| lv1, |
| lv2, |
| ) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Topk(), example_args, {}, Expected) |
| |
| |
| def test_dynamic_shape(): |
| class DynamicModel(torch.nn.Module): |
| def forward(self, x1, x2): |
| return torch.ops.aten.add.Tensor(x1, x2) |
| |
| B = tvm.tir.SizeVar("BatchSize", dtype="int64") |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| lhs: R.Tensor((B, 4), dtype="float32"), |
| rhs: R.Tensor((B, 4), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs) |
| gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(2, 4), torch.randn(2, 4)) |
| batch = torch.export.Dim("batch") |
| dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} |
| |
| verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) |
| |
| |
| def test_broadcast_to(): |
| class BroadcastTo(Module): |
| def forward(self, x): |
| return torch.broadcast_to(x, (5, 3)) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((5, 1), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 3), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(x, R.shape([5, 3])) |
| gv: R.Tuple(R.Tensor((5, 3), dtype="float32")) = (lv,) |
| R.output(gv) |
| |
| return gv |
| |
| example_args = (torch.randn(5, 1, dtype=torch.float32),) |
| verify_model(BroadcastTo(), example_args, {}, Expected) |
| |
| |
| def test_narrow(): |
| class Narrow(Module): |
| def forward(self, x): |
| return torch.narrow(x, 1, 0, 2) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((5, 3), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 2), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice( |
| x, |
| (R.prim_value(1),), |
| (R.prim_value(0),), |
| (R.prim_value(2),), |
| assume_inbound=False, |
| ) |
| gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) |
| R.output(gv) |
| |
| return gv |
| |
| example_args = (torch.randn(5, 3, dtype=torch.float32),) |
| verify_model(Narrow(), example_args, {}, Expected) |
| |
| |
| def test_item(): |
| class Item(Module): |
| def forward(self, x): |
| return x.item() |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.take(input, R.const(0, "int64"), axis=0) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(1, dtype=torch.float32),) |
| verify_model(Item(), example_args, {}, Expected) |
| |
| |
| def test_norm(): |
| class Norm(Module): |
| def __init__(self, p, dim=None, keepdim=False): |
| super().__init__() |
| self.p = p |
| self.dim = dim |
| self.keepdim = keepdim |
| |
| def forward(self, x): |
| return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected3: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) |
| lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32")) |
| lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) |
| lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32")) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected4: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) |
| lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32")) |
| lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) |
| lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32")) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected5: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) |
| lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4.0, "float32")) |
| lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) |
| lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power( |
| lv2, R.const(-0.25, "float32") |
| ) |
| gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| @tvm.script.ir_module |
| class Expected6: |
| @R.function |
| def main( |
| inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) |
| lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32")) |
| lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.sum(lv1, axis=None, keepdims=True) |
| lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.power(lv2, R.const(2.0, "float32")) |
| gv: R.Tuple(R.Tensor((1, 1, 1, 1), dtype="float32")) = (lv3,) |
| R.output(gv) |
| return gv |
| |
| norms = [ |
| ((float("inf"), None, False), Expected1), |
| ((float("-inf"), None, False), Expected2), |
| ((float(2), None, False), Expected3), |
| ((float(1.0), None, False), Expected4), |
| ((float(-4), None, True), Expected5), |
| ((float(0.5), None, True), Expected6), |
| ] |
| |
| example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),) |
| |
| for (p, dim, keepdim), expected in norms: |
| verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected) |
| |
| |
| def test_eye(): |
| class Eye1(Module): |
| def forward(self, input): |
| return torch.eye(3, 5, dtype=torch.float32) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main( |
| input: R.Tensor((3, 5), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") |
| gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| class Eye2(Module): |
| def forward(self, input): |
| return torch.eye(5, dtype=torch.float32) |
| |
| @tvm.script.ir_module |
| class Expected2: |
| @R.function |
| def main( |
| input: R.Tensor((5,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") |
| gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args1 = (torch.randn(3, 5, dtype=torch.float32),) |
| verify_model(Eye1(), example_args1, {}, Expected1) |
| |
| example_args2 = (torch.randn(5, dtype=torch.float32),) |
| verify_model(Eye2(), example_args2, {}, Expected2) |
| |
| |
| def test_cross_entropy(): |
| class CrossEntropyModule(Module): |
| def __init__(self): |
| super().__init__() |
| self.criterion = nn.CrossEntropyLoss() |
| self.target = torch.tensor([0, 1, 2, 1]) |
| |
| def forward(self, x): |
| return self.criterion(x, self.target) |
| |
| @tvm.script.ir_module |
| class Expected1: |
| @R.function |
| def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) |
| lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( |
| lv, |
| targets=R.const([0, 1, 2, 1], dtype="int64"), |
| reduction="mean", |
| ignore_index=-100, |
| ) |
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) |
| R.output(gv) |
| return gv |
| |
| example_args1 = (torch.randn(4, 3, dtype=torch.float32),) |
| verify_model(CrossEntropyModule(), example_args1, {}, Expected1) |
| |
| |
| def test_linspace(): |
| class Linspace(Module): |
| def forward(self, input): |
| return torch.linspace(0, 1, steps=9, dtype=torch.float32) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((9, 9), dtype="float32") |
| ) -> R.Tuple(R.Tensor((9,), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") |
| gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| example_args = (torch.randn(9, 9, dtype=torch.float32),) |
| verify_model(Linspace(), example_args, {}, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "torch_dtype, relax_dtype", |
| [ |
| (torch.float32, "float32"), |
| (torch.float16, "float16"), |
| (torch.bfloat16, "bfloat16"), |
| (torch.int64, "int64"), |
| (torch.int32, "int32"), |
| (torch.bool, "bool"), |
| ], |
| ) |
| def test_dtypes(torch_dtype, relax_dtype): |
| example_args = ( |
| torch.randint(0, 10, (10, 10)).to(torch_dtype), |
| torch.randint(0, 10, (10, 10)).to(torch_dtype), |
| ) |
| |
| class Model(Module): |
| def forward(self, lhs: torch.Tensor, rhs: torch.Tensor): |
| return torch.ops.aten.add(lhs, rhs) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| lhs: R.Tensor((10, 10), dtype=relax_dtype), |
| rhs: R.Tensor((10, 10), dtype=relax_dtype), |
| ) -> R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)): |
| with R.dataflow(): |
| lv: R.Tensor((10, 10), dtype=relax_dtype) = relax.op.add(lhs, rhs) |
| gv: R.Tuple(R.Tensor((10, 10), dtype=relax_dtype)) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(Model(), example_args, {}, Expected) |
| |
| |
| def test_mm(): |
| class MatrixMultiply(Module): |
| def forward(self, a, b): |
| return torch.mm(a, b) |
| |
| example_args = ( |
| torch.randn(2, 3, dtype=torch.float32), |
| torch.randn(3, 4, dtype=torch.float32), |
| ) |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| a: R.Tensor((2, 3), dtype="float32"), |
| b: R.Tensor((3, 4), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): |
| with R.dataflow(): |
| lv: R.Tensor((2, 4), dtype="float32") = R.matmul(a, b, out_dtype="float32") |
| gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (lv,) |
| R.output(gv) |
| return gv |
| |
| verify_model(MatrixMultiply(), example_args, {}, Expected) |
| |
| |
| def test_lstm(): |
| class BasicLSTM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nn.LSTM( |
| input_size=4, |
| hidden_size=8, |
| num_layers=1, |
| batch_first=True, |
| bidirectional=False, |
| ) |
| |
| def forward(self, x): |
| y, _ = self.lstm(x) |
| return y |
| |
| torch.manual_seed(42) |
| x = torch.randn(2, 3, 4, dtype=torch.float32) |
| model = BasicLSTM() |
| with torch.no_grad(): |
| pytorch_output = model(x) |
| exported_program = export(model, args=(x,)) |
| mod = from_exported_program(exported_program) |
| target = tvm.target.Target("llvm") |
| ex = relax.build(mod, target) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_tvm = tvm.runtime.tensor(x.numpy()) |
| tvm_output = vm["main"](x_tvm) |
| if hasattr(tvm_output, "numpy"): |
| tvm_output_np = tvm_output.numpy() |
| else: |
| tvm_output_np = tvm_output[0].numpy() |
| assert ( |
| pytorch_output.shape == tvm_output_np.shape |
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" |
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) |
| |
| class SeqFirstLSTM(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nn.LSTM( |
| input_size=3, |
| hidden_size=6, |
| num_layers=1, |
| batch_first=False, |
| bidirectional=False, |
| ) |
| |
| def forward(self, x): |
| y, _ = self.lstm(x) |
| return y |
| |
| torch.manual_seed(43) |
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) |
| model2 = SeqFirstLSTM() |
| with torch.no_grad(): |
| pytorch_output2 = model2(x2) |
| exported_program2 = export(model2, args=(x2,)) |
| mod2 = from_exported_program(exported_program2) |
| ex2 = relax.build(mod2, target) |
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) |
| x2_tvm = tvm.runtime.tensor(x2.numpy()) |
| tvm_output2 = vm2["main"](x2_tvm) |
| if hasattr(tvm_output2, "numpy"): |
| tvm_output2_np = tvm_output2.numpy() |
| else: |
| tvm_output2_np = tvm_output2[0].numpy() |
| assert pytorch_output2.shape == tvm_output2_np.shape |
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) |
| |
| |
| def test_tensor_none_tuple(): |
| example_args = (torch.tensor([1.0, 2.0, 3.0]),) |
| |
| class TensorNoneModel(Module): |
| def forward(self, x): |
| return x + 1, None |
| |
| @tvm.script.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((3,), dtype="float32") |
| ) -> R.Tuple(R.Tensor((3,), dtype="float32"), R.Object): |
| with R.dataflow(): |
| lv: R.Tensor((3,), dtype="float32") = R.add(x, R.const(1.0, "float32")) |
| gv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Object) = (lv, R.null_value()) |
| R.output(gv) |
| return gv |
| |
| verify_model(TensorNoneModel(), example_args, {}, Expected) |
| |
| |
| def test_gru(): |
| class BasicGRU(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.gru = nn.GRU( |
| input_size=4, |
| hidden_size=8, |
| num_layers=1, |
| batch_first=True, |
| bidirectional=False, |
| ) |
| |
| def forward(self, x): |
| y, _ = self.gru(x) |
| return y |
| |
| torch.manual_seed(42) |
| x = torch.randn(2, 3, 4, dtype=torch.float32) |
| model = BasicGRU() |
| with torch.no_grad(): |
| pytorch_output = model(x) |
| exported_program = export(model, args=(x,)) |
| mod = from_exported_program(exported_program) |
| target = tvm.target.Target("llvm") |
| ex = relax.build(mod, target) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| x_tvm = tvm.runtime.tensor(x.numpy()) |
| tvm_output = vm["main"](x_tvm) |
| if hasattr(tvm_output, "numpy"): |
| tvm_output_np = tvm_output.numpy() |
| else: |
| tvm_output_np = tvm_output[0].numpy() |
| assert ( |
| pytorch_output.shape == tvm_output_np.shape |
| ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" |
| np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) |
| |
| class SeqFirstGRU(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.gru = nn.GRU( |
| input_size=3, |
| hidden_size=6, |
| num_layers=1, |
| batch_first=False, |
| bidirectional=False, |
| ) |
| |
| def forward(self, x): |
| y, _ = self.gru(x) |
| return y |
| |
| torch.manual_seed(43) |
| x2 = torch.randn(4, 2, 3, dtype=torch.float32) |
| model2 = SeqFirstGRU() |
| with torch.no_grad(): |
| pytorch_output2 = model2(x2) |
| exported_program2 = export(model2, args=(x2,)) |
| mod2 = from_exported_program(exported_program2) |
| ex2 = relax.build(mod2, target) |
| vm2 = relax.VirtualMachine(ex2, tvm.cpu()) |
| x2_tvm = tvm.runtime.tensor(x2.numpy()) |
| tvm_output2 = vm2["main"](x2_tvm) |
| if hasattr(tvm_output2, "numpy"): |
| tvm_output2_np = tvm_output2.numpy() |
| else: |
| tvm_output2_np = tvm_output2[0].numpy() |
| assert pytorch_output2.shape == tvm_output2_np.shape |
| np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |