| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-docstring, invalid-name |
| import numpy as np |
| import tvm |
| import tvm.testing |
| from tvm import relax, tir |
| from tvm.relax.frontend.nn import Module, Tensor, op, spec |
| from tvm.script import ir as I |
| from tvm.script import relax as R |
| from tvm.script import tir as T |
| |
| # mypy: disable-error-code="attr-defined,valid-type,name-defined" |
| |
| |
| def test_unary(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.square(x) |
| z1 = op.sqrt(x) |
| return (z0, z1) |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| square: R.Tensor((1, 10), dtype="float32") = R.square(x) |
| sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x) |
| gv1 = (square, sqrt), (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([1, 10], "float32")}}, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_binary(): |
| class Model(Module): |
| def test(self, x: Tensor, y: Tensor): |
| z0 = op.add(x, y) |
| z1 = op.multiply(x, y) |
| z2 = op.divide(x, y) |
| z3 = op.matmul(x, y) |
| z4 = op.maximum(x, y) |
| z5 = op.minimum(x, y) |
| z6 = op.subtract(x, y) |
| z7 = op.greater(x, y) |
| z8 = op.greater_equal(x, y) |
| z9 = op.less(x, y) |
| z10 = op.less_equal(x, y) |
| z11 = op.equal(x, y) |
| z12 = op.not_equal(x, y) |
| |
| return (z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12) |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((1, 10), dtype="float32"), y: R.Tensor((10, 1), dtype="float32"), _io: R.Object): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| add: R.Tensor((10, 10), dtype="float32") = R.add(x, y) |
| mul: R.Tensor((10, 10), dtype="float32") = R.multiply(x, y) |
| divide: R.Tensor((10, 10), dtype="float32") = R.divide(x, y) |
| matmul: R.Tensor((1, 1), dtype="float32") = R.matmul(x, y, out_dtype="void") |
| maximum: R.Tensor((10, 10), dtype="float32") = R.maximum(x, y) |
| minimum: R.Tensor((10, 10), dtype="float32") = R.minimum(x, y) |
| subtract: R.Tensor((10, 10), dtype="float32") = R.subtract(x, y) |
| greater: R.Tensor((10, 10), dtype="bool") = x > y |
| greater_equal: R.Tensor((10, 10), dtype="bool") = x >= y |
| less: R.Tensor((10, 10), dtype="bool") = x < y |
| less_equal: R.Tensor((10, 10), dtype="bool") = x <= y |
| equal: R.Tensor((10, 10), dtype="bool") = R.equal(x, y) |
| not_equal: R.Tensor((10, 10), dtype="bool") = R.not_equal(x, y) |
| gv1 = (add, mul, divide, matmul, maximum, minimum, subtract, greater, greater_equal, less, less_equal, equal, not_equal), (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([1, 10], "float32"), "y": spec.Tensor([10, 1], "float32")}}, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_sum(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.sum(x, axis=[1, 2], keepdims=True) |
| return z0 |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| sum: R.Tensor((3, 1, 1, 4), dtype="float32") = R.sum(x, axis=[1, 2], keepdims=True) |
| gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = sum, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_max(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.max(x, axis=[1, 2], keepdims=True) |
| return z0 |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| max: R.Tensor((3, 1, 1, 4), dtype="float32") = R.max(x, axis=[1, 2], keepdims=True) |
| gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = max, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_min(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.min(x, axis=[1, 2], keepdims=True) |
| return z0 |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((3, 5, 2, 4), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| min: R.Tensor((3, 1, 1, 4), dtype="float32") = R.min(x, axis=[1, 2], keepdims=True) |
| gv1: R.Tuple(R.Tensor((3, 1, 1, 4), dtype="float32"), R.Tuple(R.Object)) = min, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([3, 5, 2, 4], "float32")}}, debug=True |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_manipulate(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.broadcast_to(x, [2, 5, 2]) |
| z1 = op.permute_dims(x, [2, 1, 0]) |
| z2 = op.reshape(x, [1, 10]) |
| z3 = op.repeat(x, repeats=2, axis=1) |
| z4 = op.squeeze(x, 0) |
| z5 = op.unsqueeze(x, 0) |
| z6 = op.concat([x, x], dim=0) |
| return (z0, z1, z2, z3, z4, z5, z6) |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((1, 5, 2), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32"), R.Tensor((1, 1, 5, 2), dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| broadcast_to: R.Tensor((2, 5, 2), dtype="float32") = R.broadcast_to(x, R.shape([2, 5, 2])) |
| permute_dims: R.Tensor((2, 5, 1), dtype="float32") = R.permute_dims(x, axes=[2, 1, 0]) |
| reshape: R.Tensor((1, 10), dtype="float32") = R.reshape(x, R.shape([1, 10])) |
| repeat: R.Tensor((1, 10, 2), dtype="float32") = R.repeat(x, repeats=2, axis=1) |
| squeeze: R.Tensor((5, 2), dtype="float32") = R.squeeze(x, axis=[0]) |
| unsqueeze: R.Tensor((1, 1, 5, 2), dtype="float32") = R.expand_dims(x, axis=0) |
| concat: R.Tensor((2, 5, 2), dtype="float32") = R.concat([x, x], axis=0) |
| gv1: R.Tuple(R.Tuple(R.Tensor((2, 5, 2), dtype="float32"), R.Tensor((2, 5, 1), dtype="float32"), R.Tensor((1, 10), dtype="float32"), R.Tensor((1, 10, 2), dtype="float32"), R.Tensor((5, 2), dtype="float32"), R.Tensor((1, 1, 5, 2), dtype="float32"), R.Tensor((2, 5, 2), dtype="float32")), R.Tuple(R.Object)) = (broadcast_to, permute_dims, reshape, repeat, squeeze, unsqueeze, concat), (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([1, 5, 2], "float32")}}, debug=True) |
| |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_index(): |
| class Model(Module): |
| def test(self, x: Tensor, y: Tensor): |
| z0 = op.take(x, y, axis=2) |
| return z0 |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((2, 1, 10), dtype="float32"), y: R.Tensor((5,), dtype="int32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 1, 5), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| take: R.Tensor((2, 1, 5), dtype="float32") = R.take(x, y, axis=2) |
| gv1: R.Tuple(R.Tensor((2, 1, 5), dtype="float32"), R.Tuple(R.Object)) = take, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, params = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([2, 1, 10], "float32"), "y": spec.Tensor([5], "int32")}}, |
| debug=True, |
| ) |
| |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_datatype(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| z0 = op.astype(x, "float16") |
| return z0 |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((2, 1, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((2, 1, 10), dtype="float16"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| astype: R.Tensor((2, 1, 10), dtype="float16") = R.astype(x, dtype="float16") |
| gv1: R.Tuple(R.Tensor((2, 1, 10), dtype="float16"), R.Tuple(R.Object)) = astype, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([2, 1, 10], "float32")}}, debug=True) |
| |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_image(): |
| class Model(Module): |
| def test(self, x: Tensor, weight: Tensor, bias: Tensor): |
| padded = op.pad(x, [0, 0, 0, 0, 1, 1, 1, 1]) |
| conv2d = op.conv2d(padded, weight, bias) |
| interpolate = op.interpolate(x, size=[40, 40]) # type: ignore |
| return (conv2d, interpolate) |
| |
| @R.function |
| def test( |
| x: R.Tensor((1, 3, 32, 32), dtype="float32"), |
| weight: R.Tensor((32, 3, 3, 3), dtype="float32"), |
| bias: R.Tensor((32,), dtype="float32"), |
| _io: R.Object, |
| ) -> R.Tuple( |
| R.Tuple( |
| R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tensor((1, 3, 40, 40), dtype="float32") |
| ), |
| R.Tuple(R.Object), |
| ): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| lv0: R.Tensor((1, 3, 34, 34), dtype="float32") = R.nn.pad(x, (0, 0, 0, 0, 1, 1, 1, 1)) |
| lv1: R.Tensor((1, 32, 32, 32), dtype="float32") = R.nn.conv2d( |
| lv0, |
| weight, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NCHW", |
| kernel_layout="OIHW", |
| out_layout="NCHW", |
| out_dtype="void", |
| ) |
| lv2: R.Tensor((1, 32, 1, 1), dtype="float32") = R.reshape(bias, R.shape([1, 32, 1, 1])) |
| conv2d: R.Tensor((1, 32, 32, 32), dtype="float32") = R.add(lv1, lv2) |
| interpolate: R.Tensor((1, 3, 40, 40), dtype="float32") = R.image.resize2d( |
| x, |
| R.shape([40, 40]), |
| roi=[T.float32(0), T.float32(0), T.float32(0), T.float32(0)], |
| layout="NCHW", |
| method="nearest_neighbor", |
| coordinate_transformation_mode="asymmetric", |
| rounding_method="round", |
| cubic_alpha=-0.75, |
| cubic_exclude=0, |
| extrapolation_value=0, |
| out_dtype="void", |
| ) |
| gv1: R.Tuple( |
| R.Tuple( |
| R.Tensor((1, 32, 32, 32), dtype="float32"), |
| R.Tensor((1, 3, 40, 40), dtype="float32"), |
| ), |
| R.Tuple(R.Object), |
| ) = (conv2d, interpolate), (_io,) |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={ |
| "test": { |
| "x": spec.Tensor([1, 3, 32, 32], "float32"), |
| "weight": spec.Tensor([32, 3, 3, 3], "float32"), |
| "bias": spec.Tensor([32], "float32"), |
| } |
| }, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_chunk(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| chunk = op.chunk(x, chunks=4) |
| return chunk |
| |
| @R.function |
| def test( |
| x: R.Tensor((8,), dtype="float32"), _io: R.Object |
| ) -> R.Tuple( |
| R.Tuple( |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| ), |
| R.Tuple(R.Object), |
| ): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| chunk: R.Tuple( |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| ) = R.split(x, indices_or_sections=4, axis=0) |
| chunk_0: R.Tensor((2,), dtype="float32") = chunk[0] |
| chunk_1: R.Tensor((2,), dtype="float32") = chunk[1] |
| chunk_2: R.Tensor((2,), dtype="float32") = chunk[2] |
| chunk_3: R.Tensor((2,), dtype="float32") = chunk[3] |
| gv1: R.Tuple( |
| R.Tuple( |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| R.Tensor((2,), dtype="float32"), |
| ), |
| R.Tuple(R.Object), |
| ) = (chunk_0, chunk_1, chunk_2, chunk_3), (_io,) |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([8], "float32")}}, debug=True) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_nn(): |
| class Model(Module): |
| def test(self, x: Tensor, weight: Tensor, bias: Tensor): |
| log_out = op.log(x) |
| floor_out = op.floor(x) |
| relu_out = op.relu(x) |
| relu6_out = op.relu6(x) |
| silu_out = op.silu(x) |
| gelu_out = op.gelu(x) |
| sigmoid_out = op.sigmoid(x) |
| tanh_out = op.tanh(x) |
| exp_out = op.exp(x) |
| negative_out = op.negative(x) |
| softplus_out = op.softplus(x, beta=1.0, threshold=20.0) |
| softmax_out = op.softmax(x, axis=2) |
| prelu_out = op.prelu(x, alpha=bias) |
| rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) |
| rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) |
| group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias) |
| return x |
| |
| @R.function |
| def test( |
| x: R.Tensor((2, 3, 4, 5), dtype="float32"), |
| weight: R.Tensor((4, 5), dtype="float32"), |
| bias: R.Tensor((3,), dtype="float32"), |
| _io: R.Object, |
| ) -> R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| log: R.Tensor((2, 3, 4, 5), dtype="float32") = R.log(x) |
| floor: R.Tensor((2, 3, 4, 5), dtype="float32") = R.floor(x) |
| relu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu(x) |
| relu6: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.relu6(x) |
| silu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.silu(x) |
| gelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.gelu(x) |
| sigmoid: R.Tensor((2, 3, 4, 5), dtype="float32") = R.sigmoid(x) |
| tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) |
| exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) |
| negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) |
| softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus( |
| x, beta=1.0, threshold=20.0 |
| ) |
| softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) |
| prelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.prelu(x, bias) |
| rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( |
| x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 |
| ) |
| rms_norm1: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( |
| x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 |
| ) |
| group_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.group_norm( |
| x, bias, bias, num_groups=1, channel_axis=1, axes=[2, 3] |
| ) |
| gv1: R.Tuple(R.Tensor((2, 3, 4, 5), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, params = m.export_tvm( |
| spec={ |
| "test": { |
| "x": spec.Tensor([2, 3, 4, 5], "float32"), |
| "weight": spec.Tensor([4, 5], "float32"), |
| "bias": spec.Tensor([3], "float32"), |
| } |
| }, |
| debug=True, |
| ) |
| |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_create(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| triu_out = op.triu(x) |
| full_with_scalar_out = op.full([10, 10], fill_value=10) # type: ignore |
| full_with_FloatImm_out = op.full( |
| [10, 10], fill_value=tir.FloatImm(dtype="float32", value=10) |
| ) |
| full_with_Tensor_out = op.full( |
| [10, 10], fill_value=Tensor.from_scalar(10, dtype="float32") |
| ) |
| zeros_out = op.zeros([10, 10]) |
| zeros_fp16_out = op.zeros([10, 10], dtype="float16") |
| |
| arange_out = op.arange(0, 10, 1, "float32") |
| return x |
| |
| # fmt: off |
| @R.function |
| def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| triu: R.Tensor((10, 10), dtype="float32") = R.triu(x, k=0) |
| full: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") |
| full1: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") |
| full2: R.Tensor((10, 10), dtype="float32") = R.full(R.shape([10, 10]), R.const(10, "float32"), dtype="float32") |
| zeros: R.Tensor((10, 10), dtype="float32") = R.zeros(R.shape([10, 10]), dtype="float32") |
| zeros1: R.Tensor((10, 10), dtype="float16") = R.zeros(R.shape([10, 10]), dtype="float16") |
| arange: R.Tensor((10,), dtype="float32") = R.arange(T.int64(0), T.int64(10), T.int64(1), dtype="float32") |
| gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = x, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, params = m.export_tvm( |
| spec={"test": {"x": spec.Tensor([10, 10], "float32")}}, debug=True |
| ) |
| |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_timestep_embedding(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| get_timestep_out = op.get_timestep_embedding(x, 10) |
| return get_timestep_out |
| |
| @R.function |
| def test( |
| x: R.Tensor((3,), dtype="float32"), _io: R.Object |
| ) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32") |
| lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, axis=[1]) |
| lv3: R.Tensor((5,), dtype="float32") = R.arange( |
| R.prim_value(T.int64(0)), |
| R.prim_value(T.int64(5)), |
| R.prim_value(T.int64(1)), |
| dtype="float32", |
| ) |
| lv4: R.Tensor((5,), dtype="float32") = R.multiply( |
| R.const(-9.2103404998779297, "float32"), lv3 |
| ) |
| lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4, "float32")) |
| lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5) |
| lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6, axis=[0]) |
| lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7) |
| lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8) |
| lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8) |
| lv11: R.Tensor((3, 10), dtype="float32") = R.concat((lv9, lv10), axis=-1) |
| get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = R.astype( |
| lv11, dtype="float32" |
| ) |
| gv1: R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)) = ( |
| get_timestep_embedding, |
| (_io,), |
| ) |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([3], "float32")}}, debug=True) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_scaled_dot_product_attention(): |
| class Model(Module): |
| def test(self, query: Tensor, key: Tensor, value: Tensor): |
| scaled_dot_product_attention = op.scaled_dot_product_attention(query, key, value) |
| return scaled_dot_product_attention |
| |
| @R.function |
| def test( |
| query: R.Tensor((1, 32, 32, 32), dtype="float32"), |
| key: R.Tensor((1, 32, 32, 32), dtype="float32"), |
| value: R.Tensor((1, 32, 32, 32), dtype="float32"), |
| _io: R.Object, |
| ) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| scaled_dot_product_attention: R.Tensor( |
| (1, 32, 32, 32), dtype="float32" |
| ) = R.nn.attention(query, key, value, scale=None, causal_mask=None) |
| gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = ( |
| scaled_dot_product_attention, |
| (_io,), |
| ) |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={ |
| "test": { |
| "query": spec.Tensor([1, 32, 32, 32], "float32"), |
| "key": spec.Tensor([1, 32, 32, 32], "float32"), |
| "value": spec.Tensor([1, 32, 32, 32], "float32"), |
| } |
| }, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule["test"], test) |
| |
| |
| def test_tensor_expr_op(): |
| class Model(Module): |
| def test(self, x: Tensor): |
| tensor_expr_op_out = op.tensor_expr_op( |
| tensor_expr_func=lambda x: x + 1, name_hint="add_one", args=[x] |
| ) |
| return tensor_expr_op_out |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @T.prim_func(private=True) |
| def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(10), T.int64(10)), "float32")): |
| T.func_attr({"tir.noalias": True}) |
| # with T.block("root"): |
| for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): |
| with T.block("T_add"): |
| v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(A[v_ax0, v_ax1]) |
| T.writes(T_add[v_ax0, v_ax1]) |
| T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + T.float32(1) |
| |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def test(x: R.Tensor((10, 10), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)): |
| cls = Expected |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv1 = R.call_tir(cls.add_one, (x,), out_sinfo=R.Tensor((10, 10), dtype="float32")) |
| gv1: R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tuple(R.Object)) = lv1, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([10, 10], "float32")}}, debug=True) |
| |
| tvm.ir.assert_structural_equal(irmodule, Expected) |
| |
| |
| def test_tensor_ir_op(): |
| num_q_heads, num_kv_heads, head_dim = 8, 8, 16 |
| fused_heads = num_q_heads + num_kv_heads * 2 |
| dtype = "float16" |
| |
| @T.prim_func(private=True) |
| def fused_rope( # pylint: disable=too-many-locals |
| var_qkv: T.handle, |
| var_q: T.handle, |
| var_k: T.handle, |
| var_v: T.handle, |
| # Scalar arguments must be specified after tensor arguments, |
| # including the output tensor arguments |
| # |
| # TODO(Lunderberg): Update |
| # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue` |
| # instead of `tir_vars`, so that the order can be consistent |
| # between the function definition and the arguments in |
| # `op.tensor_ir_op`. |
| offset: T.int64, |
| ): |
| batch_size = T.int64() |
| seq_len = T.int64() |
| qkv = T.match_buffer(var_qkv, (batch_size, seq_len, fused_heads, head_dim), dtype) |
| q = T.match_buffer(var_q, (batch_size, seq_len, num_q_heads, head_dim), dtype) |
| k = T.match_buffer(var_k, (batch_size, seq_len, num_kv_heads, head_dim), dtype) |
| v = T.match_buffer(var_v, (batch_size, seq_len, num_kv_heads, head_dim), dtype) |
| T.evaluate(offset) |
| |
| class Model(Module): |
| def test(self, qkv: Tensor, offset: tir.Var): |
| tensor_expr_op_out = op.tensor_ir_op( |
| fused_rope, |
| "llama_fused_rope", |
| args=[qkv, offset], |
| out=[ |
| Tensor.placeholder((1, 1, num_q_heads, head_dim), dtype), |
| Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype), |
| Tensor.placeholder((1, 1, num_kv_heads, head_dim), dtype), |
| ], |
| ) |
| return tensor_expr_op_out |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @T.prim_func(private=True) |
| def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, offset: T.int64): |
| batch_size, seq_len = T.int64(), T.int64() |
| qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16") |
| q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16") |
| k = T.match_buffer(var_k, (batch_size, seq_len, 8, 16), "float16") |
| v = T.match_buffer(var_v, (batch_size, seq_len, 8, 16), "float16") |
| T.evaluate(offset) |
| |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def test(qkv: R.Tensor((1, 1, 24, 16), dtype="float16"), offset: R.Shape(["offset_1"]), _io: R.Object) -> R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)): |
| offset_1 = T.int64() |
| R.func_attr({"num_input": 3}) |
| cls = Expected |
| with R.dataflow(): |
| lv1 = R.call_tir(cls.llama_fused_rope, (qkv,), out_sinfo=[R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")], tir_vars=R.shape([offset_1])) |
| llama_fused_rope_0: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[0] |
| llama_fused_rope_1: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[1] |
| llama_fused_rope_2: R.Tensor((1, 1, 8, 16), dtype="float16") = lv1[2] |
| gv1: R.Tuple(R.Tuple(R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16"), R.Tensor((1, 1, 8, 16), dtype="float16")), R.Tuple(R.Object)) = (llama_fused_rope_0, llama_fused_rope_1, llama_fused_rope_2), (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={ |
| "test": {"qkv": spec.Tensor([1, 1, fused_heads, head_dim], "float16"), "offset": int} |
| }, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule, Expected) |
| |
| |
| def test_tensor_ir_inplace_op(): |
| hidden_size = 4096 |
| dtype = "float16" |
| |
| @T.prim_func |
| def inplace_take( |
| var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 |
| ): |
| T.func_attr({"tir.noalias": True}) |
| vocab_size = T.int64() |
| weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) |
| seq_len = T.int64() |
| total_seq_len = T.int64() |
| pos = T.match_buffer(var_pos, (seq_len,), "int32") |
| embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) |
| for ax0, ax1 in T.grid(seq_len, hidden_size): |
| with T.block("T_take"): |
| v0, v1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(weight[pos[v0], v1], pos[v0]) |
| T.writes(embeddings[v0, v1]) |
| embeddings[v0 + offset, v1] = weight[pos[v0], v1] |
| |
| class Model(Module): |
| def test( |
| self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int |
| ): |
| tensor_expr_op_out = op.tensor_ir_inplace_op( |
| inplace_take, |
| "inplace_take", |
| args=[embedding_table, input_ids, embedding_dst, offset], |
| inplace_indices=[2], |
| out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype), |
| ) |
| return tensor_expr_op_out |
| |
| @I.ir_module |
| class Expected: |
| @T.prim_func |
| def inplace_take( |
| var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 |
| ): |
| T.func_attr({"tir.noalias": True}) |
| vocab_size = T.int64() |
| weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) |
| seq_len = T.int64() |
| total_seq_len = T.int64() |
| pos = T.match_buffer(var_pos, (seq_len,), "int32") |
| embeddings = T.match_buffer(var_embeddings, (total_seq_len, hidden_size), dtype) |
| for ax0, ax1 in T.grid(seq_len, hidden_size): |
| with T.block("T_take"): |
| v0, v1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(weight[pos[v0], v1], pos[v0]) |
| T.writes(embeddings[v0, v1]) |
| embeddings[v0 + offset, v1] = weight[pos[v0], v1] |
| |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def test( |
| embedding_table: R.Tensor(("vocab_size", hidden_size), dtype), |
| input_ids: R.Tensor(("seq_len",), "int32"), |
| embedding_dst: R.Tensor(("total_seq_len", hidden_size), dtype), |
| offset: R.Shape(["offset_1"]), |
| packed_params: R.Tuple, |
| ) -> R.Tensor(("total_seq_len", hidden_size), dtype): |
| total_seq_len = T.int64() |
| offset_1 = T.int64() |
| R.func_attr({"num_input": 4}) |
| cls = Expected |
| with R.dataflow(): |
| lv1 = R.call_tir_inplace( |
| cls.inplace_take, |
| (embedding_table, input_ids, embedding_dst), |
| out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), |
| inplace_indices=[2], |
| tir_vars=R.shape([offset_1]), |
| ) |
| gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 |
| R.output(gv1) |
| return gv1 |
| |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={ |
| "test": { |
| "embedding_table": spec.Tensor(["vocab_size", hidden_size], dtype), |
| "input_ids": spec.Tensor(["seq_len"], "int32"), |
| "embedding_dst": spec.Tensor(["total_seq_len", hidden_size], dtype), |
| "offset": int, |
| "$": { |
| "param_mode": "packed", |
| "effect_mode": "none", |
| }, |
| }, |
| }, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule, Expected) |
| |
| |
| def test_tensor_ir_op_no_tir_var(): |
| @T.prim_func(private=True) |
| def tir_func(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): |
| T.evaluate(0) |
| |
| class Model(Module): |
| def test(self, A: Tensor): |
| tensor_expr_op_out = op.tensor_ir_op( |
| tir_func, |
| "tir_func", |
| args=[A], |
| out=[Tensor.placeholder((16, 16), "float32")], |
| ) |
| return tensor_expr_op_out |
| |
| @I.ir_module |
| class Expected: |
| @T.prim_func(private=True) |
| def tir_func(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): |
| T.evaluate(0) |
| |
| @R.function |
| def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| cls = Expected |
| with R.dataflow(): |
| lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) |
| gv: R.Tensor((16, 16), dtype="float32") = lv |
| R.output(gv) |
| return gv |
| |
| m = Model() |
| irmodule, _ = m.export_tvm(spec={"test": {"A": spec.Tensor([16, 16], "float32")}}) |
| tvm.ir.assert_structural_equal(irmodule, Expected) |
| |
| |
| def test_extern(): |
| class Model(Module): |
| def test(self, q: Tensor, k: Tensor, v: Tensor): |
| b, s, h_q, d = q.shape |
| tensor_expr_op_out = op.extern( |
| name="flashinfer.single_decode", |
| args=[q, k, v, 0, 0, 1.0, 10000.0], |
| out=Tensor.placeholder((b, s, h_q * d), dtype="float16"), |
| ) |
| return tensor_expr_op_out |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def test(q: R.Tensor((1, 1, 16, 8), dtype="float32"), k: R.Tensor((64, 16, 8), dtype="float32"), v: R.Tensor((64, 16, 8), dtype="float32"), _io: R.Object) -> R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| flashinfer_single_decode = R.call_dps_packed("flashinfer.single_decode", (q, k, v, R.prim_value(0), R.prim_value(0), R.prim_value(T.float64(1)), R.prim_value(T.float64(10000))), out_sinfo=R.Tensor((1, 1, 128), dtype="float16")) |
| gv1: R.Tuple(R.Tensor((1, 1, 128), dtype="float16"), R.Tuple(R.Object)) = flashinfer_single_decode, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| batch, seq, t, d, h_q, h_kv = 1, 1, 64, 8, 16, 16 |
| m = Model() |
| irmodule, _ = m.export_tvm( |
| spec={ |
| "test": { |
| "q": spec.Tensor([batch, seq, h_q, d], "float32"), |
| "k": spec.Tensor([t, h_kv, d], "float32"), |
| "v": spec.Tensor([t, h_kv, d], "float32"), |
| } |
| }, |
| debug=True, |
| ) |
| tvm.ir.assert_structural_equal(irmodule, Expected) |
| |
| |
| def test_empty(): |
| @tvm.register_global_func("test_empty_assert", override=True) |
| def test_empty_assert(_lineo, x): |
| assert x.shape == (10, 10) |
| assert x.dtype == "float32" |
| |
| class Model(Module): |
| def test(self): |
| result = op.empty([10, 10], dtype="float32") |
| op.debug_func("test_empty_assert", result) |
| return result |
| |
| irmodule, _ = Model().export_tvm(spec={"test": {}}, debug=True) |
| ex = tvm.compile(irmodule, "llvm") |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| effects = vm["_initialize_effect"]() |
| vm["test"](*effects) |
| |
| |
| @tvm.testing.requires_cuda |
| def test_multinomial_from_uniform(): |
| prob_shape = (3, 5) |
| sample_shape = (6, 1) |
| |
| class Model(Module): |
| def foo(self, prob: Tensor, uniform_sample: Tensor, sample_indices: Tensor): |
| z0 = op.multinomial_from_uniform(prob, uniform_sample, sample_indices) |
| return z0 |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1), dtype="float32"), sample_indices: R.Tensor((6, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 4}) |
| with R.dataflow(): |
| multinomial_from_uniform: R.Tensor((6, 1), dtype="int64") = R.multinomial_from_uniform(prob, uniform_sample, sample_indices, dtype="int64") |
| gv1: R.Tuple(R.Tensor((6, 1), dtype="int64"), R.Tuple(R.Object)) = multinomial_from_uniform, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| mod, _ = m.export_tvm( |
| spec={ |
| "foo": { |
| "prob": spec.Tensor(prob_shape, "float32"), |
| "uniform_sample": spec.Tensor(sample_shape, "float32"), |
| "sample_indices": spec.Tensor(sample_shape, "int64"), |
| } |
| }, |
| debug=True, |
| ) |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| target = tvm.target.Target("cuda", host="llvm") |
| with target: |
| mod = relax.backend.DispatchSampling()(mod) |
| mod = tir.transform.DefaultGPUSchedule()(mod) |
| ex = tvm.compile(mod, target) |
| dev = tvm.device(str(target), 0) |
| vm = relax.VirtualMachine(ex, dev) |
| |
| effects = vm["_initialize_effect"]() |
| |
| np_rand = np.random.rand(*prob_shape).astype(np.float32) |
| # normalize it to get the random prob |
| np_prob = np_rand / np_rand.sum(axis=1, keepdims=True) |
| nd_prob = tvm.runtime.tensor(np_prob, dev) |
| # special sample to get deterministic results |
| nd_sample = tvm.runtime.tensor(np.array([[1], [0], [1], [1], [0], [1]]).astype(np.float32), dev) |
| nd_sample_indices = tvm.runtime.tensor( |
| np.array([[0], [1], [1], [2], [2], [2]]).astype(np.int64), dev |
| ) |
| inputs = [nd_prob, nd_sample, nd_sample_indices, effects] |
| res = vm["foo"](*inputs) |
| tvm.testing.assert_allclose( |
| res[0].numpy(), np.array([[4], [0], [4], [4], [0], [4]]).astype(np.int64) |
| ) |
| |
| |
| @tvm.testing.requires_gpu |
| def test_sample_top_p_top_k_from_sorted_prob(): |
| prob_shape = (2, 3) |
| sample_shape = (3, 1) |
| |
| class Model(Module): |
| def foo( |
| self, |
| prob: Tensor, |
| index: Tensor, |
| top_p: Tensor, |
| top_k: Tensor, |
| uniform_sample: Tensor, |
| sample_indices: Tensor, |
| ): |
| z0 = op.sample_top_p_top_k_from_sorted_prob( |
| prob, index, top_p, top_k, uniform_sample, sample_indices |
| ) |
| return z0 |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @T.prim_func(private=True) |
| def get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle): |
| batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) |
| cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) |
| indices = T.match_buffer(B, (batch, vocab_size), "int64") |
| renorm_prob = T.match_buffer(C, (batch, 1)) |
| out_batch = T.int64(is_size_var=True) |
| usample = T.match_buffer(D, (out_batch, 1)) |
| sample_indices = T.match_buffer(E, (out_batch, 1), "int64") |
| output_index = T.match_buffer(F, (out_batch, 1), "int64") |
| # with T.block("root"): |
| for ax0, ax1 in T.grid(out_batch, vocab_size): |
| with T.block("T_get_index_from_sorted"): |
| v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(usample[v_ax0, T.int64(0)], cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1):v_ax1 - T.int64(1) + T.int64(2)], sample_indices[v_ax0, T.int64(0)], renorm_prob[sample_indices[v_ax0, T.int64(0)], 0], indices[sample_indices[v_ax0, T.int64(0)], T.min(T.int64(0), v_ax1):T.min(T.int64(0), v_ax1) + (T.max(T.int64(0), v_ax1) + T.int64(1) - T.min(T.int64(0), v_ax1))]) |
| T.writes(output_index[v_ax0, 0]) |
| if usample[v_ax0, T.int64(0)] < cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0] or v_ax1 + T.int64(1) == vocab_size: |
| if v_ax1 == T.int64(0): |
| output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], 0] |
| else: |
| if usample[v_ax0, T.int64(0)] >= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - T.int64(1)] / renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]: |
| output_index[v_ax0, 0] = indices[sample_indices[v_ax0, T.int64(0)], v_ax1] |
| |
| @T.prim_func(private=True) |
| def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): |
| batch, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) |
| cumsum_sorted = T.match_buffer(A, (batch, vocab_size)) |
| top_p = T.match_buffer(B, (batch, 1)) |
| top_k = T.match_buffer(C, (batch, 1), "int64") |
| renorm_prob = T.match_buffer(D, (batch, 1)) |
| # with T.block("root"): |
| for ax0, ax1 in T.grid(batch, vocab_size): |
| with T.block("T_get_renorm_prob"): |
| v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0]) |
| T.writes(renorm_prob[v_ax0, 0]) |
| if not (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)): |
| renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, 0] |
| else: |
| if cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]: |
| if v_ax1 + T.int64(1) == vocab_size: |
| renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1] |
| else: |
| if not (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]): |
| renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] |
| |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 3), dtype="int64"), top_p: R.Tensor((2, 1), dtype="float32"), top_k: R.Tensor((2, 1), dtype="int64"), uniform_sample: R.Tensor((3, 1), dtype="float32"), sample_indices: R.Tensor((3, 1), dtype="int64"), _io: R.Object,) -> R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 7}) |
| cls = Expected |
| with R.dataflow(): |
| cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(prob, axis=1, dtype="void", exclusive=None) |
| lv1 = R.call_tir(cls.get_renorm_prob, (cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) |
| lv2 = R.call_tir(cls.get_index_from_sorted, (cumsum, index, lv1, uniform_sample, sample_indices), out_sinfo=R.Tensor((3, 1), dtype="int64")) |
| gv1: R.Tuple(R.Tensor((3, 1), dtype="int64"), R.Tuple(R.Object)) = lv2, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| mod, _ = m.export_tvm( |
| spec={ |
| "foo": { |
| "prob": spec.Tensor(prob_shape, "float32"), |
| "index": spec.Tensor(prob_shape, "int64"), |
| "top_p": spec.Tensor((prob_shape[0], 1), "float32"), |
| "top_k": spec.Tensor((prob_shape[0], 1), "int64"), |
| "uniform_sample": spec.Tensor(sample_shape, "float32"), |
| "sample_indices": spec.Tensor(sample_shape, "int64"), |
| } |
| }, |
| debug=True, |
| ) |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| target = tvm.target.Target("cuda -libs=thrust", host="llvm") |
| with target: |
| mod = tir.transform.DefaultGPUSchedule()(mod) |
| |
| ex = tvm.compile(mod, target) |
| dev = tvm.cuda(0) |
| vm = relax.VirtualMachine(ex, dev) |
| |
| effects = vm["_initialize_effect"]() |
| sorted_prob = tvm.runtime.tensor( |
| np.array([[0.5, 0.4, 0.1], [0.4, 0.3, 0.3]]).astype(np.float32), dev |
| ) |
| indices = tvm.runtime.tensor(np.array([[2, 1, 0], [2, 0, 1]]).astype(np.int64), dev) |
| top_p = tvm.runtime.tensor(np.array([[0.6], [0.9]]).astype(np.float32), dev) |
| top_k = tvm.runtime.tensor(np.array([[3], [2]]).astype(np.int64), dev) |
| usample = tvm.runtime.tensor(np.array([[0.5], [0.6], [0.7]]).astype(np.float32), dev) |
| sample_indices = tvm.runtime.tensor(np.array([[0], [1], [1]]).astype(np.int64), dev) |
| |
| inputs = [sorted_prob, indices, top_p, top_k, usample, sample_indices, effects] |
| |
| res = vm["foo"](*inputs) |
| tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0], [0]]).astype(np.int64)) |
| |
| |
| @tvm.testing.requires_gpu |
| def test_renormalize_top_p_top_k_prob(): |
| prob_shape = (2, 3) |
| sample_shape = (2, 1) |
| |
| class Model(Module): |
| def foo( |
| self, |
| prob: Tensor, |
| sorted_prob: Tensor, |
| top_p: Tensor, |
| top_k: Tensor, |
| ): |
| z0 = op.renormalize_top_p_top_k_prob(prob, sorted_prob, top_p, top_k) |
| return z0 |
| |
| # fmt: off |
| @I.ir_module |
| class Expected: |
| @T.prim_func(private=True) |
| def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"), filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")): |
| T.func_attr({"tir.noalias": True}) |
| # with T.block("root"): |
| for i, j in T.grid(T.int64(2), T.int64(3)): |
| with T.block("filter_with_top_p_top_k"): |
| v_i, v_j = T.axis.remap("SS", [i, j]) |
| T.reads(B[v_i, T.int64(0)], A[v_i, v_j]) |
| T.writes(filter_with_top_p_top_k[v_i, v_j]) |
| filter_with_top_p_top_k[v_i, v_j] = T.Select(B[v_i, T.int64(0)] <= A[v_i, v_j], A[v_i, v_j], T.float32(0)) |
| |
| @T.prim_func(private=True) |
| def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle): |
| batch, vocab_size = T.int64(), T.int64() |
| sorted_prob = T.match_buffer(A, (batch, vocab_size)) |
| cumsum_sorted = T.match_buffer(B, (batch, vocab_size)) |
| top_p = T.match_buffer(C, (batch, 1)) |
| top_k = T.match_buffer(D, (batch, 1), "int64") |
| cutoff = T.match_buffer(E, (batch, 1)) |
| # with T.block("root"): |
| for ax0, ax1 in T.grid(batch, vocab_size): |
| with T.block("T_get_renorm_prob"): |
| v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) |
| T.reads(cumsum_sorted[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))], top_p[v_ax0, 0], top_k[v_ax0, 0], sorted_prob[v_ax0, T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)):T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + (T.max(T.max(T.int64(0), v_ax1), v_ax1 + T.int64(1)) + T.int64(1) - T.min(T.min(T.int64(0), v_ax1), v_ax1 + T.int64(1)))]) |
| T.writes(cutoff[v_ax0, 0]) |
| if (cumsum_sorted[v_ax0, 0] < top_p[v_ax0, 0] and top_k[v_ax0, 0] > T.int64(1)) == T.bool(False): |
| cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0] |
| else: |
| if (cumsum_sorted[v_ax0, v_ax1] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) < top_k[v_ax0, 0]) == T.bool(True): |
| if v_ax1 + T.int64(1) == vocab_size: |
| cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1] |
| else: |
| if (cumsum_sorted[v_ax0, v_ax1 + T.int64(1)] < top_p[v_ax0, 0] and v_ax1 + T.int64(1) + T.int64(1) < top_k[v_ax0, 0]) == T.bool(False): |
| cutoff[v_ax0, 0] = sorted_prob[v_ax0, v_ax1 + T.int64(1)] |
| |
| @R.function |
| def _initialize_effect() -> R.Tuple(R.Object): |
| with R.dataflow(): |
| _io: R.Object = R.null_value() |
| lv: R.Tuple(R.Object) = (_io,) |
| gv: R.Tuple(R.Object) = lv |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), dtype="float32"), top_p: R.Tensor((2, 1), dtype="float32"), top_k: R.Tensor((2, 1), dtype="int64"), _io: R.Object) -> R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tuple(R.Object)): |
| R.func_attr({"num_input": 5}) |
| cls = Expected |
| with R.dataflow(): |
| cumsum: R.Tensor((2, 3), dtype="float32") = R.cumsum(sorted_prob, axis=1, dtype="void", exclusive=None) |
| lv1 = R.call_tir(cls.get_renorm_cutoff, (sorted_prob, cumsum, top_p, top_k), out_sinfo=R.Tensor((2, 1), dtype="float32")) |
| lv2 = R.call_tir(cls.filter_with_top_p_top_k, (prob, lv1), out_sinfo=R.Tensor((2, 3), dtype="float32")) |
| sum: R.Tensor((2, 1), dtype="float32") = R.sum(lv2, axis=[1], keepdims=True) |
| divide: R.Tensor((2, 3), dtype="float32") = R.divide(lv2, sum) |
| gv1: R.Tuple(R.Tensor((2, 3), dtype="float32"), R.Tuple(R.Object)) = divide, (_io,) |
| R.output(gv1) |
| return gv1 |
| # fmt: on |
| |
| m = Model() |
| mod, _ = m.export_tvm( |
| spec={ |
| "foo": { |
| "prob": spec.Tensor(prob_shape, "float32"), |
| "sorted_prob": spec.Tensor(prob_shape, "float32"), |
| "top_p": spec.Tensor(sample_shape, "float32"), |
| "top_k": spec.Tensor(sample_shape, "int64"), |
| } |
| }, |
| debug=True, |
| ) |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| target = tvm.target.Target("cuda -libs=thrust", host="llvm") |
| with target: |
| mod = relax.transform.LegalizeOps()(mod) |
| mod = tir.transform.DefaultGPUSchedule()(mod) |
| |
| ex = tvm.compile(mod, target) |
| dev = tvm.cuda(0) |
| vm = relax.VirtualMachine(ex, dev) |
| |
| effects = vm["_initialize_effect"]() |
| prob = tvm.runtime.tensor(np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]]).astype(np.float32), dev) |
| sorted_prob = tvm.runtime.tensor( |
| np.array([[0.5, 0.3, 0.2], [0.4, 0.3, 0.3]]).astype(np.float32), dev |
| ) |
| top_p = tvm.runtime.tensor(np.array([[0.6], [0.9]]).astype(np.float32), dev) |
| top_k = tvm.runtime.tensor(np.array([[3], [2]]).astype(np.int64), dev) |
| |
| inputs = [prob, sorted_prob, top_p, top_k, effects] |
| |
| res = vm["foo"](*inputs) |
| tvm.testing.assert_allclose( |
| res[0].numpy(), np.array([[0, 0.375, 0.625], [0.3, 0.3, 0.4]]).astype(np.float32) |
| ) |
| |
| |
| def test_sort_argsort_topk(): |
| class Model(Module): |
| def foo(self, x: Tensor): |
| z0 = op.sort(x, axis=-1, descending=True) |
| z1 = op.argsort(x, axis=-1, descending=False) |
| z2 = op.topk(x, k=2, axis=-1) |
| return z0, z1, z2 |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def foo(x: R.Tensor(("seq_len", 64), dtype="float16")): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| sort = R.sort(x, axis=-1, descending=True) |
| argsort = R.argsort(x, axis=-1, descending=False, dtype="int32") |
| topk = R.topk(x, k=2, axis=-1, ret_type="both", largest=True, dtype="int32") |
| topk_0 = topk[0] |
| topk_1 = topk[1] |
| gv = sort, argsort, (topk_0, topk_1) |
| R.output(gv) |
| return gv |
| |
| m = Model() |
| mod, _ = m.export_tvm({"foo": {"x": spec.Tensor(("seq_len", 64), "float16")}}) |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |