| # ruff: noqa: E402 |
| import pytest |
| |
| pytest.importorskip("tensorflow", reason="tensorflow not available") |
| |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except |
| # pylint: disable=import-outside-toplevel, redefined-builtin |
| """TFLite to Relax converter tests""" |
| |
| import os |
| |
| import flatbuffers |
| import numpy as np |
| import pytest |
| import tensorflow as tf |
| import tflite.Model |
| from tensorflow.keras import applications as keras_app |
| |
| import tvm |
| import tvm.relax.frontend.tflite.tflite_frontend as tflite_frontend |
| from tvm import relax |
| from tvm.relax.frontend.tflite import from_tflite |
| from tvm.script.parser import ir as I |
| from tvm.script.parser import relax as R |
| from tvm.script.parser import tirx as T |
| |
| |
| def _get_mod_from_cfunc(cfunc): |
| converter = tf.lite.TFLiteConverter.from_concrete_functions([cfunc]) |
| converter.target_spec.supported_ops = [ |
| tf.lite.OpsSet.TFLITE_BUILTINS, |
| tf.lite.OpsSet.SELECT_TF_OPS, |
| ] |
| |
| tflite_model_buf = converter.convert() |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| return mod |
| |
| |
| def verify(TestClass, expected=None): |
| if isinstance(TestClass, type): |
| cf = TestClass().func.get_concrete_function() |
| else: |
| cf = TestClass |
| mod = _get_mod_from_cfunc(cf) |
| |
| if expected: |
| tvm.ir.assert_structural_equal(mod, expected) |
| |
| # Run E2E test only on nightly |
| if "CI_ENV_NIGHTLY" not in os.environ: |
| return |
| |
| # Inputs |
| tf_inputs = [] |
| tvm_inputs = [] |
| for arg in mod["main"].params: |
| shape = tuple(shape_val.value for shape_val in arg.struct_info.shape.values) |
| data = np.random.uniform(0, 1, size=shape).astype(arg.struct_info.dtype) |
| tvm_inputs.append(data) |
| tf_inputs.append(tf.constant(data)) |
| |
| # TF Run |
| tf_output = cf(*tf_inputs) |
| |
| # TVM Run |
| tgt = tvm.target.Target("llvm") |
| ex = tvm.compile(mod, tgt) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| vm.set_input("main", *tvm_inputs) |
| vm.invoke_stateful("main") |
| tvm_output = vm.get_outputs("main") |
| |
| if isinstance(tf_output, tuple): |
| for tf_out, tvm_out in zip(tf_output, tvm_output): |
| np.testing.assert_allclose(tf_out.numpy(), tvm_out.numpy(), rtol=1e-5, atol=1e-5) |
| else: |
| np.testing.assert_allclose(tf_output.numpy(), tvm_output.numpy(), rtol=1e-5, atol=1e-5) |
| |
| |
| def _verify_random_with_inputs(cfunc, inputs): |
| """E2E verify random ops by shape/dtype and TVM seeded self-consistency.""" |
| if "CI_ENV_NIGHTLY" not in os.environ: |
| return |
| |
| mod = _get_mod_from_cfunc(cfunc) |
| tvm_inputs = [np.asarray(data) for data in inputs] |
| tf_inputs = [tf.constant(data) for data in tvm_inputs] |
| |
| tf_output = cfunc(*tf_inputs) |
| |
| tgt = tvm.target.Target("llvm") |
| ex = tvm.compile(mod, tgt) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| |
| def run_tvm(): |
| vm.set_input("main", *tvm_inputs) |
| vm.invoke_stateful("main") |
| return vm.get_outputs("main") |
| |
| tvm_output = run_tvm() |
| tvm_output_again = run_tvm() |
| |
| if not isinstance(tf_output, tuple): |
| tf_output = (tf_output,) |
| tvm_output = (tvm_output,) |
| tvm_output_again = (tvm_output_again,) |
| |
| for tf_out, tvm_out, tvm_out_again in zip(tf_output, tvm_output, tvm_output_again): |
| tf_np = tf_out.numpy() |
| tvm_np = tvm_out.numpy() |
| assert tvm_np.shape == tf_np.shape |
| assert tvm_np.dtype == tf_np.dtype |
| np.testing.assert_equal(tvm_np, tvm_out_again.numpy()) |
| |
| |
| def test_add_one_2d(): |
| class AddOne2D(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) |
| def func(self, x): |
| return x + 1 |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add(x, R.const(1.0, "float32")) |
| R.output(gv) |
| return gv |
| |
| verify(AddOne2D, Expected) |
| |
| |
| def test_add_n(): |
| class AddN(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y, z): |
| return tf.add_n([x, y, z]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| z: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.add(x, y) |
| gv: R.Tensor((2, 2), dtype="float32") = R.add(lv, z) |
| R.output(gv) |
| return gv |
| |
| verify(AddN, Expected) |
| |
| |
| def test_cumsum(): |
| class Cumsum(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(3, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(5, 6), dtype=tf.int32), |
| ] |
| ) |
| def func(self, x, y): |
| out1 = tf.math.cumsum(x, axis=0) |
| out2 = tf.math.cumsum(y, axis=1, exclusive=True) |
| return out1, out2 |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((3, 4), dtype="float32"), |
| y: R.Tensor((5, 6), dtype="int32"), |
| ) -> R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((5, 6), dtype="int32")): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv1: R.Tensor((3, 4), dtype="float32") = R.cumsum( |
| x, axis=0, dtype="float32", exclusive=False |
| ) |
| gv2: R.Tensor((5, 6), dtype="int32") = R.cumsum( |
| y, axis=1, dtype="int32", exclusive=True |
| ) |
| gv = (gv1, gv2) |
| R.output(gv) |
| return gv |
| |
| verify(Cumsum, Expected) |
| |
| |
| def test_split(): |
| class Split(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| a, b, c = tf.split(x, 3, axis=1) |
| return tf.raw_ops.Pack(values=[a, b, c], axis=1) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 3, 10), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 10), dtype="float32"), |
| R.Tensor((1, 10), dtype="float32"), |
| R.Tensor((1, 10), dtype="float32"), |
| ) = R.split(x, indices_or_sections=3, axis=1) |
| lv1: R.Tensor((1, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((1, 1, 10), dtype="float32") = R.expand_dims(lv1, axis=[1]) |
| lv3: R.Tensor((1, 10), dtype="float32") = lv[1] |
| lv4: R.Tensor((1, 1, 10), dtype="float32") = R.expand_dims(lv3, axis=[1]) |
| lv5: R.Tensor((1, 10), dtype="float32") = lv[2] |
| lv6: R.Tensor((1, 1, 10), dtype="float32") = R.expand_dims(lv5, axis=[1]) |
| gv: R.Tensor((1, 3, 10), dtype="float32") = R.concat((lv2, lv4, lv6), axis=1) |
| R.output(gv) |
| return gv |
| |
| verify(Split, Expected) |
| |
| |
| def test_split_v_dynamic(): |
| """SPLIT_V with runtime split sizes imports shape-aware Relax IR.""" |
| |
| class TfSplitVDynamic(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(10,), dtype=tf.float32), |
| tf.TensorSpec(shape=(3,), dtype=tf.int32), |
| ] |
| ) |
| def func(self, x, size_splits): |
| return tf.split(x, size_splits, axis=0) |
| |
| cf = TfSplitVDynamic().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| assert "R.dynamic_strided_slice" in ir |
| assert "R.scatter_elements" in ir |
| |
| |
| def test_split_v_static(): |
| """SPLIT_V with static unequal size_splits lowers to Relax split.""" |
| |
| class SplitVUnequal(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 10, 4), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.split(x, [2, 3, 5], axis=1) |
| |
| @I.ir_module |
| class ExpectedUnequal: |
| @R.function |
| def main(x: R.Tensor((2, 10, 4), dtype="float32")) -> R.Tuple( |
| R.Tensor((2, 2, 4), dtype="float32"), |
| R.Tensor((2, 3, 4), dtype="float32"), |
| R.Tensor((2, 5, 4), dtype="float32"), |
| ): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((2, 2, 4), dtype="float32"), |
| R.Tensor((2, 3, 4), dtype="float32"), |
| R.Tensor((2, 5, 4), dtype="float32"), |
| ) = R.split(x, indices_or_sections=[2, 5], axis=1) |
| lv1: R.Tensor((2, 2, 4), dtype="float32") = lv[0] |
| lv2: R.Tensor((2, 3, 4), dtype="float32") = lv[1] |
| lv3: R.Tensor((2, 5, 4), dtype="float32") = lv[2] |
| gv: R.Tuple( |
| R.Tensor((2, 2, 4), dtype="float32"), |
| R.Tensor((2, 3, 4), dtype="float32"), |
| R.Tensor((2, 5, 4), dtype="float32"), |
| ) = lv1, lv2, lv3 |
| R.output(gv) |
| return gv |
| |
| verify(SplitVUnequal, ExpectedUnequal) |
| |
| |
| def test_pack(): |
| class Pack(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| return tf.raw_ops.Pack(values=[x, y], axis=0) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), dtype="float32"), |
| y: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tensor((2, 2, 3), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(x, axis=[0]) |
| lv1: R.Tensor((1, 2, 3), dtype="float32") = R.expand_dims(y, axis=[0]) |
| gv: R.Tensor((2, 2, 3), dtype="float32") = R.concat((lv, lv1), axis=0) |
| R.output(gv) |
| return gv |
| |
| verify(Pack, Expected) |
| |
| |
| def test_cast(): |
| class Cast(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.cast(x, tf.int32) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="int32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="int32") = R.astype(x, dtype="int32") |
| R.output(gv) |
| return gv |
| |
| verify(Cast, Expected) |
| |
| |
| def test_bitcast_float32_to_int32(): |
| """BITCAST same-width: float32 -> int32, shape preserved.""" |
| |
| class BitcastF32ToI32(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.bitcast(x, tf.int32) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="int32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="int32") = R.memory.view( |
| x, R.shape([1, 30]), R.dtype("int32") |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(BitcastF32ToI32, Expected) |
| |
| |
| def test_bitcast_uint8_to_int8(): |
| """BITCAST same-width 8-bit: uint8 -> int8.""" |
| |
| class BitcastU8ToI8(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4,), dtype=tf.uint8)]) |
| def func(self, x): |
| return tf.bitcast(x, tf.int8) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((4,), dtype="uint8")) -> R.Tensor((4,), dtype="int8"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((4,), dtype="int8") = R.memory.view(x, R.shape([4]), R.dtype("int8")) |
| R.output(gv) |
| return gv |
| |
| verify(BitcastU8ToI8, Expected) |
| |
| |
| def test_bitcast_int32_to_int16_widens_shape(): |
| """BITCAST width-changing (smaller): int32[3] -> int16[3, 2].""" |
| |
| class BitcastI32ToI16(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.int32)]) |
| def func(self, x): |
| return tf.bitcast(x, tf.int16) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 2), dtype="int16"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((3, 2), dtype="int16") = R.memory.view( |
| x, R.shape([3, 2]), R.dtype("int16") |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(BitcastI32ToI16, Expected) |
| |
| |
| def test_bitcast_int16_to_int32_collapses_shape(): |
| """BITCAST width-changing (larger): int16[5, 2] -> int32[5].""" |
| |
| class BitcastI16ToI32(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(5, 2), dtype=tf.int16)]) |
| def func(self, x): |
| return tf.bitcast(x, tf.int32) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((5, 2), dtype="int16")) -> R.Tensor((5,), dtype="int32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((5,), dtype="int32") = R.memory.view(x, R.shape([5]), R.dtype("int32")) |
| R.output(gv) |
| return gv |
| |
| verify(BitcastI16ToI32, Expected) |
| |
| |
| def test_expand_dims(): |
| class ExpandDims(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.expand_dims(x, axis=2) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30, 1), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30, 1), dtype="float32") = R.reshape(x, R.shape([1, 30, 1])) |
| R.output(gv) |
| return gv |
| |
| verify(ExpandDims, Expected) |
| |
| |
| def test_transpose(): |
| class Transpose(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| x = tf.expand_dims(x, axis=2) |
| return tf.transpose(x, perm=[0, 2, 1]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 1, 30), dtype="float32") = R.reshape(x, R.shape([1, 1, 30])) |
| R.output(gv) |
| return gv |
| |
| verify(Transpose, Expected) |
| |
| |
| def test_reshape(): |
| class Reshape(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.reshape(x, (1, 2, 15)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 2, 15), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 2, 15), dtype="float32") = R.reshape(x, R.shape([1, 2, 15])) |
| R.output(gv) |
| return gv |
| |
| verify(Reshape, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, out_type", |
| [ |
| ((2, 3, 4), tf.int32), |
| ((5,), tf.int64), |
| ((1, 1, 1, 1), tf.int32), |
| ((), tf.int32), |
| ((0, 3), tf.int64), |
| ], |
| ) |
| def test_shape(input_shape, out_type): |
| """SHAPE conversion for static-rank non-quantized tensors.""" |
| |
| class Shape(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) |
| def func(self, x): |
| return tf.shape(x, out_type=out_type) |
| |
| verify(Shape) |
| |
| |
| def test_shape_dynamic_dim(): |
| """SHAPE conversion with a dynamic input dimension.""" |
| |
| class ShapeDynamic(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.shape(x, out_type=tf.int32) |
| |
| verify(ShapeDynamic) |
| |
| |
| @pytest.mark.parametrize( |
| "start, limit, delta, dtype", |
| [ |
| (0, 8, 2, tf.int32), |
| (1, 9, 2, tf.int64), |
| (0.0, 1.0, 0.2, tf.float32), |
| (8, 0, -2, tf.int32), |
| (0, 0, 1, tf.int32), |
| (0, 7, 2, tf.int32), |
| (0.0, -1.0, -0.25, tf.float32), |
| ], |
| ) |
| def test_range(start, limit, delta, dtype): |
| """RANGE conversion with non-quantized constant scalar bounds.""" |
| |
| class Range(tf.Module): |
| @tf.function(input_signature=[]) |
| def func(self): |
| return tf.range(start, limit, delta, dtype=dtype) |
| |
| verify(Range) |
| |
| |
| def test_range_dynamic_scalar_inputs_not_supported(): |
| """RANGE conversion currently rejects dynamic scalar inputs.""" |
| |
| class RangeDynamic(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(), dtype=tf.int32), |
| tf.TensorSpec(shape=(), dtype=tf.int32), |
| tf.TensorSpec(shape=(), dtype=tf.int32), |
| ] |
| ) |
| def func(self, start, limit, delta): |
| return tf.range(start, limit, delta, dtype=tf.int32) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"): |
| verify(RangeDynamic) |
| |
| |
| def test_tile_ir(): |
| """TILE conversion with explicit Relax IR structural check.""" |
| |
| class Tile(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.tile(x, [2, 1]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 3), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((4, 3), dtype="float32") = R.tile(x, repeats=[2, 1]) |
| R.output(gv) |
| return gv |
| |
| verify(Tile, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, multiples, dtype", |
| [ |
| ((2, 3), [2, 1], tf.float32), |
| ((1, 4, 2), [3, 1, 2], tf.float32), |
| ((2, 1, 3, 1), [1, 2, 1, 4], tf.float32), |
| ((2, 3), [1, 1], tf.float32), |
| ((3,), [2], tf.float32), |
| ((2, 3), [4, 2], tf.float32), |
| ((2, 2), [1, 3], tf.int32), |
| ], |
| ) |
| def test_tile(input_shape, multiples, dtype): |
| """TILE conversion for non-quantized input and repeat factors.""" |
| |
| class Tile(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=dtype)]) |
| def func(self, x): |
| return tf.tile(x, multiples) |
| |
| verify(Tile) |
| |
| |
| def test_concat_v2(): |
| class ConcatV2(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| a, b, c = tf.split(x, 3, axis=1) |
| axis = tf.add(tf.constant(1, dtype="int32"), tf.constant(0, dtype="int32")) |
| return tf.raw_ops.ConcatV2(values=[a, b, c], axis=axis) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 10), dtype="float32"), |
| R.Tensor((1, 10), dtype="float32"), |
| R.Tensor((1, 10), dtype="float32"), |
| ) = R.split(x, indices_or_sections=3, axis=1) |
| lv1: R.Tensor((1, 10), dtype="float32") = lv[0] |
| lv2: R.Tensor((1, 10), dtype="float32") = lv[1] |
| lv3: R.Tensor((1, 10), dtype="float32") = lv[2] |
| gv: R.Tensor((1, 30), dtype="float32") = R.concat((lv1, lv2, lv3), axis=1) |
| R.output(gv) |
| return gv |
| |
| verify(ConcatV2, Expected) |
| |
| |
| def test_multi_output(): |
| class MultiOutput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) |
| def func(self, x): |
| y = 2 * x |
| return x, y |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.multiply(x, R.const(2.0, "float32")) |
| gv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (x, lv) |
| R.output(gv) |
| return gv |
| |
| verify(MultiOutput, Expected) |
| |
| |
| def test_elu(): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.elu(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 30), dtype="float32") = R.exp(x) |
| lv1: R.Tensor((1, 30), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv) |
| lv2: R.Tensor((1, 30), dtype="float32") = R.nn.relu(lv1) |
| lv3: R.Tensor((1, 30), dtype="float32") = R.multiply(R.const(-1.0, "float32"), lv2) |
| lv4: R.Tensor((1, 30), dtype="float32") = R.nn.relu(x) |
| gv: R.Tensor((1, 30), dtype="float32") = R.add(lv3, lv4) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_gelu(): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.gelu(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 30), dtype="float32") = R.multiply( |
| x, R.const(0.70710676908493042, "float32") |
| ) |
| lv1: R.Tensor((1, 30), dtype="float32") = R.erf(lv) |
| lv2: R.Tensor((1, 30), dtype="float32") = R.multiply(lv1, R.const(0.5, "float32")) |
| lv3: R.Tensor((1, 30), dtype="float32") = R.add(R.const(0.5, "float32"), lv2) |
| gv: R.Tensor((1, 30), dtype="float32") = R.multiply(x, lv3) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_swish(): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.swish(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 30), dtype="float32") = R.sigmoid(x) |
| gv: R.Tensor((1, 30), dtype="float32") = R.multiply(x, lv) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_prelu_constant_alpha(): |
| alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, dtype=np.float32)) |
| prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init) |
| |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return prelu(x) |
| |
| verify(TfInput) |
| |
| |
| def test_fill(): |
| class TfInput(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(1, 30), dtype=tf.float32), |
| tf.TensorSpec(shape=(), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| fill_out = tf.fill((1, 30), y) |
| return x + fill_out |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 30), dtype="float32"), y: R.Tensor((), dtype="float32") |
| ) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = R.add(x, y) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_fill_dynamic_dims(): |
| """FILL with runtime dims legalizes and compiles.""" |
| |
| class TfFillDynamic(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2,), dtype=tf.int32), |
| tf.TensorSpec(shape=(), dtype=tf.float32), |
| ] |
| ) |
| def func(self, dims, value): |
| return tf.fill(dims, value) |
| |
| cf = TfFillDynamic().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| assert "R.tensor_to_shape" in ir |
| assert "R.full" in ir |
| tvm.compile(mod, tvm.target.Target("llvm")) |
| verify(cf) |
| |
| |
| def test_random_uniform_dynamic_shape(): |
| """RANDOM_UNIFORM imports dynamic shape and validates random output metadata.""" |
| |
| class TfRandomUniform(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2,), dtype=tf.int32)]) |
| def func(self, shape): |
| return tf.raw_ops.RandomUniform(shape=shape, dtype=tf.float32, seed=7, seed2=11) |
| |
| cf = TfRandomUniform().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| assert "R.tensor_to_shape" in ir |
| assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir |
| |
| _verify_random_with_inputs(cf, [np.array([2, 3], dtype="int32")]) |
| |
| |
| def test_random_standard_normal_dynamic_shape(): |
| """RANDOM_STANDARD_NORMAL imports dynamic shape and validates random output metadata.""" |
| |
| class TfRandomStandardNormal(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2,), dtype=tf.int32)]) |
| def func(self, shape): |
| return tf.raw_ops.RandomStandardNormal(shape=shape, dtype=tf.float32, seed=3, seed2=5) |
| |
| cf = TfRandomStandardNormal().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| assert "R.tensor_to_shape" in ir |
| assert 'R.call_dps_packed("tvm.contrib.random.normal"' in ir |
| |
| _verify_random_with_inputs(cf, [np.array([2, 4], dtype="int32")]) |
| |
| |
| def test_multinomial_dynamic_num_samples(): |
| """MULTINOMIAL lowers through seeded uniform sampling with dynamic num_samples.""" |
| |
| class TfMultinomial(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(), dtype=tf.int32), |
| ] |
| ) |
| def func(self, logits, num_samples): |
| return tf.raw_ops.Multinomial( |
| logits=logits, |
| num_samples=num_samples, |
| output_dtype=tf.int64, |
| seed=13, |
| seed2=17, |
| ) |
| |
| cf = TfMultinomial().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| assert "R.nn.softmax" in ir |
| assert "R.multinomial_from_uniform" in ir |
| assert "R.tensor_to_shape" in ir |
| assert "multinomial_num_samples" in ir |
| assert 'R.call_dps_packed("tvm.contrib.random.uniform"' in ir |
| |
| _verify_random_with_inputs( |
| cf, |
| [ |
| np.array([[2.0, 1.0, 0.5], [0.1, 0.2, 3.0]], dtype="float32"), |
| np.array(4, dtype="int32"), |
| ], |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.add, R.add), |
| (tf.subtract, R.subtract), |
| (tf.multiply, R.multiply), |
| (tf.divide, R.divide), |
| (tf.math.floormod, R.floor_mod), |
| (tf.math.floordiv, R.floor_divide), |
| (tf.math.atan2, R.atan2), |
| ], |
| ) |
| def test_binary(tf_op, relax_op): |
| class Binary(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| return tf_op(x, y) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), y: R.Tensor((2, 2), dtype="float32") |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y) |
| R.output(gv) |
| return gv |
| |
| verify(Binary, Expected) |
| |
| |
| def test_pow(): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.math.pow(x, 4) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = R.power(x, R.const(4.0, "float32")) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_square(): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.math.square(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = R.power(x, R.const(2.0, "float32")) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_broadcast_args(): |
| class TfInput(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(3,), dtype=tf.int32), |
| tf.TensorSpec(shape=(3,), dtype=tf.int32), |
| ] |
| ) |
| def func(self, s0, s1): |
| return tf.broadcast_dynamic_shape(s0, s1) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(s0: R.Tensor((3,), dtype="int32"), s1: R.Tensor((3,), dtype="int32")) -> R.Tensor( |
| (3,), dtype="int32" |
| ): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((0,), dtype="int32") = R.full( |
| R.shape([0]), R.const(1, "int32"), dtype="int32" |
| ) |
| lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0) |
| lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1, "int32")) |
| lv3: R.Tensor((0,), dtype="int32") = R.full( |
| R.shape([0]), R.const(1, "int32"), dtype="int32" |
| ) |
| lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1), axis=0) |
| lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1, "int32")) |
| lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4) |
| lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6) |
| gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_broadcast_args_diff_length(): |
| """BROADCAST_ARGS with shape inputs of different lengths.""" |
| |
| class TfInput(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(1,), dtype=tf.int32), |
| tf.TensorSpec(shape=(3,), dtype=tf.int32), |
| ] |
| ) |
| def func(self, s0, s1): |
| return tf.broadcast_dynamic_shape(s0, s1) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(s0: R.Tensor((1,), dtype="int32"), s1: R.Tensor((3,), dtype="int32")) -> R.Tensor( |
| (3,), dtype="int32" |
| ): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="int32") = R.full( |
| R.shape([2]), R.const(1, "int32"), dtype="int32" |
| ) |
| lv1: R.Tensor((3,), dtype="int32") = R.concat((lv, s0), axis=0) |
| lv2: R.Tensor((3,), dtype="bool") = R.equal(lv1, R.const(1, "int32")) |
| lv3: R.Tensor((0,), dtype="int32") = R.full( |
| R.shape([0]), R.const(1, "int32"), dtype="int32" |
| ) |
| lv4: R.Tensor((3,), dtype="int32") = R.concat((lv3, s1), axis=0) |
| lv5: R.Tensor((3,), dtype="bool") = R.equal(lv4, R.const(1, "int32")) |
| lv6: R.Tensor((3,), dtype="int32") = R.maximum(lv1, lv4) |
| lv7: R.Tensor((3,), dtype="int32") = R.where(lv5, lv1, lv6) |
| gv: R.Tensor((3,), dtype="int32") = R.where(lv2, lv4, lv7) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.nn.relu, R.nn.relu), |
| (tf.nn.relu6, R.nn.relu6), |
| (tf.math.floor, R.floor), |
| (tf.math.ceil, R.ceil), |
| (tf.math.tanh, R.tanh), |
| (tf.math.sigmoid, R.sigmoid), |
| (tf.math.abs, R.abs), |
| (tf.math.cos, R.cos), |
| (tf.math.sin, R.sin), |
| (tf.math.exp, R.exp), |
| (tf.math.log, R.log), |
| (tf.math.negative, R.negative), |
| (tf.round, R.round), |
| (tf.math.rsqrt, R.rsqrt), |
| (tf.nn.softmax, R.nn.softmax), |
| (tf.math.sqrt, R.sqrt), |
| (tf.nn.log_softmax, R.nn.log_softmax), |
| ], |
| ) |
| def test_element_wise(tf_op, relax_op): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf_op(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = relax_op(x) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.math.less, R.less), |
| (tf.math.less_equal, R.less_equal), |
| (tf.math.greater, R.greater), |
| (tf.math.greater_equal, R.greater_equal), |
| (tf.math.equal, R.equal), |
| (tf.math.not_equal, R.not_equal), |
| ], |
| ) |
| def test_split_compare(tf_op, relax_op): |
| class Compare(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| a, b = tf.split(x, 2, axis=1) |
| return tf_op(a, b, name=None) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 15), dtype="bool"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 15), dtype="float32"), |
| R.Tensor((1, 15), dtype="float32"), |
| ) = R.split(x, indices_or_sections=2, axis=1) |
| lv1: R.Tensor((1, 15), dtype="float32") = lv[0] |
| lv2: R.Tensor((1, 15), dtype="float32") = lv[1] |
| gv: R.Tensor((1, 15), dtype="bool") = relax_op(lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| verify(Compare, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.math.logical_not, R.logical_not), |
| ], |
| ) |
| def test_logical_unary(tf_op, relax_op): |
| class Logical(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 2), dtype=tf.bool), |
| ] |
| ) |
| def func(self, x): |
| return tf_op(x) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="bool"), |
| ) -> R.Tensor((2, 2), dtype="bool"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="bool") = relax_op(x) |
| R.output(gv) |
| return gv |
| |
| verify(Logical, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.math.logical_or, R.logical_or), |
| (tf.math.logical_and, R.logical_and), |
| ], |
| ) |
| def test_logical(tf_op, relax_op): |
| class Logical(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 2), dtype=tf.bool), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.bool), |
| ] |
| ) |
| def func(self, x, y): |
| return tf_op(x, y) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="bool"), y: R.Tensor((2, 2), dtype="bool")) -> R.Tensor( |
| (2, 2), dtype="bool" |
| ): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="bool") = relax_op(x, y) |
| R.output(gv) |
| return gv |
| |
| verify(Logical, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.add, R.add), |
| (tf.subtract, R.subtract), |
| (tf.multiply, R.multiply), |
| (tf.divide, R.divide), |
| (tf.math.floormod, R.floor_mod), |
| (tf.math.maximum, R.maximum), |
| (tf.math.minimum, R.minimum), |
| ], |
| ) |
| def test_split_binary(tf_op, relax_op): |
| class Binary(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| a, b = tf.split(x, 2, axis=1) |
| return tf_op(a, b, name=None) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 15), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((1, 15), dtype="float32"), |
| R.Tensor((1, 15), dtype="float32"), |
| ) = R.split(x, indices_or_sections=2, axis=1) |
| lv1: R.Tensor((1, 15), dtype="float32") = lv[0] |
| lv2: R.Tensor((1, 15), dtype="float32") = lv[1] |
| gv: R.Tensor((1, 15), dtype="float32") = relax_op(lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| verify(Binary, Expected) |
| |
| |
| def test_squared_difference(): |
| class SquaredDifference(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| return tf.math.squared_difference(x, y) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") |
| ) -> R.Tensor((2, 3), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 3), dtype="float32") = R.subtract(x, y) |
| gv: R.Tensor((2, 3), dtype="float32") = R.power(lv, R.const(2.0, "float32")) |
| R.output(gv) |
| return gv |
| |
| verify(SquaredDifference, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op, axis, out_shape", |
| [ |
| (tf.math.argmax, R.argmax, 0, (30,)), |
| (tf.math.argmin, R.argmin, 1, (5,)), |
| ], |
| ) |
| def test_reduce(tf_op, relax_op, axis, out_shape): |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(5, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf_op(x, axis=axis) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((5, 30), dtype="float32")) -> R.Tensor(out_shape, dtype="int64"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor(out_shape, dtype="int64") = relax_op(x, axis=axis, keepdims=False) |
| R.output(gv) |
| return gv |
| |
| verify(TfInput, Expected) |
| |
| |
| def test_fully_connected(): |
| class FullyConnected(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8), dtype=tf.float32)]) |
| def func(self, x): |
| weight = tf.constant(np.arange(24, dtype=np.float32).reshape((3, 8))) |
| bias = tf.constant(np.array([0.5, 1.0, -1.0], dtype=np.float32)) |
| out = tf.matmul(x, weight, transpose_b=True) |
| return tf.nn.bias_add(out, bias) |
| |
| verify(FullyConnected) |
| |
| |
| def test_depthwise_conv2d(): |
| class DepthwiseConv2D(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(3, 3, 2, 1), dtype=tf.float32), |
| ] |
| ) |
| def func(self, data, kernel): |
| return tf.nn.depthwise_conv2d( |
| input=data, |
| filter=kernel, |
| strides=[1, 1, 1, 1], |
| padding="SAME", |
| ) |
| |
| verify(DepthwiseConv2D) |
| |
| |
| def test_transpose_conv(): |
| class TransposeConv(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(3, 3, 3, 2), dtype=tf.float32), |
| ] |
| ) |
| def func(self, data, kernel): |
| output_shape = tf.constant([1, 8, 8, 3], dtype=tf.int32) |
| return tf.nn.conv2d_transpose( |
| input=data, |
| filters=kernel, |
| output_shape=output_shape, |
| strides=[1, 1, 1, 1], |
| padding="SAME", |
| ) |
| |
| verify(TransposeConv) |
| |
| |
| def test_l2_pool2d(): |
| class L2Pool2D(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32)]) |
| def func(self, data): |
| squared = tf.math.square(data) |
| pooled = tf.nn.avg_pool2d(squared, ksize=[2, 2], strides=[1, 1], padding="SAME") |
| return tf.math.sqrt(pooled) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(data: R.Tensor((1, 8, 8, 2), dtype="float32")) -> R.Tensor( |
| (1, 8, 8, 2), dtype="float32" |
| ): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| squared = R.power(data, R.const(2.0, "float32")) |
| pooled = R.nn.avg_pool2d( |
| squared, |
| pool_size=[2, 2], |
| strides=[1, 1], |
| padding=[0, 0, 1, 1], |
| layout="NHWC", |
| ) |
| gv = R.sqrt(pooled) |
| R.output(gv) |
| return gv |
| |
| verify(L2Pool2D, Expected) |
| |
| |
| def test_l2_normalization(): |
| class L2Normalization(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.l2_normalize(x, axis=-1) |
| |
| verify(L2Normalization) |
| |
| |
| def test_local_response_normalization(): |
| class LocalResponseNormalization(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 4), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.local_response_normalization( |
| x, |
| depth_radius=2, |
| bias=1.0, |
| alpha=1e-4, |
| beta=0.75, |
| ) |
| |
| verify(LocalResponseNormalization) |
| |
| |
| def test_slice(): |
| class Slice(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.slice(x, begin=[1, 1], size=[2, 2]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.strided_slice( |
| x, axes=[0, 1], begin=[1, 1], end=[3, 3] |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Slice, Expected) |
| |
| |
| def test_strided_slice_stride(): |
| class StridedSliceStride(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4, 6), dtype=tf.float32)]) |
| def func(self, x): |
| return x[0:2, 1:5:2] |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((4, 6), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.strided_slice( |
| x, |
| axes=[0, 1], |
| begin=[0, 1], |
| end=[2, 5], |
| strides=[1, 2], |
| assume_inbound=False, |
| ) |
| gv: R.Tensor((2, 2), dtype="float32") = R.reshape(lv, R.shape([2, 2])) |
| R.output(gv) |
| return gv |
| |
| verify(StridedSliceStride, Expected) |
| |
| |
| def test_strided_slice_negative_stride(): |
| class StridedSliceNegativeStride(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4,), dtype=tf.float32)]) |
| def func(self, x): |
| return x[::-1] |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((4,), dtype="float32") = R.strided_slice( |
| x, axes=[0], begin=[4], end=[-5], strides=[-1], assume_inbound=False |
| ) |
| gv: R.Tensor((4,), dtype="float32") = R.reshape(lv, R.shape([4])) |
| R.output(gv) |
| return gv |
| |
| verify(StridedSliceNegativeStride, Expected) |
| |
| |
| def test_reverse_v2(): |
| class ReverseV2(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.reverse(x, axis=[1]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 3), dtype="float32") = R.flip(x, axis=1) |
| R.output(gv) |
| return gv |
| |
| verify(ReverseV2, Expected) |
| |
| |
| def test_gather(): |
| class Gather(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(2,), dtype=tf.int64), |
| ] |
| ) |
| def func(self, x, indices): |
| return tf.gather(x, indices, axis=1) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3, 4), dtype="float32"), |
| indices: R.Tensor((2,), dtype="int64"), |
| ) -> R.Tensor((2, 2, 4), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="int32") = R.astype(indices, dtype="int32") |
| gv: R.Tensor((2, 2, 4), dtype="float32") = R.take(x, lv, axis=1, mode="fast") |
| R.output(gv) |
| return gv |
| |
| verify(Gather, Expected) |
| |
| |
| def test_gather_nd(): |
| class GatherND(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.int32), |
| ] |
| ) |
| def func(self, x, indices): |
| return tf.gather_nd(x, indices) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3, 4), dtype="float32"), |
| indices: R.Tensor((2, 2), dtype="int32"), |
| ) -> R.Tensor((2, 4), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="int32") = R.permute_dims(indices, axes=[-1, 0]) |
| lv1: R.Tensor((2, 2), dtype="int64") = R.astype(lv, dtype="int64") |
| gv: R.Tensor((2, 4), dtype="float32") = R.gather_nd(x, lv1, batch_dims=0) |
| R.output(gv) |
| return gv |
| |
| verify(GatherND, Expected) |
| |
| |
| def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, padding): |
| class Conv2DModule(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=data_shape, dtype=tf.float32), |
| tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), |
| ] |
| ) |
| def func(self, data, kernel): |
| return tf.nn.conv2d( |
| input=data, |
| filters=kernel, |
| data_format=data_format, |
| strides=strides, |
| padding=padding, |
| ) |
| |
| return Conv2DModule |
| |
| |
| def test_conv2d_same(): |
| Conv2DModule = _make_conv2d_module( |
| (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 128, 128, 32), dtype="float32"), |
| kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), |
| ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( |
| kernel, axes=[3, 0, 1, 2] |
| ) |
| lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( |
| lv, axes=[1, 2, 3, 0] |
| ) |
| lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.conv2d( |
| data, |
| lv1, |
| strides=[1, 1], |
| padding=[1, 1, 1, 1], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add( |
| lv2, R.const(np.zeros((32,), dtype="float32")) |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv2DModule, Expected) |
| |
| |
| def test_conv2d_valid(): |
| Conv2DModule = _make_conv2d_module( |
| (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 128, 128, 32), dtype="float32"), |
| kernel: R.Tensor((3, 3, 32, 32), dtype="float32"), |
| ) -> R.Tensor((1, 126, 126, 32), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims( |
| kernel, axes=[3, 0, 1, 2] |
| ) |
| lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = R.permute_dims( |
| lv, axes=[1, 2, 3, 0] |
| ) |
| lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = R.nn.conv2d( |
| data, |
| lv1, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add( |
| lv2, R.const(np.zeros((32,), dtype="float32")) |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv2DModule, Expected) |
| |
| |
| def _make_conv3d_module(data_shape, kernel_shape, strides, padding): |
| class Conv3DModule(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=data_shape, dtype=tf.float32), |
| tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), |
| ] |
| ) |
| def func(self, data, kernel): |
| return tf.nn.conv3d( |
| input=data, |
| filters=kernel, |
| strides=strides, |
| padding=padding, |
| ) |
| |
| return Conv3DModule |
| |
| |
| def test_conv3d_valid(): |
| Conv3DModule = _make_conv3d_module((1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), |
| kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"), |
| ) -> R.Tensor((1, 6, 6, 6, 16), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 6, 6, 6, 16), dtype="float32") = R.nn.conv3d( |
| data, |
| kernel, |
| strides=[1, 1, 1], |
| padding=[0, 0, 0, 0, 0, 0], |
| dilation=[1, 1, 1], |
| groups=1, |
| data_layout="NDHWC", |
| kernel_layout="DHWIO", |
| out_layout="NDHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv3DModule, Expected) |
| |
| |
| def test_conv3d_same(): |
| Conv3DModule = _make_conv3d_module((1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), |
| kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"), |
| ) -> R.Tensor((1, 8, 8, 8, 16), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 8, 8, 8, 16), dtype="float32") = R.nn.conv3d( |
| data, |
| kernel, |
| strides=[1, 1, 1], |
| padding=[1, 1, 1, 1, 1, 1], |
| dilation=[1, 1, 1], |
| groups=1, |
| data_layout="NDHWC", |
| kernel_layout="DHWIO", |
| out_layout="NDHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv3DModule, Expected) |
| |
| |
| def _make_conv3d_transpose_module(data_shape, kernel_shape, strides, padding): |
| # Compute the expected output_shape for tf.nn.conv3d_transpose. |
| # data_shape: (N, D, H, W, C_in), kernel_shape: (KD, KH, KW, C_out, C_in) |
| # strides: (1, sD, sH, sW, 1) |
| batch = data_shape[0] |
| out_channels = kernel_shape[3] |
| out_spatial = [] |
| for i in range(3): # D, H, W |
| in_size = data_shape[1 + i] |
| k_size = kernel_shape[i] |
| s = strides[1 + i] |
| if padding == "VALID": |
| out_spatial.append((in_size - 1) * s + k_size) |
| else: # SAME |
| out_spatial.append(in_size * s) |
| computed_output_shape = [batch, *out_spatial, out_channels] |
| |
| class Conv3DTransposeModule(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=data_shape, dtype=tf.float32), |
| tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), |
| ] |
| ) |
| def func(self, data, kernel): |
| return tf.nn.conv3d_transpose( |
| input=data, |
| filters=kernel, |
| output_shape=computed_output_shape, |
| strides=strides, |
| padding=padding, |
| ) |
| |
| return Conv3DTransposeModule |
| |
| |
| def test_conv3d_transpose_valid(): |
| Conv3DTransposeModule = _make_conv3d_transpose_module( |
| (1, 8, 8, 8, 3), (3, 3, 3, 8, 3), (1, 1, 1, 1, 1), "VALID" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), |
| kernel: R.Tensor((3, 3, 3, 8, 3), dtype="float32"), |
| ) -> R.Tensor((1, 10, 10, 10, 8), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 10, 10, 10, 8), dtype="float32") = R.nn.conv3d_transpose( |
| data, |
| kernel, |
| strides=[1, 1, 1], |
| padding=[0, 0, 0, 0, 0, 0], |
| output_padding=[0, 0, 0], |
| dilation=[1, 1, 1], |
| groups=1, |
| data_layout="NDHWC", |
| kernel_layout="DHWOI", |
| out_layout="NDHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv3DTransposeModule, Expected) |
| |
| |
| def test_conv3d_transpose_same(): |
| Conv3DTransposeModule = _make_conv3d_transpose_module( |
| (1, 8, 8, 8, 3), (3, 3, 3, 8, 3), (1, 1, 1, 1, 1), "SAME" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), |
| kernel: R.Tensor((3, 3, 3, 8, 3), dtype="float32"), |
| ) -> R.Tensor((1, 8, 8, 8, 8), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 8, 8, 8, 8), dtype="float32") = R.nn.conv3d_transpose( |
| data, |
| kernel, |
| strides=[1, 1, 1], |
| padding=[1, 1, 1, 1, 1, 1], |
| output_padding=[0, 0, 0], |
| dilation=[1, 1, 1], |
| groups=1, |
| data_layout="NDHWC", |
| kernel_layout="DHWOI", |
| out_layout="NDHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Conv3DTransposeModule, Expected) |
| |
| |
| def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding): |
| class Pool2DModule(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=data_shape, dtype=tf.float32), |
| ] |
| ) |
| def func(self, data): |
| return pool( |
| input=data, |
| ksize=ksize, |
| data_format=data_format, |
| strides=strides, |
| padding=padding, |
| ) |
| |
| return Pool2DModule |
| |
| |
| def test_avg_pool2d_same(): |
| Pool2DModule = _make_pool2d_module( |
| tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 128, 128, 32), dtype="float32"), |
| ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.avg_pool2d( |
| data, |
| pool_size=[2, 2], |
| strides=[1, 1], |
| dilation=[1, 1], |
| padding=[0, 0, 1, 1], |
| ceil_mode=False, |
| count_include_pad=False, |
| layout="NHWC", |
| out_layout="NHWC", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Pool2DModule, Expected) |
| |
| |
| def test_avg_pool2d_valid(): |
| Pool2DModule = _make_pool2d_module( |
| tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" |
| ) |
| verify(Pool2DModule) |
| |
| |
| def test_max_pool2d_same(): |
| Pool2DModule = _make_pool2d_module( |
| tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME" |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 128, 128, 32), dtype="float32"), |
| ) -> R.Tensor((1, 128, 128, 32), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.nn.max_pool2d( |
| data, |
| pool_size=[2, 2], |
| strides=[1, 1], |
| dilation=[1, 1], |
| padding=[0, 0, 1, 1], |
| ceil_mode=False, |
| layout="NHWC", |
| out_layout="NHWC", |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Pool2DModule, Expected) |
| |
| |
| def test_max_pool2d_valid(): |
| Pool2DModule = _make_pool2d_module( |
| tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID" |
| ) |
| verify(Pool2DModule) |
| |
| |
| @pytest.mark.parametrize( |
| "net, shape", |
| [ |
| # Limiting the tests for CI |
| (keras_app.Xception, (1, 299, 299, 3)), |
| # (keras_app.VGG16, (1, 224, 224, 3)), |
| # (keras_app.VGG19, (1, 224, 224, 3)), |
| (keras_app.ResNet50, (1, 224, 224, 3)), |
| # (keras_app.ResNet50V2, (1, 224, 224, 3)), |
| # (keras_app.ResNet101, (1, 224, 224, 3)), |
| # (keras_app.ResNet101V2, (1, 224, 224, 3)), |
| # (keras_app.ResNet152, (1, 224, 224, 3)), |
| # (keras_app.ResNet152V2, (1, 224, 224, 3)), |
| (keras_app.InceptionResNetV2, (1, 299, 299, 3)), |
| # (keras_app.MobileNet, (1, 224, 224, 3)), |
| (keras_app.MobileNetV2, (1, 224, 224, 3)), |
| (keras_app.DenseNet121, (1, 224, 224, 3)), |
| # (keras_app.DenseNet169, (1, 224, 224, 3)), |
| # (keras_app.DenseNet201, (1, 224, 224, 3)), |
| (keras_app.NASNetMobile, (1, 224, 224, 3)), |
| # (keras_app.NASNetLarge, (1, 331, 331, 3)), |
| (keras_app.EfficientNetB0, (1, 224, 224, 3)), |
| # (keras_app.EfficientNetB1, (1, 240, 240, 3)), |
| # (keras_app.EfficientNetB2, (1, 260, 260, 3)), |
| # (keras_app.EfficientNetB3, (1, 300, 300, 3)), |
| # (keras_app.EfficientNetB4, (1, 380, 380, 3)), |
| # (keras_app.EfficientNetB5, (1, 456, 456, 3)), |
| # (keras_app.EfficientNetB6, (1, 528, 528, 3)), |
| # (keras_app.EfficientNetB7, (1, 600, 600, 3)), |
| (keras_app.EfficientNetV2B0, (1, 224, 224, 3)), |
| # (keras_app.EfficientNetV2B1, (1, 240, 240, 3)), |
| # (keras_app.EfficientNetV2B2, (1, 260, 260, 3)), |
| # (keras_app.EfficientNetV2B3, (1, 300, 300, 3)), |
| # (keras_app.EfficientNetV2S, (1, 384, 384, 3)), |
| # (keras_app.EfficientNetV2M, (1, 480, 480, 3)), |
| # (keras_app.EfficientNetV2L, (1, 480, 480, 3)), |
| # (keras_app.ConvNeXtTiny, (1, 224, 224, 3)), |
| # (keras_app.ConvNeXtSmall, (1, 224, 224, 3)), |
| # (keras_app.ConvNeXtBase, (1, 224, 224, 3)), |
| # (keras_app.ConvNeXtLarge, (1, 224, 224, 3)), |
| # (keras_app.ConvNeXtXLarge, (1, 224, 224, 3)), |
| ], |
| ) |
| def test_networks(net, shape): |
| # Run network tests only in nightly builds |
| if "CI_ENV_NIGHTLY" not in os.environ: |
| return |
| |
| class NetworkModule(tf.Module): |
| def __init__(self): |
| self.model = net(weights=None, include_top=True) |
| |
| @tf.function |
| def func(self, data): |
| return self.model(data, training=False) |
| |
| model = NetworkModule() |
| concrete_func = model.func.get_concrete_function(tf.TensorSpec(shape=shape, dtype=tf.float32)) |
| |
| verify(concrete_func) |
| |
| |
| def test_broadcast_to(): |
| class Model(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 2), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.broadcast_to(x, [3, 2, 2]) |
| |
| verify(Model) |
| |
| class ModelScalarAndInt(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.int32)]) |
| def func(self, x): |
| return tf.broadcast_to(x, [4, 4]) |
| |
| verify(ModelScalarAndInt) |
| |
| |
| def test_embedding_lookup(): |
| class Model(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.int32)]) |
| def func(self, indices): |
| params = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.float32) |
| return tf.nn.embedding_lookup(params, indices) |
| |
| verify(Model) |
| |
| class ModelMultidim(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.int32)]) |
| def func(self, indices): |
| params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=tf.float32) |
| return tf.nn.embedding_lookup(params, indices) |
| |
| verify(ModelMultidim) |
| |
| |
| def test_select_v2(): |
| class Model(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 2), dtype=tf.bool), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| ] |
| ) |
| def func(self, condition, x, y): |
| return tf.where(condition, x, y) |
| |
| verify(Model) |
| |
| class ModelBroadcasting(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 1), dtype=tf.bool), |
| tf.TensorSpec(shape=(2, 2), dtype=tf.float32), |
| tf.TensorSpec(shape=(), dtype=tf.float32), |
| ] |
| ) |
| def func(self, condition, x, y): |
| return tf.where(condition, x, y) |
| |
| verify(ModelBroadcasting) |
| |
| |
| def test_scatter_nd(): |
| class Model(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(4, 1), dtype=tf.int32), |
| tf.TensorSpec(shape=(4,), dtype=tf.float32), |
| tf.TensorSpec(shape=(1,), dtype=tf.int32), |
| ] |
| ) |
| def func(self, indices, updates, shape): |
| return tf.scatter_nd(indices, updates, shape) |
| |
| verify(Model) |
| |
| |
| def test_segment_sum(): |
| """SEGMENT_SUM lowers to scatter_nd with add reduction.""" |
| |
| class Model(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) |
| def func(self, data): |
| return tf.raw_ops.SegmentSum( |
| data=data, segment_ids=tf.constant([0, 0, 1, 2], dtype=tf.int32) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((3, 2), dtype="float32") = R.zeros(R.shape([3, 2]), dtype="float32") |
| lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( |
| R.const([0, 0, 1, 2], "int32"), axis=[1] |
| ) |
| gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="add") |
| R.output(gv) |
| return gv |
| |
| verify(Model, Expected) |
| |
| |
| def test_unsorted_segment_min(): |
| """UNSORTED_SEGMENT_MIN lowers to scatter_nd with min reduction.""" |
| |
| class Model(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) |
| def func(self, data): |
| return tf.raw_ops.UnsortedSegmentMin( |
| data=data, |
| segment_ids=tf.constant([2, 0, 2, 1], dtype=tf.int32), |
| num_segments=tf.constant(3, dtype=tf.int32), |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((3, 2), dtype="float32") = R.full( |
| R.shape([3, 2]), R.const(np.finfo(np.float32).max, "float32"), dtype="float32" |
| ) |
| lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( |
| R.const([2, 0, 2, 1], "int32"), axis=[1] |
| ) |
| gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="min") |
| R.output(gv) |
| return gv |
| |
| verify(Model, Expected) |
| |
| |
| def test_unsorted_segment_prod(): |
| """UNSORTED_SEGMENT_PROD lowers to scatter_nd with mul reduction.""" |
| |
| class Model(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(4, 2), dtype=tf.float32)]) |
| def func(self, data): |
| return tf.raw_ops.UnsortedSegmentProd( |
| data=data, |
| segment_ids=tf.constant([1, 0, 1, 2], dtype=tf.int32), |
| num_segments=tf.constant(3, dtype=tf.int32), |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((3, 2), dtype="float32") = R.full( |
| R.shape([3, 2]), R.const(1, "float32"), dtype="float32" |
| ) |
| lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( |
| R.const([1, 0, 1, 2], "int32"), axis=[1] |
| ) |
| gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="mul") |
| R.output(gv) |
| return gv |
| |
| verify(Model, Expected) |
| |
| |
| def test_batch_matmul(): |
| class BatchMatMul(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 4, 5), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| return tf.matmul(x, y) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 3, 4), dtype="float32"), |
| y: R.Tensor((2, 4, 5), dtype="float32"), |
| ) -> R.Tensor((2, 3, 5), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y, out_dtype="void") |
| gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv, R.shape([2, 3, 5])) |
| R.output(gv) |
| return gv |
| |
| verify(BatchMatMul, Expected) |
| |
| |
| def test_batch_matmul_adj(): |
| class BatchMatMulAdj(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 4, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 5, 4), dtype=tf.float32), |
| ] |
| ) |
| def func(self, x, y): |
| return tf.matmul(x, y, transpose_a=True, transpose_b=True) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 4, 3), dtype="float32"), |
| y: R.Tensor((2, 5, 4), dtype="float32"), |
| ) -> R.Tensor((2, 3, 5), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x, axes=[0, 2, 1]) |
| lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y, axes=[0, 2, 1]) |
| lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1, out_dtype="void") |
| gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2, R.shape([2, 3, 5])) |
| R.output(gv) |
| return gv |
| |
| verify(BatchMatMulAdj, Expected) |
| |
| |
| def _verify_nms_v4(mod, tf_func, boxes_np, scores_np): |
| """E2E verify for NMS V4: only run on nightly, compare valid outputs only.""" |
| if "CI_ENV_NIGHTLY" not in os.environ: |
| return |
| |
| tf_indices, tf_valid = tf_func(tf.constant(boxes_np), tf.constant(scores_np)) |
| n_valid = int(tf_valid.numpy()) |
| |
| tgt = tvm.target.Target("llvm") |
| ex = tvm.compile(mod, tgt) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| vm.set_input("main", boxes_np, scores_np) |
| vm.invoke_stateful("main") |
| tvm_indices, tvm_valid = vm.get_outputs("main") |
| |
| assert int(tvm_valid.numpy()) == n_valid |
| np.testing.assert_array_equal( |
| tf_indices.numpy()[:n_valid], |
| tvm_indices.numpy()[:n_valid], |
| ) |
| |
| |
| def _build_nms_v4_mod(num_boxes, max_output_size, iou_threshold, score_threshold): |
| """Convert a NonMaxSuppressionV4 TFLite model to a Relax module. |
| |
| Scalar params must be Python literals (not tf.constant) so TFLite can |
| statically infer output shapes during conversion. |
| """ |
| |
| class NMSv4Module(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(num_boxes, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(num_boxes,), dtype=tf.float32), |
| ] |
| ) |
| def func(self, boxes, scores): |
| indices, valid = tf.raw_ops.NonMaxSuppressionV4( |
| boxes=boxes, |
| scores=scores, |
| max_output_size=max_output_size, |
| iou_threshold=iou_threshold, |
| score_threshold=score_threshold, |
| pad_to_max_output_size=True, |
| ) |
| return indices, valid |
| |
| instance = NMSv4Module() |
| cf = instance.func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| return mod, instance.func |
| |
| |
| def _verify_nms_v5(mod, tf_func, boxes_np, scores_np, soft_nms_sigma=0.0): |
| """E2E verify for NMS: only run on nightly, compare valid outputs only.""" |
| if "CI_ENV_NIGHTLY" not in os.environ: |
| return |
| |
| tf_indices, tf_scores, tf_valid = tf_func(tf.constant(boxes_np), tf.constant(scores_np)) |
| n_valid = int(tf_valid.numpy()) |
| |
| tgt = tvm.target.Target("llvm") |
| ex = tvm.compile(mod, tgt) |
| vm = relax.VirtualMachine(ex, tvm.cpu()) |
| vm.set_input("main", boxes_np, scores_np) |
| vm.invoke_stateful("main") |
| tvm_indices, tvm_scores, tvm_valid = vm.get_outputs("main") |
| |
| assert int(tvm_valid.numpy()) == n_valid |
| np.testing.assert_array_equal( |
| tf_indices.numpy()[:n_valid], |
| tvm_indices.numpy()[:n_valid], |
| ) |
| np.testing.assert_allclose( |
| tf_scores.numpy()[:n_valid], |
| tvm_scores.numpy()[:n_valid], |
| rtol=1e-5, |
| atol=1e-5, |
| ) |
| if soft_nms_sigma > 0.0: |
| np.testing.assert_allclose( |
| tf_scores.numpy(), |
| tvm_scores.numpy(), |
| rtol=1e-5, |
| atol=1e-5, |
| ) |
| np.testing.assert_array_less(-1e-6, tvm_scores.numpy()[n_valid:]) |
| |
| |
| def _build_nms_v5_mod( |
| num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma=0.0 |
| ): |
| """Convert a NonMaxSuppressionV5 TFLite model to a Relax module. |
| |
| Scalar params must be Python literals (not tf.constant) so TFLite can |
| statically infer output shapes during conversion. |
| """ |
| |
| class NMSv5Module(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(num_boxes, 4), dtype=tf.float32), |
| tf.TensorSpec(shape=(num_boxes,), dtype=tf.float32), |
| ] |
| ) |
| def func(self, boxes, scores): |
| indices, out_scores, valid = tf.raw_ops.NonMaxSuppressionV5( |
| boxes=boxes, |
| scores=scores, |
| max_output_size=max_output_size, |
| iou_threshold=iou_threshold, |
| score_threshold=score_threshold, |
| soft_nms_sigma=soft_nms_sigma, |
| pad_to_max_output_size=True, |
| ) |
| return indices, out_scores, valid |
| |
| instance = NMSv5Module() |
| cf = instance.func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| return mod, instance.func |
| |
| |
| class _StubDetectionPostprocessTensor: |
| def __init__(self, shape, name): |
| self._shape = list(shape) |
| self._name = name |
| |
| def Shape(self, index): |
| return self._shape[index] |
| |
| def Name(self): |
| return self._name |
| |
| def Type(self): |
| return 0 |
| |
| |
| class _StubDetectionPostprocessOp: |
| def __init__(self, custom_options): |
| self._custom_options = _encode_detection_postprocess_custom_options(custom_options) |
| |
| def CustomOptionsAsNumpy(self): |
| return np.frombuffer(self._custom_options, dtype="uint8") |
| |
| |
| _DETECTION_POSTPROCESS_ANCHORS = np.array( |
| [ |
| [0.5, 0.5, 1.0, 1.0], |
| [0.5, 0.2, 1.0, 1.0], |
| [0.1, 0.1, 0.5, 0.5], |
| [0.8, 0.8, 0.2, 0.2], |
| ], |
| dtype="float32", |
| ) |
| |
| |
| def _encode_detection_postprocess_custom_options(custom_options): |
| from flatbuffers import flexbuffers |
| |
| builder = flexbuffers.Builder() |
| with builder.Map(): |
| for key, value in custom_options.items(): |
| if isinstance(value, bool): |
| builder.Bool(key, value) |
| elif isinstance(value, int): |
| builder.Int(key, value) |
| else: |
| builder.Float(key, float(value)) |
| return bytes(builder.Finish()) |
| |
| |
| def _make_detection_postprocess_tensor_wrapper(tensor_idx, shape, name): |
| return tflite_frontend.TensorWrapper( |
| tensor_idx, |
| _StubDetectionPostprocessTensor(shape, name), |
| None, |
| ) |
| |
| |
| def _build_detection_postprocess_mod( |
| *, |
| num_classes=1, |
| max_detections=4, |
| detections_per_class=4, |
| use_regular_nms=False, |
| nms_iou_threshold=0.5, |
| nms_score_threshold=0.3, |
| x_scale=10.0, |
| y_scale=10.0, |
| w_scale=5.0, |
| h_scale=5.0, |
| batch_size=2, |
| num_anchors=4, |
| input_num_classes=None, |
| ): |
| custom_options = { |
| "num_classes": num_classes, |
| "max_detections": max_detections, |
| "detections_per_class": detections_per_class, |
| "nms_iou_threshold": nms_iou_threshold, |
| "nms_score_threshold": nms_score_threshold, |
| "x_scale": x_scale, |
| "y_scale": y_scale, |
| "w_scale": w_scale, |
| "h_scale": h_scale, |
| "use_regular_nms": use_regular_nms, |
| } |
| return _convert_detection_postprocess_with_options( |
| custom_options, |
| batch_size=batch_size, |
| num_anchors=num_anchors, |
| num_classes=num_classes, |
| input_num_classes=input_num_classes, |
| ) |
| |
| |
| def _convert_detection_postprocess_with_options( |
| custom_options, |
| *, |
| batch_size=2, |
| num_anchors=4, |
| num_classes=1, |
| input_num_classes=None, |
| build_module=True, |
| ): |
| input_num_classes = num_classes if input_num_classes is None else input_num_classes |
| loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors, 4), "float32")) |
| cls = relax.Var( |
| "cls", relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32") |
| ) |
| inputs = [ |
| _make_detection_postprocess_tensor_wrapper(0, (batch_size, num_anchors, 4), "loc"), |
| _make_detection_postprocess_tensor_wrapper( |
| 1, (batch_size, num_anchors, input_num_classes), "cls" |
| ), |
| _make_detection_postprocess_tensor_wrapper(2, (num_anchors, 4), "anchors"), |
| ] |
| converter = tflite_frontend.OperatorConverter.__new__(tflite_frontend.OperatorConverter) |
| converter.bb = relax.BlockBuilder() |
| converter.exp_tab = tflite_frontend.ExprTable() |
| converter.get_input_tensors = lambda op: inputs |
| converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx] |
| converter.get_tensor_value = lambda tensor: ( |
| _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx == 2 else None |
| ) |
| converter.get_tensor_type_str = lambda tensor_type: "float32" |
| op = _StubDetectionPostprocessOp(custom_options) |
| if not build_module: |
| return converter.convert_detection_postprocess(op) |
| bb = converter.bb |
| with bb.function("main", [loc, cls]): |
| with bb.dataflow(): |
| output = converter.convert_detection_postprocess(op) |
| gv = bb.emit_output(output) |
| bb.emit_func_output(gv) |
| return bb.get() |
| |
| |
| def _make_valid_boxes(rng, n): |
| """Generate n random boxes with y1<=y2, x1<=x2 using the given RNG.""" |
| raw = rng.random((n, 4), dtype=np.float32) |
| return np.stack( |
| [ |
| np.minimum(raw[:, 0], raw[:, 2]), # y1 |
| np.minimum(raw[:, 1], raw[:, 3]), # x1 |
| np.maximum(raw[:, 0], raw[:, 2]), # y2 |
| np.maximum(raw[:, 1], raw[:, 3]), # x2 |
| ], |
| axis=1, |
| ).astype(np.float32) |
| |
| |
| _NMS_V5_CASES = [ |
| pytest.param( |
| 6, |
| 3, |
| 0.5, |
| 0.0, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.1, 1.0, 1.1], |
| [0.0, 0.0, 1.0, 0.9], |
| [0.5, 0.5, 1.5, 1.5], |
| [0.0, 0.0, 0.3, 0.3], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), |
| id="basic", |
| ), |
| pytest.param( |
| 8, |
| 4, |
| 0.5, |
| 0.4, |
| _make_valid_boxes(np.random.default_rng(42), 8), |
| np.random.default_rng(42).random(8, dtype=np.float32), |
| id="score_threshold", |
| ), |
| pytest.param( |
| 5, |
| 3, |
| 0.5, |
| 0.99, |
| _make_valid_boxes(np.random.default_rng(0), 5), |
| np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32), |
| id="all_suppressed", |
| ), |
| pytest.param( |
| 6, |
| 6, |
| 0.1, |
| 0.0, |
| np.array( |
| [ |
| [0.0, 0.0, 0.4, 0.4], |
| [0.5, 0.5, 0.9, 0.9], |
| [0.1, 0.1, 0.5, 0.5], |
| [0.6, 0.6, 1.0, 1.0], |
| [0.0, 0.5, 0.4, 0.9], |
| [0.5, 0.0, 0.9, 0.4], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.85, 0.7, 0.65, 0.6, 0.55], dtype=np.float32), |
| id="iou_threshold", |
| ), |
| pytest.param( |
| 4, |
| 10, |
| 0.5, |
| 0.0, |
| np.array( |
| [ |
| [0.0, 0.0, 0.3, 0.3], |
| [0.5, 0.5, 0.8, 0.8], |
| [0.1, 0.1, 0.4, 0.4], |
| [0.6, 0.6, 0.9, 0.9], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32), |
| id="max_output_size_larger_than_boxes", |
| ), |
| ] |
| |
| |
| _NMS_V5_SOFT_CASES = [ |
| pytest.param( |
| 6, |
| 6, |
| 0.5, |
| 0.0, |
| 0.5, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.1, 1.0, 1.1], |
| [0.0, 0.0, 1.0, 0.9], |
| [0.5, 0.5, 1.5, 1.5], |
| [0.0, 0.0, 0.3, 0.3], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), |
| id="soft_nms_basic", |
| ), |
| pytest.param( |
| 5, |
| 5, |
| 0.5, |
| 0.0, |
| 0.3, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.1, 0.1, 1.1, 1.1], |
| [0.2, 0.2, 1.2, 1.2], |
| [0.3, 0.3, 1.3, 1.3], |
| [2.0, 2.0, 3.0, 3.0], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.8, 0.7, 0.6, 0.5], dtype=np.float32), |
| id="soft_nms_tight_sigma", |
| ), |
| pytest.param( |
| 3, |
| 3, |
| 0.5, |
| 0.3, |
| 0.1, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.2, 0.2, 1.2, 1.2], |
| [2.0, 2.0, 3.0, 3.0], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.8, 0.75], dtype=np.float32), |
| id="soft_nms_threshold_hole", |
| ), |
| pytest.param( |
| 3, |
| 3, |
| 0.5, |
| 0.0, |
| 0.1, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.2, 0.2, 1.2, 1.2], |
| [2.0, 2.0, 3.0, 3.0], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.85, 0.8], dtype=np.float32), |
| id="soft_nms_reorder", |
| ), |
| ] |
| |
| |
| @pytest.mark.parametrize( |
| "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores", |
| _NMS_V5_CASES, |
| ) |
| def test_nms_v5(num_boxes, max_output_size, iou_threshold, score_threshold, boxes, scores): |
| """NON_MAX_SUPPRESSION_V5: conversion smoke test + E2E correctness (nightly only).""" |
| mod, tf_func = _build_nms_v5_mod(num_boxes, max_output_size, iou_threshold, score_threshold) |
| _verify_nms_v5(mod, tf_func, boxes, scores) |
| |
| |
| @pytest.mark.parametrize( |
| "num_boxes,max_output_size,iou_threshold,score_threshold,soft_nms_sigma,boxes,scores", |
| _NMS_V5_SOFT_CASES, |
| ) |
| def test_nms_v5_soft( |
| num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma, boxes, scores |
| ): |
| """NON_MAX_SUPPRESSION_V5 with soft_nms_sigma: conversion smoke test + E2E correctness.""" |
| mod, tf_func = _build_nms_v5_mod( |
| num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma |
| ) |
| _verify_nms_v5(mod, tf_func, boxes, scores, soft_nms_sigma=soft_nms_sigma) |
| |
| |
| def test_nms_v5_ir(): |
| """Verify the emitted Relax IR has correct structure for NON_MAX_SUPPRESSION_V5.""" |
| num_boxes = 6 |
| max_output_size = 3 |
| mod, _ = _build_nms_v5_mod( |
| num_boxes=num_boxes, |
| max_output_size=max_output_size, |
| iou_threshold=0.5, |
| score_threshold=0.0, |
| ) |
| |
| ir = mod.script() |
| |
| # Validate correct sorting/id indices are passed to valid_counts |
| assert "score_index=0" in ir |
| assert "id_index=-1" in ir |
| # NMS size limit validation |
| assert f"max_output_size={max_output_size}" in ir |
| # Valid output shape must be () statically |
| assert 'R.Tensor((), dtype="int32")' in ir |
| # Bounding boxes / scores tensor bounds checks |
| assert f"R.Tensor(({max_output_size},)" in ir |
| |
| |
| def test_nms_v5_soft_ir(): |
| """Verify the emitted Relax IR passes soft_nms_sigma for NON_MAX_SUPPRESSION_V5.""" |
| num_boxes = 6 |
| max_output_size = 3 |
| mod, _ = _build_nms_v5_mod( |
| num_boxes=num_boxes, |
| max_output_size=max_output_size, |
| iou_threshold=0.5, |
| score_threshold=0.0, |
| soft_nms_sigma=0.5, |
| ) |
| |
| ir = mod.script() |
| |
| # soft_nms_sigma must appear in the IR |
| assert "soft_nms_sigma=0.5" in ir |
| # score_threshold must also be forwarded |
| assert "score_threshold=0.0" in ir |
| # Soft-NMS padded scores must be clipped to non-negative values. |
| assert "R.clip(" in ir |
| |
| |
| _NMS_V4_CASES = [ |
| pytest.param( |
| 6, |
| 3, |
| 0.5, |
| 0.0, |
| np.array( |
| [ |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.0, 1.0, 1.0], |
| [0.0, 0.1, 1.0, 1.1], |
| [0.0, 0.0, 1.0, 0.9], |
| [0.5, 0.5, 1.5, 1.5], |
| [0.0, 0.0, 0.3, 0.3], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), |
| id="basic", |
| ), |
| pytest.param( |
| 8, |
| 4, |
| 0.5, |
| 0.4, |
| _make_valid_boxes(np.random.default_rng(42), 8), |
| np.random.default_rng(42).random(8, dtype=np.float32), |
| id="score_threshold", |
| ), |
| pytest.param( |
| 5, |
| 3, |
| 0.5, |
| 0.99, |
| _make_valid_boxes(np.random.default_rng(0), 5), |
| np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32), |
| id="all_suppressed", |
| ), |
| pytest.param( |
| 4, |
| 10, |
| 0.5, |
| 0.0, |
| np.array( |
| [ |
| [0.0, 0.0, 0.3, 0.3], |
| [0.5, 0.5, 0.8, 0.8], |
| [0.1, 0.1, 0.4, 0.4], |
| [0.6, 0.6, 0.9, 0.9], |
| ], |
| dtype=np.float32, |
| ), |
| np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32), |
| id="max_output_size_larger_than_boxes", |
| ), |
| ] |
| |
| |
| @pytest.mark.parametrize( |
| "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores", |
| _NMS_V4_CASES, |
| ) |
| def test_nms_v4(num_boxes, max_output_size, iou_threshold, score_threshold, boxes, scores): |
| """NON_MAX_SUPPRESSION_V4: conversion smoke test + E2E correctness (nightly only).""" |
| mod, tf_func = _build_nms_v4_mod(num_boxes, max_output_size, iou_threshold, score_threshold) |
| _verify_nms_v4(mod, tf_func, boxes, scores) |
| |
| |
| def test_nms_v4_ir(): |
| """Verify the emitted Relax IR has correct structure for NON_MAX_SUPPRESSION_V4.""" |
| num_boxes = 6 |
| max_output_size = 3 |
| mod, _ = _build_nms_v4_mod( |
| num_boxes=num_boxes, |
| max_output_size=max_output_size, |
| iou_threshold=0.5, |
| score_threshold=0.0, |
| ) |
| |
| ir = mod.script() |
| |
| # Validate correct sorting/id indices are passed to valid_counts |
| assert "score_index=0" in ir |
| assert "id_index=-1" in ir |
| # NMS size limit validation |
| assert f"max_output_size={max_output_size}" in ir |
| # Valid output shape must be () statically |
| assert 'R.Tensor((), dtype="int32")' in ir |
| # Selected indices tensor bounds check |
| assert f"R.Tensor(({max_output_size},)" in ir |
| # V4 must use hard-NMS (soft_nms_sigma left at default 0.0) |
| assert "soft_nms_sigma=0.0" in ir |
| |
| |
| _DETECTION_POSTPROCESS_SMOKE_CASES = [ |
| pytest.param( |
| { |
| "num_classes": 2, |
| "input_num_classes": 3, |
| "max_detections": 2, |
| "detections_per_class": 2, |
| "use_regular_nms": False, |
| "nms_iou_threshold": 0.5, |
| "nms_score_threshold": 0.5, |
| "batch_size": 1, |
| "num_anchors": 4, |
| }, |
| 2, |
| False, |
| id="basic_fast_nms", |
| ), |
| pytest.param( |
| { |
| "num_classes": 2, |
| "input_num_classes": 3, |
| "max_detections": 3, |
| "detections_per_class": 2, |
| "use_regular_nms": True, |
| "nms_iou_threshold": 0.45, |
| "nms_score_threshold": 0.25, |
| "batch_size": 2, |
| "num_anchors": 4, |
| }, |
| 1, |
| True, |
| id="regular_nms_multi_batch", |
| ), |
| ] |
| |
| |
| _DETECTION_POSTPROCESS_SHAPE_CASES = [ |
| pytest.param( |
| { |
| "num_classes": 2, |
| "input_num_classes": 5, |
| "max_detections": 2, |
| "detections_per_class": 2, |
| "use_regular_nms": False, |
| "nms_iou_threshold": 0.5, |
| "nms_score_threshold": 0.5, |
| "batch_size": 1, |
| "num_anchors": 4, |
| }, |
| id="wider_input_classes", |
| ), |
| pytest.param( |
| { |
| "num_classes": 2, |
| "input_num_classes": 3, |
| "max_detections": 4, |
| "detections_per_class": 4, |
| "use_regular_nms": False, |
| "nms_iou_threshold": 0.5, |
| "nms_score_threshold": 0.5, |
| "batch_size": 1, |
| "num_anchors": 4, |
| }, |
| id="larger_max_detections", |
| ), |
| ] |
| |
| |
| @pytest.mark.parametrize( |
| "build_kwargs,expected_topk_count,expected_keep_background", |
| _DETECTION_POSTPROCESS_SMOKE_CASES, |
| ) |
| def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, expected_keep_background): |
| mod = _build_detection_postprocess_mod(**build_kwargs) |
| ir = mod.script() |
| |
| assert "R.vision.multibox_transform_loc" in ir |
| assert "R.vision.all_class_non_max_suppression" in ir |
| assert 'output_format="tensorflow"' in ir |
| assert "R.where" in ir |
| assert "R.gather_elements" in ir |
| assert "R.gather_nd" in ir |
| assert ir.count("R.topk(") == expected_topk_count |
| assert f"keep_background={expected_keep_background}" in ir |
| expected_batch = build_kwargs["batch_size"] |
| expected_max_detections = build_kwargs["max_detections"] |
| tvm.ir.assert_structural_equal( |
| mod["main"].ret_struct_info, |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((expected_batch, expected_max_detections, 4), "float32"), |
| relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), |
| relax.TensorStructInfo((expected_batch, expected_max_detections), "float32"), |
| relax.TensorStructInfo((expected_batch,), "float32"), |
| ] |
| ), |
| ) |
| |
| legalized = relax.transform.LegalizeOps()(mod) |
| legalized_ir = legalized.script() |
| assert "R.vision.all_class_non_max_suppression(" not in legalized_ir |
| assert "R.call_tir(" in legalized_ir |
| tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info, mod["main"].ret_struct_info) |
| |
| |
| @pytest.mark.parametrize("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES) |
| def test_detection_postprocess_shape_variations(build_kwargs): |
| mod = _build_detection_postprocess_mod(**build_kwargs) |
| batch_size = build_kwargs["batch_size"] |
| num_anchors = build_kwargs["num_anchors"] |
| input_num_classes = build_kwargs["input_num_classes"] |
| max_detections = build_kwargs["max_detections"] |
| |
| tvm.ir.assert_structural_equal( |
| mod["main"].params[1].struct_info, |
| relax.TensorStructInfo((batch_size, num_anchors, input_num_classes), "float32"), |
| ) |
| tvm.ir.assert_structural_equal( |
| mod["main"].ret_struct_info, |
| relax.TupleStructInfo( |
| [ |
| relax.TensorStructInfo((batch_size, max_detections, 4), "float32"), |
| relax.TensorStructInfo((batch_size, max_detections), "float32"), |
| relax.TensorStructInfo((batch_size, max_detections), "float32"), |
| relax.TensorStructInfo((batch_size,), "float32"), |
| ] |
| ), |
| ) |
| |
| |
| def _make_resize_expected( |
| input_shape, output_size, method, coordinate_transformation_mode, rounding_method |
| ): |
| """Build an Expected IRModule programmatically to avoid TVMScript variable scope limitations.""" |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", relax.TensorStructInfo(input_shape, "float32")) |
| with bb.function("main", [x]): |
| with bb.dataflow(): |
| gv = bb.emit_output( |
| relax.op.image.resize2d( |
| x, |
| size=relax.ShapeExpr([output_size[0], output_size[1]]), |
| roi=[0.0, 0.0, 0.0, 0.0], |
| layout="NHWC", |
| method=method, |
| coordinate_transformation_mode=coordinate_transformation_mode, |
| rounding_method=rounding_method, |
| cubic_alpha=-0.75, |
| cubic_exclude=0, |
| extrapolation_value=0.0, |
| out_dtype="void", |
| ) |
| ) |
| bb.emit_func_output(gv) |
| mod = bb.get() |
| mod["main"] = mod["main"].with_attr("num_input", 1) |
| return mod |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, output_size, tf_op, coordinate_transformation_mode", |
| [ |
| ( |
| (1, 4, 4, 1), |
| [8, 8], |
| lambda x: tf.image.resize(x, [8, 8], method="bilinear"), |
| "half_pixel", |
| ), |
| ( |
| (1, 8, 8, 3), |
| [4, 4], |
| lambda x: tf.image.resize(x, [4, 4], method="bilinear"), |
| "half_pixel", |
| ), |
| ( |
| (1, 4, 4, 1), |
| [7, 7], |
| lambda x: tf.compat.v1.image.resize_bilinear(x, [7, 7], align_corners=True), |
| "align_corners", |
| ), |
| ( |
| (1, 4, 4, 2), |
| [8, 8], |
| lambda x: tf.compat.v1.image.resize_bilinear(x, [8, 8], half_pixel_centers=True), |
| "half_pixel", |
| ), |
| ( |
| (2, 6, 6, 16), |
| [12, 12], |
| lambda x: tf.image.resize(x, [12, 12], method="bilinear"), |
| "half_pixel", |
| ), |
| ( |
| (1, 5, 5, 3), |
| [5, 5], |
| lambda x: tf.image.resize(x, [5, 5], method="bilinear"), |
| "half_pixel", |
| ), |
| ( |
| (1, 4, 8, 1), |
| [8, 16], |
| lambda x: tf.image.resize(x, [8, 16], method="bilinear"), |
| "half_pixel", |
| ), |
| ], |
| ) |
| def test_resize_bilinear(input_shape, output_size, tf_op, coordinate_transformation_mode): |
| class ResizeBilinear(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) |
| def func(self, x): |
| return tf_op(x) |
| |
| expected = _make_resize_expected( |
| input_shape, output_size, "linear", coordinate_transformation_mode, "" |
| ) |
| verify(ResizeBilinear, expected) |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, output_size, tf_op, coordinate_transformation_mode, rounding_method", |
| [ |
| ( |
| (1, 2, 2, 1), |
| [4, 4], |
| lambda x: tf.image.resize(x, [4, 4], method="nearest"), |
| "half_pixel", |
| "round_prefer_ceil", |
| ), |
| ( |
| (1, 8, 8, 3), |
| [4, 4], |
| lambda x: tf.image.resize(x, [4, 4], method="nearest"), |
| "half_pixel", |
| "round_prefer_ceil", |
| ), |
| ( |
| (1, 4, 4, 1), |
| [7, 7], |
| lambda x: tf.compat.v1.image.resize_nearest_neighbor(x, [7, 7], align_corners=True), |
| "align_corners", |
| "", |
| ), |
| ( |
| (4, 3, 3, 8), |
| [6, 6], |
| lambda x: tf.image.resize(x, [6, 6], method="nearest"), |
| "half_pixel", |
| "round_prefer_ceil", |
| ), |
| ( |
| (1, 4, 8, 1), |
| [8, 16], |
| lambda x: tf.image.resize(x, [8, 16], method="nearest"), |
| "half_pixel", |
| "round_prefer_ceil", |
| ), |
| ( |
| (1, 3, 3, 2), |
| [3, 3], |
| lambda x: tf.image.resize(x, [3, 3], method="nearest"), |
| "half_pixel", |
| "round_prefer_ceil", |
| ), |
| ], |
| ) |
| def test_resize_nearest_neighbor( |
| input_shape, output_size, tf_op, coordinate_transformation_mode, rounding_method |
| ): |
| class ResizeNearest(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) |
| def func(self, x): |
| return tf_op(x) |
| |
| expected = _make_resize_expected( |
| input_shape, |
| output_size, |
| "nearest_neighbor", |
| coordinate_transformation_mode, |
| rounding_method, |
| ) |
| verify(ResizeNearest, expected) |
| |
| |
| def _make_reduce_expected(relax_op, input_shape, axes, keepdims, dtype): |
| if axes is None: |
| axes = list(range(len(input_shape))) |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", relax.TensorStructInfo(input_shape, dtype)) |
| with bb.function("main", [x]): |
| with bb.dataflow(): |
| gv = bb.emit_output(relax_op(x, axis=axes, keepdims=keepdims)) |
| bb.emit_func_output(gv) |
| mod = bb.get() |
| mod["main"] = mod["main"].with_attr("num_input", 1) |
| return mod |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.reduce_sum, relax.op.sum), |
| (tf.reduce_mean, relax.op.mean), |
| (tf.reduce_max, relax.op.max), |
| (tf.reduce_min, relax.op.min), |
| (tf.reduce_prod, relax.op.prod), |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "input_shape, axes", |
| [ |
| ((1, 8, 8, 3), 1), |
| ((1, 8, 8, 3), [1, 2]), |
| ((1, 8, 8, 3), -1), |
| ((1, 8, 8, 3), None), |
| ((30,), 0), |
| ((2, 5, 2), [0, 2]), |
| ], |
| ) |
| @pytest.mark.parametrize("keepdims", [True, False]) |
| @pytest.mark.parametrize("dtype", [tf.float32, tf.int32]) |
| def test_reduction_ops(tf_op, relax_op, input_shape, axes, keepdims, dtype): |
| class ReduceModule(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=dtype)]) |
| def func(self, x): |
| return tf_op(x, axis=axes, keepdims=keepdims) |
| |
| relax_dtype = "float32" if dtype == tf.float32 else "int32" |
| expected = _make_reduce_expected(relax_op, input_shape, axes, keepdims, relax_dtype) |
| verify(ReduceModule, expected) |
| |
| |
| def _make_reduce_bool_expected(relax_op, input_shape, axes, keepdims): |
| if axes is None: |
| axes = list(range(len(input_shape))) |
| bb = relax.BlockBuilder() |
| x = relax.Var("x", relax.TensorStructInfo(input_shape, "bool")) |
| with bb.function("main", [x]): |
| with bb.dataflow(): |
| cast_in = bb.emit(relax.op.astype(x, "int8")) |
| reduced = bb.emit(relax_op(cast_in, axis=axes, keepdims=keepdims)) |
| gv = bb.emit_output(relax.op.astype(reduced, "bool")) |
| bb.emit_func_output(gv) |
| mod = bb.get() |
| mod["main"] = mod["main"].with_attr("num_input", 1) |
| return mod |
| |
| |
| @pytest.mark.parametrize( |
| "tf_op, relax_op", |
| [ |
| (tf.reduce_any, relax.op.max), |
| (tf.reduce_all, relax.op.min), |
| ], |
| ) |
| @pytest.mark.parametrize( |
| "input_shape, axes", |
| [ |
| ((1, 8, 8, 3), 1), |
| ((1, 8, 8, 3), [1, 2]), |
| ((1, 8, 8, 3), -1), |
| ((1, 8, 8, 3), None), |
| ((30,), 0), |
| ((2, 5, 2), [0, 2]), |
| ], |
| ) |
| @pytest.mark.parametrize("keepdims", [True, False]) |
| def test_reduction_bool_ops(tf_op, relax_op, input_shape, axes, keepdims): |
| class ReduceBoolModule(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.bool)]) |
| def func(self, x): |
| return tf_op(x, axis=axes, keepdims=keepdims) |
| |
| expected = _make_reduce_bool_expected(relax_op, input_shape, axes, keepdims) |
| verify(ReduceBoolModule, expected) |
| |
| # Regression guard: compile to catch a bool max/min lowering path. |
| tvm.compile(expected, tvm.target.Target("llvm")) |
| |
| |
| def test_pad(): |
| class Pad(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.pad(x, [[1, 1], [2, 2]]) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 7), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((4, 7), dtype="float32") = R.nn.pad( |
| x, pad_width=[1, 1, 2, 2], pad_value=0.0, pad_mode="constant" |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(Pad, Expected) |
| |
| |
| def test_pad_v2(): |
| class PadV2(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.pad(x, [[1, 1], [2, 2]], constant_values=5.0) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((4, 7), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((4, 7), dtype="float32") = R.nn.pad( |
| x, pad_width=[1, 1, 2, 2], pad_value=5.0, pad_mode="constant" |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(PadV2, Expected) |
| |
| |
| def test_mirror_pad(): |
| class MirrorPad(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.pad(x, [[1, 1], [2, 2]], mode="REFLECT") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((5, 8), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((5, 8), dtype="float32") = R.nn.pad( |
| x, pad_width=[1, 1, 2, 2], pad_value=0.0, pad_mode="reflect" |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(MirrorPad, Expected) |
| |
| |
| def test_topk_v2(): |
| class TopKV2(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(5,), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.math.top_k(x, k=3).values |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((5,), dtype="float32")) -> R.Tensor((3,), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tuple(R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="int32")) = ( |
| R.topk(x, k=3, axis=-1, ret_type="both", largest=True, dtype="int32") |
| ) |
| gv: R.Tensor((3,), dtype="float32") = lv[0] |
| R.output(gv) |
| return gv |
| |
| verify(TopKV2, Expected) |
| |
| |
| def test_one_hot(): |
| class OneHot(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.int32)]) |
| def func(self, x): |
| return tf.one_hot(x, depth=4) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((3,), dtype="int32")) -> R.Tensor((3, 4), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((3, 4), dtype="float32") = R.one_hot( |
| x, |
| R.prim_value(T.float32(1.0)), |
| R.prim_value(T.float32(0.0)), |
| depth=4, |
| axis=-1, |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(OneHot, Expected) |
| |
| |
| def test_select(): |
| class Select(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2, 3), dtype=tf.bool), |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(2, 3), dtype=tf.float32), |
| ] |
| ) |
| def func(self, cond, x, y): |
| return tf.where(cond, x, y) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| cond: R.Tensor((2, 3), dtype="bool"), |
| x: R.Tensor((2, 3), dtype="float32"), |
| y: R.Tensor((2, 3), dtype="float32"), |
| ) -> R.Tensor((2, 3), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 3), dtype="float32") = R.where(cond, x, y) |
| R.output(gv) |
| return gv |
| |
| verify(Select, Expected) |
| |
| |
| def test_depth_to_space(): |
| class DepthToSpace(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 2, 4, 8), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.depth_to_space(x, block_size=2) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 2, 4, 8), dtype="float32"), |
| ) -> R.Tensor((1, 4, 8, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 4, 2, 2, 2), dtype="float32") = R.reshape( |
| x, R.shape([1, 2, 4, 2, 2, 2]) |
| ) |
| lv1: R.Tensor((1, 2, 2, 4, 2, 2), dtype="float32") = R.permute_dims( |
| lv, axes=[0, 1, 3, 2, 4, 5] |
| ) |
| gv: R.Tensor((1, 4, 8, 2), dtype="float32") = R.reshape(lv1, R.shape([1, 4, 8, 2])) |
| R.output(gv) |
| return gv |
| |
| verify(DepthToSpace, Expected) |
| |
| |
| def test_space_to_depth(): |
| class SpaceToDepth(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 4, 4, 2), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.space_to_depth(x, block_size=2) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 4, 4, 2), dtype="float32"), |
| ) -> R.Tensor((1, 2, 2, 8), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 2, 2, 2, 2, 2), dtype="float32") = R.reshape( |
| x, R.shape([1, 2, 2, 2, 2, 2]) |
| ) |
| lv1: R.Tensor((1, 2, 2, 2, 2, 2), dtype="float32") = R.permute_dims( |
| lv, axes=[0, 1, 3, 2, 4, 5] |
| ) |
| gv: R.Tensor((1, 2, 2, 8), dtype="float32") = R.reshape(lv1, R.shape([1, 2, 2, 8])) |
| R.output(gv) |
| return gv |
| |
| verify(SpaceToDepth, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, block_shape, paddings, expected_out_shape", |
| [ |
| ((1, 2, 2, 1), [2, 2], [[0, 0], [0, 0]], (4, 1, 1, 1)), |
| ((1, 2, 3, 1), [2, 2], [[0, 0], [1, 0]], (4, 1, 2, 1)), |
| ], |
| ) |
| def test_space_to_batch_nd(input_shape, block_shape, paddings, expected_out_shape): |
| """SPACE_TO_BATCH_ND imports to Relax and preserves expected output shape.""" |
| |
| class SpaceToBatchND(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) |
| def func(self, x): |
| return tf.space_to_batch_nd( |
| x, |
| tf.constant(block_shape, dtype=tf.int32), |
| tf.constant(paddings, dtype=tf.int32), |
| ) |
| |
| cf = SpaceToBatchND().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| |
| assert "space_to_batch_nd" in ir |
| assert len(mod["main"].params) == 1 |
| tvm.ir.assert_structural_equal( |
| mod["main"].ret_struct_info, |
| relax.TensorStructInfo(expected_out_shape, "float32"), |
| ) |
| |
| if "CI_ENV_NIGHTLY" in os.environ: |
| verify(SpaceToBatchND) |
| |
| |
| @pytest.mark.parametrize( |
| "input_shape, block_shape, crops, expected_out_shape", |
| [ |
| ((4, 1, 1, 1), [2, 2], [[0, 0], [0, 0]], (1, 2, 2, 1)), |
| ((4, 1, 2, 1), [2, 2], [[0, 0], [1, 0]], (1, 2, 3, 1)), |
| ], |
| ) |
| def test_batch_to_space_nd(input_shape, block_shape, crops, expected_out_shape): |
| """BATCH_TO_SPACE_ND imports to Relax and preserves expected output shape.""" |
| |
| class BatchToSpaceND(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, dtype=tf.float32)]) |
| def func(self, x): |
| return tf.raw_ops.BatchToSpaceND( |
| input=x, |
| block_shape=tf.constant(block_shape, dtype=tf.int32), |
| crops=tf.constant(crops, dtype=tf.int32), |
| ) |
| |
| cf = BatchToSpaceND().func.get_concrete_function() |
| mod = _get_mod_from_cfunc(cf) |
| ir = mod.script() |
| |
| assert "batch_to_space_nd" in ir |
| assert len(mod["main"].params) == 1 |
| tvm.ir.assert_structural_equal( |
| mod["main"].ret_struct_info, |
| relax.TensorStructInfo(expected_out_shape, "float32"), |
| ) |
| |
| if "CI_ENV_NIGHTLY" in os.environ: |
| verify(BatchToSpaceND) |
| |
| |
| def test_leaky_relu(): |
| class LeakyReLU(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.nn.leaky_relu(x, alpha=0.2) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = R.nn.leakyrelu( |
| x, alpha=0.20000000298023224 |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(LeakyReLU, Expected) |
| |
| |
| def test_hard_swish(): |
| class HardSwish(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return x * tf.nn.relu6(x + 3) / 6 |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 30), dtype="float32") = R.add(x, R.const(3.0, dtype="float32")) |
| lv1: R.Tensor((1, 30), dtype="float32") = R.clip( |
| lv, R.prim_value(T.float64(0.0)), R.prim_value(T.float64(6.0)) |
| ) |
| lv2: R.Tensor((1, 30), dtype="float32") = R.multiply(x, lv1) |
| gv: R.Tensor((1, 30), dtype="float32") = R.divide( |
| lv2, R.const(6.0, dtype="float32") |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(HardSwish, Expected) |
| |
| |
| def test_relu_n1_to_1(): |
| class ReLU_N1_to_1(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return tf.clip_by_value(x, -1.0, 1.0) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 30), dtype="float32") = R.clip(x, min=-1, max=1) |
| R.output(gv) |
| return gv |
| |
| verify(ReLU_N1_to_1, Expected) |
| |
| |
| def test_prelu_basic(): |
| alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, dtype=np.float32)) |
| prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init) |
| |
| class TfInput(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) |
| def func(self, x): |
| return prelu(x) |
| |
| verify(TfInput) |
| |
| |
| @pytest.mark.parametrize( |
| "shared_axes", |
| [ |
| pytest.param([1, 2], id="channelwise_shared_axes"), |
| pytest.param([1, 2, 3], id="scalar_shared_axes"), |
| pytest.param(None, id="elementwise_no_shared_axes"), |
| ], |
| ) |
| def test_prelu(shared_axes): |
| inputs = tf.keras.Input(shape=(4, 4, 3), batch_size=1, dtype=tf.float32) |
| prelu_kwargs = { |
| "alpha_initializer": tf.initializers.constant(0.25), |
| } |
| if shared_axes is not None: |
| prelu_kwargs["shared_axes"] = shared_axes |
| outputs = tf.keras.layers.PReLU(**prelu_kwargs)(inputs) |
| keras_model = tf.keras.Model(inputs, outputs) |
| |
| converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) |
| tflite_model_buf = converter.convert() |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) |
| |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| if shared_axes == [1, 2]: |
| alpha_const = np.full((1, 1, 3), 0.25, dtype=np.float32) |
| elif shared_axes == [1, 2, 3]: |
| alpha_const = np.full((1, 1, 1), 0.25, dtype=np.float32) |
| else: |
| alpha_const = np.full((4, 4, 3), 0.25, dtype=np.float32) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor( |
| (1, 4, 4, 3), dtype="float32" |
| ): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.broadcast_to( |
| R.const(alpha_const), R.shape([1, 4, 4, 3]) |
| ) |
| lv1: R.Tensor((48,), dtype="float32") = R.reshape(x, R.shape([48])) |
| lv2: R.Tensor((48,), dtype="float32") = R.reshape(lv, R.shape([48])) |
| lv3: R.Tensor((48,), dtype="float32") = R.nn.prelu(lv1, lv2, axis=0) |
| gv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.reshape(lv3, R.shape([1, 4, 4, 3])) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_matrix_diag(): |
| """Test TFLite MATRIX_DIAG operator.""" |
| |
| class MatrixDiag(tf.Module): |
| @tf.function(input_signature=[tf.TensorSpec(shape=(3,), dtype=tf.float32)]) |
| def func(self, diagonal): |
| return tf.raw_ops.MatrixDiag(diagonal=diagonal) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(diagonal: R.Tensor((3,), dtype="float32")) -> R.Tensor((3, 3), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 3]), dtype="float32") |
| gv = R.call_dps_packed( |
| "topi.matrix_set_diag", |
| ( |
| lv, |
| diagonal, |
| R.const(0, "int32"), |
| R.const(0, "int32"), |
| R.const(False, "bool"), |
| R.const(False, "bool"), |
| ), |
| out_sinfo=R.Tensor((3, 3), dtype="float32"), |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(MatrixDiag, Expected) |
| |
| |
| def test_matrix_set_diag(): |
| """Test TFLite MATRIX_SET_DIAG operator.""" |
| |
| class MatrixSetDiag(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(3, 3), dtype=tf.float32), |
| tf.TensorSpec(shape=(3,), dtype=tf.float32), |
| ] |
| ) |
| def func(self, input, diagonal): |
| return tf.raw_ops.MatrixSetDiag(input=input, diagonal=diagonal) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| input: R.Tensor((3, 3), dtype="float32"), |
| diagonal: R.Tensor((3,), dtype="float32"), |
| ) -> R.Tensor((3, 3), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv = R.call_dps_packed( |
| "topi.matrix_set_diag", |
| ( |
| input, |
| diagonal, |
| R.const(0, "int32"), |
| R.const(0, "int32"), |
| R.const(False, "bool"), |
| R.const(False, "bool"), |
| ), |
| out_sinfo=R.Tensor((3, 3), dtype="float32"), |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(MatrixSetDiag, Expected) |
| |
| |
| def test_sparse_to_dense(): |
| """Test TFLite SPARSE_TO_DENSE operator.""" |
| |
| class SparseToDense(tf.Module): |
| @tf.function( |
| input_signature=[ |
| tf.TensorSpec(shape=(2,), dtype=tf.int32), |
| tf.TensorSpec(shape=(2,), dtype=tf.float32), |
| tf.TensorSpec(shape=(), dtype=tf.float32), |
| ] |
| ) |
| def func(self, indices, values, default_value): |
| # output_shape is provided as a constant, not an input |
| return tf.raw_ops.SparseToDense( |
| sparse_indices=indices, |
| output_shape=tf.constant([3], dtype=tf.int32), |
| sparse_values=values, |
| default_value=default_value, |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| indices: R.Tensor((2,), dtype="int32"), |
| values: R.Tensor((2,), dtype="float32"), |
| default_value: R.Tensor((), dtype="float32"), |
| ) -> R.Tensor((3,), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| gv = R.call_dps_packed( |
| "topi.sparse_to_dense", |
| (indices, R.const([3], "int32"), values, default_value), |
| out_sinfo=R.Tensor((3,), dtype="float32"), |
| ) |
| R.output(gv) |
| return gv |
| |
| verify(SparseToDense, Expected) |
| |
| |
| # DENSIFY operator tests |
| # DENSIFY converts sparse weight tensors to dense at conversion time (not runtime). |
| # Since TensorFlow does not provide an API to create sparse TFLite models, |
| # we manually build them using the flatbuffers API. |
| |
| |
| # Import schema helpers explicitly. CI's generated tflite package does not |
| # reliably re-export these builder helpers and enums at the package top-level. |
| def _get_tflite_schema_module(module_name): |
| return __import__(f"tflite.{module_name}", fromlist=[module_name]) |
| |
| |
| def _get_tflite_schema_enum(enum_name): |
| return getattr(_get_tflite_schema_module(enum_name), enum_name) |
| |
| |
| _tfl_add_options = _get_tflite_schema_module("AddOptions") |
| _tfl_buffer = _get_tflite_schema_module("Buffer") |
| _tfl_concatenation_options = _get_tflite_schema_module("ConcatenationOptions") |
| _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") |
| _tfl_depthwise_conv2d_options = _get_tflite_schema_module("DepthwiseConv2DOptions") |
| _tfl_dilate_options = _get_tflite_schema_module("DilateOptions") |
| _tfl_reshape_options = _get_tflite_schema_module("ReshapeOptions") |
| _tfl_transpose_conv_options = _get_tflite_schema_module("TransposeConvOptions") |
| |
| # ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── |
| _tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") |
| _tfl_stablehlo_bcast_opts = _get_tflite_schema_module("StablehloBroadcastInDimOptions") |
| _tfl_stablehlo_composite_opts = _get_tflite_schema_module("StableHLOCompositeOptions") |
| _tfl_stablehlo_conv_opts = _get_tflite_schema_module("StablehloConvolutionOptions") |
| _tfl_stablehlo_dot_opts = _get_tflite_schema_module("StablehloDotGeneralOptions") |
| _tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions") |
| _tfl_stablehlo_compare_opts = _get_tflite_schema_module("StablehloCompareOptions") |
| _tfl_stablehlo_comp_dir = _get_tflite_schema_module("StablehloComparisonDirection") |
| _tfl_stablehlo_comp_type = _get_tflite_schema_module("StablehloComparisonType") |
| _tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions") |
| _tfl_stablehlo_dyn_slice_opts = _get_tflite_schema_module("StablehloDynamicSliceOptions") |
| _tfl_stablehlo_gather_opts = _get_tflite_schema_module("StablehloGatherOptions") |
| _tfl_stablehlo_reduce_opts = _get_tflite_schema_module("StablehloReduceOptions") |
| _tfl_stablehlo_reduce_window_opts = _get_tflite_schema_module("StablehloReduceWindowOptions") |
| _tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") |
| _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") |
| _tfl_call_options = _get_tflite_schema_module("CallOptions") |
| _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions") |
| _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") |
| _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") |
| _tfl_if_options = _get_tflite_schema_module("IfOptions") |
| _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") |
| _tfl_model = _get_tflite_schema_module("Model") |
| _tfl_operator = _get_tflite_schema_module("Operator") |
| _tfl_operator_code = _get_tflite_schema_module("OperatorCode") |
| _tfl_quantization_parameters = _get_tflite_schema_module("QuantizationParameters") |
| _tfl_sparsity_parameters = _get_tflite_schema_module("SparsityParameters") |
| _tfl_subgraph = _get_tflite_schema_module("SubGraph") |
| _tfl_tensor = _get_tflite_schema_module("Tensor") |
| _tfl_while_options = _get_tflite_schema_module("WhileOptions") |
| |
| _tfl_builtin_operator = _get_tflite_schema_enum("BuiltinOperator") |
| _tfl_builtin_options = _get_tflite_schema_enum("BuiltinOptions") |
| _tfl_builtin_options2 = _get_tflite_schema_enum("BuiltinOptions2") |
| _tfl_activation_fn = _get_tflite_schema_enum("ActivationFunctionType") |
| _tfl_dimension_type = _get_tflite_schema_enum("DimensionType") |
| _tfl_fc_weights_format = _get_tflite_schema_enum("FullyConnectedOptionsWeightsFormat") |
| _tfl_padding = _get_tflite_schema_enum("Padding") |
| _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") |
| _tfl_tensor_type = _get_tflite_schema_enum("TensorType") |
| |
| _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") |
| |
| _DENSIFY_TEST_VALUES = np.array([1.0, 2.0], dtype=np.float32) |
| _DENSIFY_TEST_DENSE = np.array([[1.0, 0.0], [0.0, 2.0]], dtype=np.float32) |
| _DENSIFY_ROW_PTRS = [0, 1, 2] |
| _DENSIFY_COL_INDICES = [0, 1] |
| _DENSIFY_CONV_KERNEL_DENSE_HWIO = _DENSIFY_TEST_DENSE.reshape(2, 2, 1, 1) |
| _DENSIFY_FC_WEIGHT_VALUES = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32) |
| _DENSIFY_FC_WEIGHT_DENSE_OI = np.diag(_DENSIFY_FC_WEIGHT_VALUES).astype(np.float32) |
| _DENSIFY_FC_ROW_PTRS = [0, 1, 2, 3, 4] |
| _DENSIFY_FC_COL_INDICES = [0, 1, 2, 3] |
| |
| |
| def _tflite_int32_vector(builder, start_vector_fn, values): |
| start_vector_fn(builder, len(values)) |
| for value in reversed(values): |
| builder.PrependInt32(value) |
| return builder.EndVector() |
| |
| |
| def _tflite_int64_vector(builder, start_vector_fn, values): |
| start_vector_fn(builder, len(values)) |
| for value in reversed(values): |
| builder.PrependInt64(value) |
| return builder.EndVector() |
| |
| |
| def _tflite_bool_vector(builder, start_vector_fn, values): |
| start_vector_fn(builder, len(values)) |
| for value in reversed(values): |
| builder.PrependBool(value) |
| return builder.EndVector() |
| |
| |
| def _tflite_float32_vector(builder, start_vector_fn, values): |
| start_vector_fn(builder, len(values)) |
| for value in reversed(values): |
| builder.PrependFloat32(value) |
| return builder.EndVector() |
| |
| |
| def _tflite_offset_vector(builder, start_vector_fn, offsets): |
| start_vector_fn(builder, len(offsets)) |
| for offset in reversed(offsets): |
| builder.PrependUOffsetTRelative(offset) |
| return builder.EndVector() |
| |
| |
| def _tflite_byte_vector(builder, data): |
| _tfl_buffer.BufferStartDataVector(builder, len(data)) |
| for byte in reversed(data): |
| builder.PrependByte(byte) |
| return builder.EndVector() |
| |
| |
| def _tflite_int32_table(builder, values): |
| # Build the values vector directly without relying on version-specific |
| # helper Int32VectorStartValuesVector, which is absent in older |
| # tflite package versions used in CI. |
| builder.StartVector(4, len(values), 4) |
| for value in reversed(values): |
| builder.PrependInt32(value) |
| values_vec = builder.EndVector() |
| _tfl_int32_vector.Int32VectorStart(builder) |
| _tfl_int32_vector.Int32VectorAddValues(builder, values_vec) |
| return _tfl_int32_vector.Int32VectorEnd(builder) |
| |
| |
| def _tflite_shape(builder, shape): |
| return _tflite_int32_vector(builder, _tfl_tensor.TensorStartShapeVector, shape) |
| |
| |
| def _build_tensor(builder, buffer_idx, shape, sparsity=None, tensor_type=None, quantization=None): |
| """Helper to build a TFLite tensor.""" |
| if tensor_type is None: |
| tensor_type = _tfl_tensor_type.FLOAT32 |
| shape_vec = _tflite_shape(builder, shape) |
| _tfl_tensor.TensorStart(builder) |
| _tfl_tensor.TensorAddBuffer(builder, buffer_idx) |
| _tfl_tensor.TensorAddHasRank(builder, True) |
| _tfl_tensor.TensorAddIsVariable(builder, False) |
| _tfl_tensor.TensorAddShape(builder, shape_vec) |
| if sparsity is not None: |
| _tfl_tensor.TensorAddSparsity(builder, sparsity) |
| if quantization is not None: |
| _tfl_tensor.TensorAddQuantization(builder, quantization) |
| _tfl_tensor.TensorAddType(builder, tensor_type) |
| return _tfl_tensor.TensorEnd(builder) |
| |
| |
| def _build_buffer(builder, data=None): |
| # Build the data vector before starting the Buffer table to avoid |
| # flatbuffers IsNestedError (vectors cannot be created inside tables). |
| data_offset = None |
| if data is not None: |
| data_offset = _tflite_byte_vector(builder, data) |
| _tfl_buffer.BufferStart(builder) |
| if data_offset is not None: |
| _tfl_buffer.BufferAddData(builder, data_offset) |
| return _tfl_buffer.BufferEnd(builder) |
| |
| |
| def _build_quantization_parameters(builder, *, scale, zero_point, quantized_dimension): |
| scale_vec = _tflite_float32_vector( |
| builder, _tfl_quantization_parameters.QuantizationParametersStartScaleVector, scale |
| ) |
| zero_point_vec = _tflite_int64_vector( |
| builder, |
| _tfl_quantization_parameters.QuantizationParametersStartZeroPointVector, |
| zero_point, |
| ) |
| _tfl_quantization_parameters.QuantizationParametersStart(builder) |
| _tfl_quantization_parameters.QuantizationParametersAddScale(builder, scale_vec) |
| _tfl_quantization_parameters.QuantizationParametersAddZeroPoint(builder, zero_point_vec) |
| _tfl_quantization_parameters.QuantizationParametersAddQuantizedDimension( |
| builder, quantized_dimension |
| ) |
| return _tfl_quantization_parameters.QuantizationParametersEnd(builder) |
| |
| |
| def _build_operator( |
| builder, |
| opcode_index, |
| inputs, |
| outputs, |
| builtin_options_type=None, |
| builtin_options=None, |
| builtin_options2_type=None, |
| builtin_options2=None, |
| ): |
| inputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartInputsVector, inputs) |
| outputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartOutputsVector, outputs) |
| _tfl_operator.OperatorStart(builder) |
| _tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index) |
| _tfl_operator.OperatorAddInputs(builder, inputs_vec) |
| _tfl_operator.OperatorAddOutputs(builder, outputs_vec) |
| if builtin_options_type is not None: |
| _tfl_operator.OperatorAddBuiltinOptionsType(builder, builtin_options_type) |
| if builtin_options is not None: |
| _tfl_operator.OperatorAddBuiltinOptions(builder, builtin_options) |
| if builtin_options2_type is not None: |
| _tfl_operator.OperatorAddBuiltinOptions2Type(builder, builtin_options2_type) |
| if builtin_options2 is not None: |
| _tfl_operator.OperatorAddBuiltinOptions2(builder, builtin_options2) |
| return _tfl_operator.OperatorEnd(builder) |
| |
| |
| def _build_operator_code(builder, builtin_op): |
| # deprecated_builtin_code is int8 (max 127). Ops past that write 127 as a |
| # placeholder and use the full builtin_code field. |
| deprecated_code = builtin_op if builtin_op < 127 else 127 |
| _tfl_operator_code.OperatorCodeStart(builder) |
| _tfl_operator_code.OperatorCodeAddDeprecatedBuiltinCode(builder, deprecated_code) |
| _tfl_operator_code.OperatorCodeAddBuiltinCode(builder, builtin_op) |
| _tfl_operator_code.OperatorCodeAddVersion(builder, 1) |
| return _tfl_operator_code.OperatorCodeEnd(builder) |
| |
| |
| def _build_subgraph(builder, *, tensors, operators, inputs, outputs): |
| tensors_vec = _tflite_offset_vector(builder, _tfl_subgraph.SubGraphStartTensorsVector, tensors) |
| operators_vec = _tflite_offset_vector( |
| builder, _tfl_subgraph.SubGraphStartOperatorsVector, operators |
| ) |
| inputs_vec = _tflite_int32_vector(builder, _tfl_subgraph.SubGraphStartInputsVector, inputs) |
| outputs_vec = _tflite_int32_vector(builder, _tfl_subgraph.SubGraphStartOutputsVector, outputs) |
| |
| _tfl_subgraph.SubGraphStart(builder) |
| _tfl_subgraph.SubGraphAddTensors(builder, tensors_vec) |
| _tfl_subgraph.SubGraphAddOperators(builder, operators_vec) |
| _tfl_subgraph.SubGraphAddInputs(builder, inputs_vec) |
| _tfl_subgraph.SubGraphAddOutputs(builder, outputs_vec) |
| return _tfl_subgraph.SubGraphEnd(builder) |
| |
| |
| def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers, extra_subgraphs=None): |
| all_subgraphs = [subgraph] + (extra_subgraphs or []) |
| buffers_vec = _tflite_offset_vector(builder, _tfl_model.ModelStartBuffersVector, buffers) |
| opcodes_vec = _tflite_offset_vector( |
| builder, _tfl_model.ModelStartOperatorCodesVector, operator_codes |
| ) |
| subgraphs_vec = _tflite_offset_vector( |
| builder, _tfl_model.ModelStartSubgraphsVector, all_subgraphs |
| ) |
| |
| _tfl_model.ModelStart(builder) |
| _tfl_model.ModelAddBuffers(builder, buffers_vec) |
| _tfl_model.ModelAddSubgraphs(builder, subgraphs_vec) |
| _tfl_model.ModelAddOperatorCodes(builder, opcodes_vec) |
| _tfl_model.ModelAddVersion(builder, 3) |
| model = _tfl_model.ModelEnd(builder) |
| |
| builder.Finish(model, b"TFL3") |
| return bytes(builder.Output()) |
| |
| |
| def _build_call_options(builder, subgraph_index): |
| _tfl_call_options.CallOptionsStart(builder) |
| _tfl_call_options.CallOptionsAddSubgraph(builder, subgraph_index) |
| return _tfl_call_options.CallOptionsEnd(builder) |
| |
| |
| def _build_if_options(builder, then_subgraph_index, else_subgraph_index): |
| _tfl_if_options.IfOptionsStart(builder) |
| _tfl_if_options.IfOptionsAddThenSubgraphIndex(builder, then_subgraph_index) |
| _tfl_if_options.IfOptionsAddElseSubgraphIndex(builder, else_subgraph_index) |
| return _tfl_if_options.IfOptionsEnd(builder) |
| |
| |
| def _build_while_options(builder, cond_subgraph_index, body_subgraph_index): |
| _tfl_while_options.WhileOptionsStart(builder) |
| _tfl_while_options.WhileOptionsAddCondSubgraphIndex(builder, cond_subgraph_index) |
| _tfl_while_options.WhileOptionsAddBodySubgraphIndex(builder, body_subgraph_index) |
| return _tfl_while_options.WhileOptionsEnd(builder) |
| |
| |
| def _build_call_once_options(builder, init_subgraph_index): |
| _tfl_call_once_options.CallOnceOptionsStart(builder) |
| _tfl_call_once_options.CallOnceOptionsAddInitSubgraphIndex(builder, init_subgraph_index) |
| return _tfl_call_once_options.CallOnceOptionsEnd(builder) |
| |
| |
| def _load_model_from_buffer(model_bytes): |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(model_bytes, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| return mod |
| |
| |
| def _get_builtin_operator(builtin_name): |
| if not hasattr(_tfl_builtin_operator, builtin_name): |
| pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") |
| return getattr(_tfl_builtin_operator, builtin_name) |
| |
| |
| def _build_tflite_call_model( |
| call_subgraph_index=1, |
| callee_inputs=None, |
| callee_outputs=None, |
| callee_output_shape=None, |
| callee_output_type=None, |
| ): |
| """Build a TFLite model where main CALLs a subgraph computing x + 1.""" |
| builder = flatbuffers.Builder(1024) |
| |
| callee_inputs = [0] if callee_inputs is None else callee_inputs |
| callee_outputs = [2] if callee_outputs is None else callee_outputs |
| callee_output_shape = [2, 2] if callee_output_shape is None else callee_output_shape |
| callee_output_type = ( |
| _tfl_tensor_type.FLOAT32 if callee_output_type is None else callee_output_type |
| ) |
| call_options = _build_call_options(builder, call_subgraph_index) |
| one = np.array(1.0, dtype=np.float32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 2, [2, 2]), |
| ] |
| main_call = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.CallOptions, |
| builtin_options=call_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_call], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| callee_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 2, callee_output_shape, tensor_type=callee_output_type), |
| ] |
| callee_add = _build_operator(builder, 1, [0, 1], [2]) |
| callee_subgraph = _build_subgraph( |
| builder, |
| tensors=callee_tensors, |
| operators=[callee_add], |
| inputs=callee_inputs, |
| outputs=callee_outputs, |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("CALL")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[callee_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_call_subgraph(): |
| """Test TFLite CALL conversion to a private Relax function.""" |
| mod = _load_model_from_buffer(_build_tflite_call_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_call_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| cls = Expected |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_1(tvmgen_tensor_0) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_tflite_multi_output_call_model(): |
| """Build a TFLite model where CALL returns x + 1 and x - 1.""" |
| builder = flatbuffers.Builder(1024) |
| |
| call_options = _build_call_options(builder, 1) |
| one = np.array(1.0, dtype=np.float32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 2, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| main_call = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1, 2], |
| builtin_options_type=_tfl_builtin_options.CallOptions, |
| builtin_options=call_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_call], |
| inputs=[0], |
| outputs=[1, 2], |
| ) |
| |
| callee_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 2, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| callee_add = _build_operator(builder, 1, [0, 1], [2]) |
| callee_sub = _build_operator(builder, 2, [0, 1], [3]) |
| callee_subgraph = _build_subgraph( |
| builder, |
| tensors=callee_tensors, |
| operators=[callee_add, callee_sub], |
| inputs=[0], |
| outputs=[2, 3], |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("CALL")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| _build_operator_code(builder, _get_builtin_operator("SUB")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[callee_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_call_subgraph_multi_output(): |
| """Test CALL tuple returns are split and rebound to TFLite output tensors.""" |
| mod = _load_model_from_buffer(_build_tflite_multi_output_call_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_call_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv1: R.Tensor((2, 2), dtype="float32") = R.subtract( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv2: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (gv, gv1) |
| R.output(gv2) |
| return gv2 |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| R.func_attr({"num_input": 1}) |
| cls = Expected |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = cls.tflite_call_subgraph_1(tvmgen_tensor_0) |
| lv1: R.Tensor((2, 2), dtype="float32") = lv[0] |
| lv2: R.Tensor((2, 2), dtype="float32") = lv[1] |
| gv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_tflite_nested_call_model(): |
| """Build a TFLite model where main CALLs subgraph A, which CALLs subgraph B.""" |
| builder = flatbuffers.Builder(1024) |
| |
| main_call_options = _build_call_options(builder, 1) |
| nested_call_options = _build_call_options(builder, 2) |
| one = np.array(1.0, dtype=np.float32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| main_call = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.CallOptions, |
| builtin_options=main_call_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_call], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| caller_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| nested_call = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.CallOptions, |
| builtin_options=nested_call_options, |
| ) |
| caller_subgraph = _build_subgraph( |
| builder, |
| tensors=caller_tensors, |
| operators=[nested_call], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| callee_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| callee_add = _build_operator(builder, 1, [0, 1], [2]) |
| callee_subgraph = _build_subgraph( |
| builder, |
| tensors=callee_tensors, |
| operators=[callee_add], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("CALL")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[caller_subgraph, callee_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_call_subgraph_nested_call(): |
| """Test nested CALL subgraphs register all generated private functions.""" |
| mod = _load_model_from_buffer(_build_tflite_nested_call_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_call_subgraph_2( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_call_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| cls = Expected |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_2(tvmgen_tensor_0) |
| R.output(gv) |
| return gv |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| cls = Expected |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_call_subgraph_1(tvmgen_tensor_0) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_call_subgraph_invalid_index_unsupported(): |
| """Test CALL rejects invalid subgraph indices before lowering.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="CALL requires a valid subgraph index"): |
| _load_model_from_buffer(_build_tflite_call_model(call_subgraph_index=2)) |
| |
| |
| def test_call_subgraph_io_mismatch_unsupported(): |
| """Test CALL rejects callees whose input arity does not match the call site.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="CALL subgraph input count mismatch"): |
| _load_model_from_buffer(_build_tflite_call_model(callee_inputs=[])) |
| |
| |
| def test_call_subgraph_output_metadata_mismatch_unsupported(): |
| """Test CALL rejects callees whose output metadata does not match the call site.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="CALL subgraph output tensor metadata mismatch" |
| ): |
| _load_model_from_buffer(_build_tflite_call_model(callee_output_shape=[2])) |
| |
| |
| def _build_tflite_if_model( |
| condition_type=_tfl_tensor_type.BOOL, |
| then_subgraph_index=1, |
| else_subgraph_index=2, |
| then_outputs=None, |
| else_outputs=None, |
| else_input_shape=None, |
| else_input_type=None, |
| else_output_shape=None, |
| else_output_type=None, |
| ): |
| """Build a TFLite model where IF selects x + 1 or x - 1.""" |
| builder = flatbuffers.Builder(1024) |
| |
| then_outputs = [2] if then_outputs is None else then_outputs |
| else_outputs = [2] if else_outputs is None else else_outputs |
| else_input_shape = [2, 2] if else_input_shape is None else else_input_shape |
| else_input_type = _tfl_tensor_type.FLOAT32 if else_input_type is None else else_input_type |
| else_output_shape = [2, 2] if else_output_shape is None else else_output_shape |
| else_output_type = _tfl_tensor_type.FLOAT32 if else_output_type is None else else_output_type |
| if_options = _build_if_options(builder, then_subgraph_index, else_subgraph_index) |
| one = np.array(1.0, dtype=np.float32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=condition_type), |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| main_if = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.IfOptions, |
| builtin_options=if_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_if], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| then_tensors = [ |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 2, []), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| then_add = _build_operator(builder, 1, [0, 1], [2]) |
| then_subgraph = _build_subgraph( |
| builder, |
| tensors=then_tensors, |
| operators=[then_add], |
| inputs=[0], |
| outputs=then_outputs, |
| ) |
| |
| else_tensors = [ |
| _build_tensor(builder, 1, else_input_shape, tensor_type=else_input_type), |
| _build_tensor(builder, 2, []), |
| _build_tensor(builder, 3, else_output_shape, tensor_type=else_output_type), |
| ] |
| else_sub = _build_operator(builder, 2, [0, 1], [2]) |
| else_subgraph = _build_subgraph( |
| builder, |
| tensors=else_tensors, |
| operators=[else_sub], |
| inputs=[0], |
| outputs=else_outputs, |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("IF")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| _build_operator_code(builder, _get_builtin_operator("SUB")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[then_subgraph, else_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_if_subgraphs(): |
| """Test TFLite IF conversion to Relax If.""" |
| mod = _load_model_from_buffer(_build_tflite_if_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_if_then_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_if_else_subgraph_2( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.subtract( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_if_subgraph_1_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="bool"), |
| tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| cls = Expected |
| if tvmgen_tensor_0: |
| gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_then_subgraph_1( |
| tvmgen_tensor_1 |
| ) |
| cond_result: R.Tensor((2, 2), dtype="float32") = gv |
| else: |
| gv1: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_else_subgraph_2( |
| tvmgen_tensor_1 |
| ) |
| cond_result: R.Tensor((2, 2), dtype="float32") = gv1 |
| return cond_result |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((), dtype="bool"), |
| tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| cls = Expected |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = cls.tflite_if_subgraph_1_2( |
| tvmgen_tensor_0, tvmgen_tensor_1 |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_tflite_multi_output_if_model(): |
| """Build a TFLite model where IF returns two tensor outputs.""" |
| builder = flatbuffers.Builder(1024) |
| |
| if_options = _build_if_options(builder, 1, 2) |
| one = np.array(1.0, dtype=np.float32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.BOOL), |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 4, [2, 2]), |
| _build_tensor(builder, 5, [2, 2]), |
| ] |
| main_if = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2, 3], |
| builtin_options_type=_tfl_builtin_options.IfOptions, |
| builtin_options=if_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_if], |
| inputs=[0, 1], |
| outputs=[2, 3], |
| ) |
| |
| then_tensors = [ |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 2, []), |
| _build_tensor(builder, 3, [2, 2]), |
| _build_tensor(builder, 4, [2, 2]), |
| ] |
| then_add = _build_operator(builder, 1, [0, 1], [2]) |
| then_sub = _build_operator(builder, 2, [0, 1], [3]) |
| then_subgraph = _build_subgraph( |
| builder, |
| tensors=then_tensors, |
| operators=[then_add, then_sub], |
| inputs=[0], |
| outputs=[2, 3], |
| ) |
| |
| else_tensors = [ |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 2, []), |
| _build_tensor(builder, 3, [2, 2]), |
| _build_tensor(builder, 4, [2, 2]), |
| ] |
| else_sub = _build_operator(builder, 2, [0, 1], [2]) |
| else_add = _build_operator(builder, 1, [0, 1], [3]) |
| else_subgraph = _build_subgraph( |
| builder, |
| tensors=else_tensors, |
| operators=[else_sub, else_add], |
| inputs=[0], |
| outputs=[2, 3], |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("IF")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| _build_operator_code(builder, _get_builtin_operator("SUB")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[then_subgraph, else_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_if_subgraphs_multi_output(): |
| """Test IF tuple returns are preserved through the private wrapper function.""" |
| mod = _load_model_from_buffer(_build_tflite_multi_output_if_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_if_then_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv1: R.Tensor((2, 2), dtype="float32") = R.subtract( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv2: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (gv, gv1) |
| R.output(gv2) |
| return gv2 |
| |
| @R.function(private=True) |
| def tflite_if_else_subgraph_2( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.subtract( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv1: R.Tensor((2, 2), dtype="float32") = R.add( |
| tvmgen_tensor_0, R.const(1.0, "float32") |
| ) |
| gv2: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (gv, gv1) |
| R.output(gv2) |
| return gv2 |
| |
| @R.function(private=True) |
| def tflite_if_subgraph_1_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="bool"), |
| tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| cls = Expected |
| if tvmgen_tensor_0: |
| gv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = cls.tflite_if_then_subgraph_1(tvmgen_tensor_1) |
| cond_result: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = gv |
| else: |
| gv1: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = cls.tflite_if_else_subgraph_2(tvmgen_tensor_1) |
| cond_result: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = gv1 |
| return cond_result |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((), dtype="bool"), |
| tvmgen_tensor_1: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tuple(R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32")): |
| R.func_attr({"num_input": 2}) |
| cls = Expected |
| with R.dataflow(): |
| lv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = cls.tflite_if_subgraph_1_2(tvmgen_tensor_0, tvmgen_tensor_1) |
| lv1: R.Tensor((2, 2), dtype="float32") = lv[0] |
| lv2: R.Tensor((2, 2), dtype="float32") = lv[1] |
| gv: R.Tuple( |
| R.Tensor((2, 2), dtype="float32"), R.Tensor((2, 2), dtype="float32") |
| ) = (lv1, lv2) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_if_subgraphs_non_bool_condition_unsupported(): |
| """Test IF rejects non-bool condition tensors.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a scalar bool condition"): |
| _load_model_from_buffer(_build_tflite_if_model(condition_type=_tfl_tensor_type.INT32)) |
| |
| |
| def test_if_subgraphs_invalid_index_unsupported(): |
| """Test IF rejects invalid branch subgraph indices before lowering.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="IF requires a valid subgraph index"): |
| _load_model_from_buffer(_build_tflite_if_model(then_subgraph_index=3)) |
| |
| |
| def test_if_subgraphs_output_count_mismatch_unsupported(): |
| """Test IF rejects branches whose output arity does not match the call site.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="IF subgraph output count mismatch"): |
| _load_model_from_buffer(_build_tflite_if_model(else_outputs=[])) |
| |
| |
| def test_if_subgraphs_input_metadata_mismatch_unsupported(): |
| """Test IF rejects branches whose input metadata does not match the call site.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="IF subgraph input tensor metadata mismatch" |
| ): |
| _load_model_from_buffer(_build_tflite_if_model(else_input_shape=[2])) |
| |
| |
| def test_if_subgraphs_output_metadata_mismatch_unsupported(): |
| """Test IF rejects branches whose output metadata does not match the call site.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="IF subgraph output tensor metadata mismatch" |
| ): |
| _load_model_from_buffer(_build_tflite_if_model(else_output_shape=[2])) |
| |
| |
| def _build_tflite_while_model( |
| cond_subgraph_index=1, |
| body_subgraph_index=2, |
| cond_output_type=_tfl_tensor_type.BOOL, |
| cond_input_type=_tfl_tensor_type.INT32, |
| body_outputs=None, |
| body_input_type=_tfl_tensor_type.INT32, |
| body_output_type=_tfl_tensor_type.INT32, |
| main_output_type=_tfl_tensor_type.INT32, |
| ): |
| """Build a TFLite WHILE model incrementing an int32 scalar until i < 3 is false.""" |
| builder = flatbuffers.Builder(1024) |
| |
| body_outputs = [2] if body_outputs is None else body_outputs |
| while_options = _build_while_options(builder, cond_subgraph_index, body_subgraph_index) |
| one = np.array(1, dtype=np.int32) |
| three = np.array(3, dtype=np.int32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=main_output_type), |
| ] |
| main_while = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.WhileOptions, |
| builtin_options=while_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_while], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| cond_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=cond_input_type), |
| _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=cond_output_type), |
| ] |
| cond_less = _build_operator(builder, 1, [0, 1], [2]) |
| cond_subgraph = _build_subgraph( |
| builder, |
| tensors=cond_tensors, |
| operators=[cond_less], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| body_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=body_input_type), |
| _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=body_output_type), |
| ] |
| body_add = _build_operator(builder, 2, [0, 1], [2]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[body_add], |
| inputs=[0], |
| outputs=body_outputs, |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("WHILE")), |
| _build_operator_code(builder, _get_builtin_operator("LESS")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, three.tobytes()), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[cond_subgraph, body_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def _build_tflite_repeated_while_model(): |
| """Build a TFLite model where two WHILE ops share the same cond/body subgraphs.""" |
| builder = flatbuffers.Builder(1024) |
| |
| while_options = _build_while_options(builder, 1, 2) |
| one = np.array(1, dtype=np.int32) |
| three = np.array(3, dtype=np.int32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), |
| ] |
| main_while_0 = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.WhileOptions, |
| builtin_options=while_options, |
| ) |
| main_while_1 = _build_operator( |
| builder, |
| 0, |
| [1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.WhileOptions, |
| builtin_options=while_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_while_0, main_while_1], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| cond_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.BOOL), |
| ] |
| cond_less = _build_operator(builder, 1, [0, 1], [2]) |
| cond_subgraph = _build_subgraph( |
| builder, |
| tensors=cond_tensors, |
| operators=[cond_less], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| body_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), |
| ] |
| body_add = _build_operator(builder, 2, [0, 1], [2]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[body_add], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("WHILE")), |
| _build_operator_code(builder, _get_builtin_operator("LESS")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, three.tobytes()), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[cond_subgraph, body_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def _build_tflite_zero_var_while_model(): |
| """Build a TFLite WHILE model with no loop-carried tensors.""" |
| builder = flatbuffers.Builder(1024) |
| |
| while_options = _build_while_options(builder, 1, 2) |
| main_while = _build_operator( |
| builder, |
| 0, |
| [], |
| [], |
| builtin_options_type=_tfl_builtin_options.WhileOptions, |
| builtin_options=while_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=[], |
| operators=[main_while], |
| inputs=[], |
| outputs=[], |
| ) |
| cond_subgraph = _build_subgraph(builder, tensors=[], operators=[], inputs=[], outputs=[]) |
| body_subgraph = _build_subgraph(builder, tensors=[], operators=[], inputs=[], outputs=[]) |
| |
| operator_codes = [_build_operator_code(builder, _get_builtin_operator("WHILE"))] |
| buffers = [_build_buffer(builder)] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[cond_subgraph, body_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_while_subgraphs(): |
| """Test TFLite WHILE conversion to a recursive Relax private function.""" |
| mod = _load_model_from_buffer(_build_tflite_while_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_while_cond_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| ) -> R.Tensor((), dtype="bool"): |
| with R.dataflow(): |
| gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32")) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_while_body_subgraph_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| ) -> R.Tensor((), dtype="int32"): |
| with R.dataflow(): |
| gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32")) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_while_subgraph_1_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| ) -> R.Tensor((), dtype="int32"): |
| cls = Expected |
| while_cond: R.Tensor((), dtype="bool") = cls.tflite_while_cond_subgraph_1( |
| tvmgen_tensor_0 |
| ) |
| if while_cond: |
| gv: R.Tensor((), dtype="int32") = cls.tflite_while_body_subgraph_2(tvmgen_tensor_0) |
| gv1: R.Tensor((), dtype="int32") = cls.tflite_while_subgraph_1_2(gv) |
| cond_result: R.Tensor((), dtype="int32") = gv1 |
| else: |
| cond_result: R.Tensor((), dtype="int32") = tvmgen_tensor_0 |
| return cond_result |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| ) -> R.Tensor((), dtype="int32"): |
| R.func_attr({"num_input": 1}) |
| cls = Expected |
| with R.dataflow(): |
| gv: R.Tensor((), dtype="int32") = cls.tflite_while_subgraph_1_2(tvmgen_tensor_0) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_while_subgraphs_repeated_cond_body_pair(): |
| """Test repeated WHILE ops reuse the same recursive private function.""" |
| mod = _load_model_from_buffer(_build_tflite_repeated_while_model()) |
| names = [gv.name_hint for gv in mod.get_global_vars()] |
| assert names.count("tflite_while_subgraph_1_2") == 1 |
| |
| |
| def _build_tflite_two_var_while_model(): |
| """Build a TFLite WHILE model with two int32 loop-carried scalar tensors.""" |
| builder = flatbuffers.Builder(1024) |
| |
| while_options = _build_while_options(builder, 1, 2) |
| one = np.array(1, dtype=np.int32) |
| three = np.array(3, dtype=np.int32) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32), |
| ] |
| main_while = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2, 3], |
| builtin_options_type=_tfl_builtin_options.WhileOptions, |
| builtin_options=while_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_while], |
| inputs=[0, 1], |
| outputs=[2, 3], |
| ) |
| |
| cond_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 2, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL), |
| ] |
| cond_less = _build_operator(builder, 1, [0, 2], [3]) |
| cond_subgraph = _build_subgraph( |
| builder, |
| tensors=cond_tensors, |
| operators=[cond_less], |
| inputs=[0, 1], |
| outputs=[3], |
| ) |
| |
| body_tensors = [ |
| _build_tensor(builder, 0, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 1, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 3, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 5, [], tensor_type=_tfl_tensor_type.INT32), |
| ] |
| body_add_i = _build_operator(builder, 2, [0, 2], [3]) |
| body_add_acc = _build_operator(builder, 2, [1, 0], [4]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[body_add_i, body_add_acc], |
| inputs=[0, 1], |
| outputs=[3, 4], |
| ) |
| |
| operator_codes = [ |
| _build_operator_code(builder, _get_builtin_operator("WHILE")), |
| _build_operator_code(builder, _get_builtin_operator("LESS")), |
| _build_operator_code(builder, _get_builtin_operator("ADD")), |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder, three.tobytes()), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[cond_subgraph, body_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_while_subgraphs_two_loop_vars(): |
| """Test WHILE tuple loop state with two loop-carried variables.""" |
| mod = _load_model_from_buffer(_build_tflite_two_var_while_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function(private=True) |
| def tflite_while_cond_subgraph_1( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| tvmgen_tensor_1: R.Tensor((), dtype="int32"), |
| ) -> R.Tensor((), dtype="bool"): |
| with R.dataflow(): |
| gv: R.Tensor((), dtype="bool") = R.less(tvmgen_tensor_0, R.const(3, "int32")) |
| R.output(gv) |
| return gv |
| |
| @R.function(private=True) |
| def tflite_while_body_subgraph_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| tvmgen_tensor_1: R.Tensor((), dtype="int32"), |
| ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): |
| with R.dataflow(): |
| gv: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_0, R.const(1, "int32")) |
| gv1: R.Tensor((), dtype="int32") = R.add(tvmgen_tensor_1, tvmgen_tensor_0) |
| gv2: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| gv, |
| gv1, |
| ) |
| R.output(gv2) |
| return gv2 |
| |
| @R.function(private=True) |
| def tflite_while_subgraph_1_2( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| tvmgen_tensor_1: R.Tensor((), dtype="int32"), |
| ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): |
| cls = Expected |
| while_cond: R.Tensor((), dtype="bool") = cls.tflite_while_cond_subgraph_1( |
| tvmgen_tensor_0, tvmgen_tensor_1 |
| ) |
| if while_cond: |
| gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| cls.tflite_while_body_subgraph_2(tvmgen_tensor_0, tvmgen_tensor_1) |
| ) |
| gv1: R.Tensor((), dtype="int32") = gv[0] |
| gv2: R.Tensor((), dtype="int32") = gv[1] |
| gv3: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| cls.tflite_while_subgraph_1_2(gv1, gv2) |
| ) |
| cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = gv3 |
| else: |
| cond_result: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| tvmgen_tensor_0, |
| tvmgen_tensor_1, |
| ) |
| return cond_result |
| |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((), dtype="int32"), |
| tvmgen_tensor_1: R.Tensor((), dtype="int32"), |
| ) -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")): |
| R.func_attr({"num_input": 2}) |
| cls = Expected |
| with R.dataflow(): |
| lv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| cls.tflite_while_subgraph_1_2(tvmgen_tensor_0, tvmgen_tensor_1) |
| ) |
| lv1: R.Tensor((), dtype="int32") = lv[0] |
| lv2: R.Tensor((), dtype="int32") = lv[1] |
| gv: R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((), dtype="int32")) = ( |
| lv1, |
| lv2, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_while_subgraphs_non_bool_condition_unsupported(): |
| """Test WHILE rejects cond subgraphs that do not return scalar bool.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a scalar bool condition"): |
| _load_model_from_buffer(_build_tflite_while_model(cond_output_type=_tfl_tensor_type.INT32)) |
| |
| |
| def test_while_subgraphs_invalid_index_unsupported(): |
| """Test WHILE rejects invalid cond/body subgraph indices before lowering.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires a valid subgraph index"): |
| _load_model_from_buffer(_build_tflite_while_model(cond_subgraph_index=3)) |
| |
| |
| def test_while_subgraphs_zero_loop_vars_unsupported(): |
| """Test WHILE rejects operators without loop-carried tensors.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="WHILE requires loop-carried inputs"): |
| _load_model_from_buffer(_build_tflite_zero_var_while_model()) |
| |
| |
| def test_while_subgraphs_loop_state_metadata_mismatch_unsupported(): |
| """Test WHILE rejects loop outputs whose metadata does not match loop inputs.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="WHILE loop state tensor metadata mismatch" |
| ): |
| _load_model_from_buffer( |
| _build_tflite_while_model(main_output_type=_tfl_tensor_type.FLOAT32) |
| ) |
| |
| |
| def test_while_subgraphs_output_count_mismatch_unsupported(): |
| """Test WHILE rejects body subgraphs whose output arity does not match loop vars.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="WHILE subgraph output count mismatch"): |
| _load_model_from_buffer(_build_tflite_while_model(body_outputs=[])) |
| |
| |
| def test_while_subgraphs_input_metadata_mismatch_unsupported(): |
| """Test WHILE rejects cond subgraph inputs whose metadata does not match loop vars.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="WHILE subgraph input tensor metadata mismatch" |
| ): |
| _load_model_from_buffer(_build_tflite_while_model(cond_input_type=_tfl_tensor_type.FLOAT32)) |
| |
| |
| def test_while_subgraphs_output_metadata_mismatch_unsupported(): |
| """Test WHILE rejects body outputs whose metadata does not match loop vars.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="WHILE subgraph output tensor metadata mismatch" |
| ): |
| _load_model_from_buffer( |
| _build_tflite_while_model(body_output_type=_tfl_tensor_type.FLOAT32) |
| ) |
| |
| |
| def _build_tflite_call_once_model( |
| init_has_op=False, |
| init_subgraph_index=1, |
| call_once_inputs=None, |
| call_once_outputs=None, |
| init_inputs=None, |
| init_outputs=None, |
| ): |
| """Build a TFLite model with CALL_ONCE and one pass-through output.""" |
| builder = flatbuffers.Builder(1024) |
| |
| call_once_inputs = [] if call_once_inputs is None else call_once_inputs |
| call_once_outputs = [] if call_once_outputs is None else call_once_outputs |
| init_inputs = [] if init_inputs is None else init_inputs |
| init_outputs = [] if init_outputs is None else init_outputs |
| |
| call_once_options = _build_call_once_options(builder, init_subgraph_index) |
| main_tensors = [_build_tensor(builder, 0, [2, 2])] |
| main_call_once = _build_operator( |
| builder, |
| 0, |
| call_once_inputs, |
| call_once_outputs, |
| builtin_options_type=_tfl_builtin_options.CallOnceOptions, |
| builtin_options=call_once_options, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_call_once], |
| inputs=[0], |
| outputs=[0], |
| ) |
| |
| if init_has_op: |
| one = np.array(1.0, dtype=np.float32) |
| init_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 2, [2, 2]), |
| ] |
| init_op = _build_operator(builder, 1, [0, 1], [2]) |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, one.tobytes()), |
| _build_buffer(builder), |
| ] |
| else: |
| init_tensors = ( |
| [_build_tensor(builder, 0, [2, 2])] |
| if len(init_inputs) != 0 or len(init_outputs) != 0 |
| else [] |
| ) |
| init_op = None |
| buffers = [_build_buffer(builder)] |
| |
| init_subgraph = _build_subgraph( |
| builder, |
| tensors=init_tensors, |
| operators=[] if init_op is None else [init_op], |
| inputs=init_inputs, |
| outputs=init_outputs, |
| ) |
| |
| operator_codes = [_build_operator_code(builder, _get_builtin_operator("CALL_ONCE"))] |
| if init_has_op: |
| operator_codes.append(_build_operator_code(builder, _get_builtin_operator("ADD"))) |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[init_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def test_call_once_empty_init_subgraph(): |
| """Test the no-op CALL_ONCE subset.""" |
| mod = _load_model_from_buffer(_build_tflite_call_once_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = tvmgen_tensor_0 |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_call_once_non_empty_init_subgraph_unsupported(): |
| """Test CALL_ONCE rejects init subgraphs with side-effect-like bodies.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE"): |
| _load_model_from_buffer(_build_tflite_call_once_model(init_has_op=True)) |
| |
| |
| def test_call_once_inputs_outputs_unsupported(): |
| """Test CALL_ONCE rejects operator inputs and outputs.""" |
| with pytest.raises(tvm.error.OpNotImplemented, match="CALL_ONCE with inputs or outputs"): |
| _load_model_from_buffer( |
| _build_tflite_call_once_model(call_once_inputs=[0], call_once_outputs=[0]) |
| ) |
| |
| |
| def test_call_once_init_subgraph_io_unsupported(): |
| """Test CALL_ONCE rejects init subgraphs with inputs or outputs.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="CALL_ONCE with non-empty init subgraph I/O" |
| ): |
| _load_model_from_buffer(_build_tflite_call_once_model(init_inputs=[0], init_outputs=[0])) |
| |
| |
| def test_call_once_invalid_index_unsupported(): |
| """Test CALL_ONCE rejects invalid init subgraph indices before lowering.""" |
| with pytest.raises( |
| tvm.error.OpNotImplemented, match="CALL_ONCE requires a valid subgraph index" |
| ): |
| _load_model_from_buffer(_build_tflite_call_once_model(init_subgraph_index=2)) |
| |
| |
| def _get_stablehlo_builtin_operator(builtin_name): |
| if not hasattr(_tfl_builtin_operator, builtin_name): |
| pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") |
| return getattr(_tfl_builtin_operator, builtin_name) |
| |
| |
| def _build_stablehlo_model(*, builtin_name, input_count): |
| """Build a minimal TFLite model containing one StableHLO builtin operator.""" |
| builder = flatbuffers.Builder(1024) |
| shape = [2, 2] |
| output_tensor_idx = input_count |
| builtin_op = _get_stablehlo_builtin_operator(builtin_name) |
| |
| tensors = [_build_tensor(builder, buffer_idx, shape) for buffer_idx in range(input_count + 1)] |
| stablehlo_op = _build_operator( |
| builder, |
| 0, |
| list(range(input_count)), |
| [output_tensor_idx], |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[stablehlo_op], |
| inputs=list(range(input_count)), |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [_build_operator_code(builder, builtin_op)] |
| buffers = [_build_buffer(builder) for _ in range(input_count + 1)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers |
| ) |
| |
| |
| def _build_stablehlo_model_with_unused_subgraph(): |
| """Build a StableHLO model with an unused extra subgraph.""" |
| builder = flatbuffers.Builder(1024) |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_ADD") |
| |
| main_tensors = [_build_tensor(builder, buffer_idx, [2, 2]) for buffer_idx in range(3)] |
| main_op = _build_operator(builder, 0, [0, 1], [2]) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[main_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| # Give the unused subgraph a conflicting input tensor name and different |
| # shape. from_tflite should infer the main function input shape only from |
| # Subgraphs(0). |
| extra_tensors = [_build_tensor(builder, buffer_idx, [4, 4]) for buffer_idx in range(3, 6)] |
| extra_op = _build_operator(builder, 0, [0, 1], [2]) |
| extra_subgraph = _build_subgraph( |
| builder, |
| tensors=extra_tensors, |
| operators=[extra_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| operator_codes = [_build_operator_code(builder, builtin_op)] |
| buffers = [_build_buffer(builder) for _ in range(6)] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[extra_subgraph], |
| operator_codes=operator_codes, |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_reduce_model(reducer_name, init_value): |
| """Build a single-input STABLEHLO_REDUCE model with a binary reducer body.""" |
| builder = flatbuffers.Builder(1024) |
| |
| dimensions_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStartDimensionsVector, |
| [1], |
| ) |
| _tfl_stablehlo_reduce_opts.StablehloReduceOptionsStart(builder) |
| _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddDimensions(builder, dimensions_vec) |
| _tfl_stablehlo_reduce_opts.StablehloReduceOptionsAddBodySubgraphIndex(builder, 1) |
| reduce_opts = _tfl_stablehlo_reduce_opts.StablehloReduceOptionsEnd(builder) |
| |
| reduce_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE") |
| reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) |
| reduce_code = _build_operator_code(builder, reduce_builtin) |
| reducer_code = _build_operator_code(builder, reducer_builtin) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 3]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 2, [2]), |
| ] |
| reduce_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloReduceOptions, |
| builtin_options2=reduce_opts, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[reduce_op], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(3, 6)] |
| reducer_op = _build_operator(builder, 1, [0, 1], [2]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[reducer_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, np.array(init_value, dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[body_subgraph], |
| operator_codes=[reduce_code, reducer_code], |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_sort_model(comparison_direction, is_stable=False): |
| """Build a single-input STABLEHLO_SORT model with a compare body.""" |
| builder = flatbuffers.Builder(1024) |
| |
| _tfl_stablehlo_sort_opts.StablehloSortOptionsStart(builder) |
| _tfl_stablehlo_sort_opts.StablehloSortOptionsAddDimension(builder, 1) |
| _tfl_stablehlo_sort_opts.StablehloSortOptionsAddIsStable(builder, is_stable) |
| _tfl_stablehlo_sort_opts.StablehloSortOptionsAddComparatorSubgraphIndex(builder, 1) |
| sort_opts = _tfl_stablehlo_sort_opts.StablehloSortOptionsEnd(builder) |
| |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection( |
| builder, comparison_direction |
| ) |
| compare_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) |
| |
| sort_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SORT") |
| compare_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") |
| sort_code = _build_operator_code(builder, sort_builtin) |
| compare_code = _build_operator_code(builder, compare_builtin) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 3]), |
| _build_tensor(builder, 1, [2, 3]), |
| ] |
| sort_op = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options2_type=_tfl_builtin_options2.StablehloSortOptions, |
| builtin_options2=sort_opts, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[sort_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| body_tensors = [ |
| _build_tensor(builder, 2, []), |
| _build_tensor(builder, 3, []), |
| _build_tensor(builder, 4, [], tensor_type=_tfl_tensor_type.BOOL), |
| ] |
| compare_op = _build_operator( |
| builder, |
| 1, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, |
| builtin_options2=compare_opts, |
| ) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[compare_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| buffers = [_build_buffer(builder) for _ in range(5)] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[body_subgraph], |
| operator_codes=[sort_code, compare_code], |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_reduce_window_model( |
| reducer_name="STABLEHLO_MAXIMUM", |
| init_value=-np.inf, |
| base_dilations=None, |
| ): |
| """Build an NHWC 2D STABLEHLO_REDUCE_WINDOW model.""" |
| builder = flatbuffers.Builder(1024) |
| if base_dilations is None: |
| base_dilations = [1, 1, 1, 1] |
| |
| window_dimensions_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDimensionsVector, |
| [1, 2, 2, 1], |
| ) |
| window_strides_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowStridesVector, |
| [1, 2, 2, 1], |
| ) |
| base_dilations_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartBaseDilationsVector, |
| base_dilations, |
| ) |
| window_dilations_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartWindowDilationsVector, |
| [1, 1, 1, 1], |
| ) |
| padding_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStartPaddingVector, |
| [0, 0, 0, 0, 0, 0, 0, 0], |
| ) |
| |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsStart(builder) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDimensions( |
| builder, window_dimensions_vec |
| ) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowStrides( |
| builder, window_strides_vec |
| ) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBaseDilations( |
| builder, base_dilations_vec |
| ) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddWindowDilations( |
| builder, window_dilations_vec |
| ) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddPadding(builder, padding_vec) |
| _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsAddBodySubgraphIndex(builder, 1) |
| reduce_window_opts = _tfl_stablehlo_reduce_window_opts.StablehloReduceWindowOptionsEnd(builder) |
| |
| reduce_window_builtin = _get_stablehlo_builtin_operator("STABLEHLO_REDUCE_WINDOW") |
| reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) |
| reduce_window_code = _build_operator_code(builder, reduce_window_builtin) |
| reducer_code = _build_operator_code(builder, reducer_builtin) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [1, 4, 4, 1]), |
| _build_tensor(builder, 1, []), |
| _build_tensor(builder, 2, [1, 2, 2, 1]), |
| ] |
| reduce_window_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloReduceWindowOptions, |
| builtin_options2=reduce_window_opts, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[reduce_window_op], |
| inputs=[0], |
| outputs=[2], |
| ) |
| |
| body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(3, 6)] |
| reducer_op = _build_operator(builder, 1, [0, 1], [2]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[reducer_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, np.array(init_value, dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[body_subgraph], |
| operator_codes=[reduce_window_code, reducer_code], |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_scatter_model(reducer_name="STABLEHLO_ADD", update_window_dims=None): |
| """Build a canonical point-update STABLEHLO_SCATTER model.""" |
| builder = flatbuffers.Builder(1024) |
| if update_window_dims is None: |
| update_window_dims = [] |
| |
| update_window_dims_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartUpdateWindowDimsVector, |
| update_window_dims, |
| ) |
| inserted_window_dims_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartInsertedWindowDimsVector, |
| [0], |
| ) |
| scatter_dims_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStartScatterDimsToOperandDimsVector, |
| [0], |
| ) |
| |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsStart(builder) |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateWindowDims( |
| builder, update_window_dims_vec |
| ) |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddInsertedWindowDims( |
| builder, inserted_window_dims_vec |
| ) |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddScatterDimsToOperandDims( |
| builder, scatter_dims_vec |
| ) |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddIndexVectorDim(builder, 1) |
| _tfl_stablehlo_scatter_opts.StablehloScatterOptionsAddUpdateComputationSubgraphIndex(builder, 1) |
| scatter_opts = _tfl_stablehlo_scatter_opts.StablehloScatterOptionsEnd(builder) |
| |
| scatter_builtin = _get_stablehlo_builtin_operator("STABLEHLO_SCATTER") |
| reducer_builtin = _get_stablehlo_builtin_operator(reducer_name) |
| scatter_code = _build_operator_code(builder, scatter_builtin) |
| reducer_code = _build_operator_code(builder, reducer_builtin) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [4]), |
| _build_tensor(builder, 1, [2, 1], tensor_type=_tfl_tensor_type.INT32), |
| _build_tensor(builder, 2, [2]), |
| _build_tensor(builder, 3, [4]), |
| ] |
| scatter_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2], |
| [3], |
| builtin_options2_type=_tfl_builtin_options2.StablehloScatterOptions, |
| builtin_options2=scatter_opts, |
| ) |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=[scatter_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| |
| body_tensors = [_build_tensor(builder, buffer_idx, []) for buffer_idx in range(4, 7)] |
| reducer_op = _build_operator(builder, 1, [0, 1], [2]) |
| body_subgraph = _build_subgraph( |
| builder, |
| tensors=body_tensors, |
| operators=[reducer_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| |
| buffers = [_build_buffer(builder) for _ in range(7)] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[body_subgraph], |
| operator_codes=[scatter_code, reducer_code], |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_composite_model(with_attributes=False, use_main_input_after_composite=False): |
| """Build a STABLEHLO_COMPOSITE model that decomposes to STABLEHLO_NEGATE.""" |
| builder = flatbuffers.Builder(1024) |
| |
| name = builder.CreateString("test.negate") |
| attributes = None |
| if with_attributes: |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStartCompositeAttributesVector( |
| builder, 1 |
| ) |
| builder.PrependUint8(1) |
| attributes = builder.EndVector() |
| |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsStart(builder) |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddName(builder, name) |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddVersion(builder, 1) |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddDecompositionSubgraphIndex(builder, 1) |
| if attributes is not None: |
| _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsAddCompositeAttributes( |
| builder, attributes |
| ) |
| composite_opts = _tfl_stablehlo_composite_opts.StableHLOCompositeOptionsEnd(builder) |
| |
| composite_builtin = _get_stablehlo_builtin_operator("STABLEHLO_COMPOSITE") |
| negate_builtin = _get_stablehlo_builtin_operator("STABLEHLO_NEGATE") |
| add_builtin = _get_stablehlo_builtin_operator("STABLEHLO_ADD") |
| composite_code = _build_operator_code(builder, composite_builtin) |
| negate_code = _build_operator_code(builder, negate_builtin) |
| add_code = _build_operator_code(builder, add_builtin) |
| |
| main_tensors = [ |
| _build_tensor(builder, 0, [2, 2]), |
| _build_tensor(builder, 1, [2, 2]), |
| _build_tensor(builder, 2, [2, 2]), |
| ] |
| composite_op = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options2_type=_tfl_builtin_options2.StableHLOCompositeOptions, |
| builtin_options2=composite_opts, |
| ) |
| main_ops = [composite_op] |
| main_outputs = [1] |
| if use_main_input_after_composite: |
| main_ops.append(_build_operator(builder, 2, [0, 1], [2])) |
| main_outputs = [2] |
| |
| main_subgraph = _build_subgraph( |
| builder, |
| tensors=main_tensors, |
| operators=main_ops, |
| inputs=[0], |
| outputs=main_outputs, |
| ) |
| |
| decomposition_tensors = [ |
| _build_tensor(builder, 2, [2, 2]), |
| _build_tensor(builder, 3, [2, 2]), |
| ] |
| negate_op = _build_operator(builder, 1, [0], [1]) |
| decomposition_subgraph = _build_subgraph( |
| builder, |
| tensors=decomposition_tensors, |
| operators=[negate_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| |
| buffers = [_build_buffer(builder) for _ in range(4)] |
| return _finish_tflite_model( |
| builder, |
| subgraph=main_subgraph, |
| extra_subgraphs=[decomposition_subgraph], |
| operator_codes=[composite_code, negate_code, add_code], |
| buffers=buffers, |
| ) |
| |
| |
| def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type): |
| """Build a minimal TFLite StableHLO binary model with the requested tensor type.""" |
| builder = flatbuffers.Builder(1024) |
| shape = [2, 2] |
| output_tensor_idx = 2 |
| builtin_op = _get_stablehlo_builtin_operator(builtin_name) |
| |
| tensors = [ |
| _build_tensor(builder, buffer_idx, shape, tensor_type=tensor_type) |
| for buffer_idx in range(3) |
| ] |
| stablehlo_op = _build_operator(builder, 0, [0, 1], [output_tensor_idx]) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[stablehlo_op], |
| inputs=[0, 1], |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [_build_operator_code(builder, builtin_op)] |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "builtin_name, relax_op", |
| [ |
| ("STABLEHLO_ABS", R.abs), |
| ("STABLEHLO_COSINE", R.cos), |
| ("STABLEHLO_EXPONENTIAL", R.exp), |
| ("STABLEHLO_FLOOR", R.floor), |
| ("STABLEHLO_LOG", R.log), |
| ("STABLEHLO_LOGISTIC", R.sigmoid), |
| ("STABLEHLO_NEGATE", R.negative), |
| ("STABLEHLO_RSQRT", R.rsqrt), |
| ("STABLEHLO_TANH", R.tanh), |
| ], |
| ) |
| def test_stablehlo_unary(builtin_name, relax_op): |
| """TFLite StableHLO unary elementwise operators.""" |
| mod = _load_model_from_buffer(_build_stablehlo_model(builtin_name=builtin_name, input_count=1)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = relax_op(x) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "builtin_name, relax_op", |
| [ |
| ("STABLEHLO_ADD", R.add), |
| ("STABLEHLO_DIVIDE", R.divide), |
| ("STABLEHLO_MAXIMUM", R.maximum), |
| ("STABLEHLO_MINIMUM", R.minimum), |
| ("STABLEHLO_MULTIPLY", R.multiply), |
| ("STABLEHLO_POWER", R.power), |
| ("STABLEHLO_SUBTRACT", R.subtract), |
| ], |
| ) |
| def test_stablehlo_binary(builtin_name, relax_op): |
| """TFLite StableHLO binary elementwise operators.""" |
| mod = _load_model_from_buffer(_build_stablehlo_model(builtin_name=builtin_name, input_count=2)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_model_with_unused_subgraph(): |
| """TFLite StableHLO import ignores unused non-main subgraphs.""" |
| mod = _load_model_from_buffer(_build_stablehlo_model_with_unused_subgraph()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add(x, y) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "reducer_name, init_value, relax_op", |
| [ |
| ("STABLEHLO_ADD", 0.0, R.sum), |
| ("STABLEHLO_MAXIMUM", -np.inf, R.max), |
| ("STABLEHLO_MINIMUM", np.inf, R.min), |
| ("STABLEHLO_MULTIPLY", 1.0, R.prod), |
| ], |
| ) |
| def test_stablehlo_reduce(reducer_name, init_value, relax_op): |
| """TFLite StableHLO REDUCE with simple binary reducer body subgraphs.""" |
| mod = _load_model_from_buffer(_build_stablehlo_reduce_model(reducer_name, init_value)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2,), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2,), dtype="float32") = relax_op(x, axis=[1], keepdims=False) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_reduce_unsupported_reducer(): |
| """TFLite StableHLO REDUCE rejects unsupported body reducer ops.""" |
| buf = _build_stablehlo_reduce_model("STABLEHLO_SUBTRACT", 0.0) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="reducer"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_reduce_non_identity_init_unsupported(): |
| """TFLite StableHLO REDUCE rejects init values that Relax reductions cannot express.""" |
| buf = _build_stablehlo_reduce_model("STABLEHLO_ADD", 1.0) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="init value"): |
| from_tflite(tflite_model) |
| |
| |
| @pytest.mark.parametrize( |
| "comparison_direction, descending", |
| [ |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, |
| False, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, |
| True, |
| ), |
| ], |
| ) |
| def test_stablehlo_sort(comparison_direction, descending): |
| """TFLite StableHLO SORT with LT/GT scalar compare body subgraphs.""" |
| mod = _load_model_from_buffer(_build_stablehlo_sort_model(comparison_direction)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 3), dtype="float32") = R.sort(x, axis=1, descending=descending) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_sort_unsupported_comparator(): |
| """TFLite StableHLO SORT rejects non-ordering comparators.""" |
| _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection |
| buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_EQ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="LT or GT"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_sort_stable_unsupported(): |
| """TFLite StableHLO SORT rejects stable sort until Relax exposes that contract.""" |
| _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection |
| buf = _build_stablehlo_sort_model(_DIR.STABLEHLO_COMPARISON_DIRECTION_LT, is_stable=True) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="stable sort"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_reduce_window_max_pool2d(): |
| """TFLite StableHLO REDUCE_WINDOW max reducer lowers to NHWC max_pool2d.""" |
| mod = _load_model_from_buffer(_build_stablehlo_reduce_window_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((1, 4, 4, 1), dtype="float32"), |
| ) -> R.Tensor((1, 2, 2, 1), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 2, 2, 1), dtype="float32") = R.nn.max_pool2d( |
| x, |
| pool_size=[2, 2], |
| strides=[2, 2], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| ceil_mode=False, |
| layout="NHWC", |
| out_layout="NHWC", |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_reduce_window_unsupported_reducer(): |
| """TFLite StableHLO REDUCE_WINDOW rejects non-max reducers in the pool subset.""" |
| buf = _build_stablehlo_reduce_window_model(reducer_name="STABLEHLO_ADD", init_value=0.0) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="MAXIMUM"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_reduce_window_base_dilation_unsupported(): |
| """TFLite StableHLO REDUCE_WINDOW rejects base dilation in the pool subset.""" |
| buf = _build_stablehlo_reduce_window_model(base_dilations=[1, 2, 1, 1]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="base dilation"): |
| from_tflite(tflite_model) |
| |
| |
| @pytest.mark.parametrize( |
| "reducer_name, reduction", |
| [ |
| ("STABLEHLO_ADD", "add"), |
| ("STABLEHLO_MAXIMUM", "max"), |
| ("STABLEHLO_MINIMUM", "min"), |
| ("STABLEHLO_MULTIPLY", "mul"), |
| ], |
| ) |
| def test_stablehlo_scatter(reducer_name, reduction): |
| """TFLite StableHLO SCATTER point updates lower to Relax scatter_nd.""" |
| mod = _load_model_from_buffer(_build_stablehlo_scatter_model(reducer_name)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| operand: R.Tensor((4,), dtype="float32"), |
| indices: R.Tensor((2, 1), dtype="int32"), |
| updates: R.Tensor((2,), dtype="float32"), |
| ) -> R.Tensor((4,), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| gv: R.Tensor((4,), dtype="float32") = R.scatter_nd( |
| operand, indices, updates, reduction=reduction |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_scatter_unsupported_reducer(): |
| """TFLite StableHLO SCATTER rejects unsupported update computation ops.""" |
| buf = _build_stablehlo_scatter_model(reducer_name="STABLEHLO_SUBTRACT") |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="reducer"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_scatter_update_window_unsupported(): |
| """TFLite StableHLO SCATTER rejects slice update windows in the point subset.""" |
| buf = _build_stablehlo_scatter_model(update_window_dims=[0]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="point updates"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_composite(): |
| """TFLite StableHLO COMPOSITE inlines a simple decomposition subgraph.""" |
| mod = _load_model_from_buffer(_build_stablehlo_composite_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.negative(x) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_composite_does_not_overwrite_main_bindings(): |
| """TFLite StableHLO COMPOSITE decomposition tensor names are scoped locally.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_composite_model(use_main_input_after_composite=True) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) |
| gv: R.Tensor((2, 2), dtype="float32") = R.add(x, lv) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_composite_attributes_unsupported(): |
| """TFLite StableHLO COMPOSITE rejects attributes until they are parsed.""" |
| buf = _build_stablehlo_composite_model(with_attributes=True) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="composite attributes"): |
| from_tflite(tflite_model) |
| |
| |
| @pytest.mark.parametrize( |
| "builtin_name, relax_op, dtype, tensor_type", |
| [ |
| ("STABLEHLO_AND", R.logical_and, "bool", _tfl_tensor_type.BOOL), |
| ("STABLEHLO_OR", R.logical_or, "bool", _tfl_tensor_type.BOOL), |
| ("STABLEHLO_AND", R.bitwise_and, "int32", _tfl_tensor_type.INT32), |
| ("STABLEHLO_OR", R.bitwise_or, "int32", _tfl_tensor_type.INT32), |
| ("STABLEHLO_SHIFT_LEFT", R.left_shift, "int32", _tfl_tensor_type.INT32), |
| ], |
| ) |
| def test_stablehlo_typed_binary(builtin_name, relax_op, dtype, tensor_type): |
| """TFLite StableHLO binary elementwise operators with non-float dtype requirements.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_typed_binary_model(builtin_name=builtin_name, tensor_type=tensor_type) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype=dtype), |
| y: R.Tensor((2, 2), dtype=dtype), |
| ) -> R.Tensor((2, 2), dtype=dtype): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype=dtype) = relax_op(x, y) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| @pytest.mark.parametrize( |
| "builtin_name, relax_op", |
| [ |
| ("STABLEHLO_SELECT", R.where), |
| ], |
| ) |
| def test_stablehlo_ternary(builtin_name, relax_op): |
| """TFLite StableHLO ternary elementwise operators.""" |
| builder = flatbuffers.Builder(1024) |
| shape = [2, 2] |
| builtin_op = _get_stablehlo_builtin_operator(builtin_name) |
| |
| # First input (condition) must be bool for R.where |
| tensor_0 = _build_tensor(builder, 0, shape, tensor_type=_tfl_tensor_type.BOOL) |
| tensor_1 = _build_tensor(builder, 1, shape) |
| tensor_2 = _build_tensor(builder, 2, shape) |
| tensor_out = _build_tensor(builder, 3, shape) |
| tensors = [tensor_0, tensor_1, tensor_2, tensor_out] |
| |
| stablehlo_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2], |
| [3], |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[stablehlo_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| operator_codes = [_build_operator_code(builder, builtin_op)] |
| buffers = [_build_buffer(builder) for _ in range(4)] |
| |
| mod = _load_model_from_buffer( |
| _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers |
| ) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| c: R.Tensor((2, 2), dtype="bool"), |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = relax_op(c, x, y) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_convert_model(): |
| """STABLEHLO_CONVERT: float32 input -> int32 output.""" |
| builder = flatbuffers.Builder(1024) |
| shape = [2, 2] |
| |
| t_in = _build_tensor(builder, 0, shape, tensor_type=_tfl_tensor_type.FLOAT32) |
| t_out = _build_tensor(builder, 1, shape, tensor_type=_tfl_tensor_type.INT32) |
| tensors = [t_in, t_out] |
| |
| op_code = _build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_CONVERT")) |
| op = _build_operator(builder, 0, [0], [1]) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(2)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_convert(): |
| """TFLite StableHLO CONVERT (astype float32 -> int32).""" |
| mod = _load_model_from_buffer(_build_stablehlo_convert_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="int32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="int32") = R.astype(x, dtype="int32") |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_clamp(): |
| """TFLite StableHLO CLAMP (clip with min/operand/max order).""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_model(builtin_name="STABLEHLO_CLAMP", input_count=3) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| m: R.Tensor((2, 2), dtype="float32"), |
| x: R.Tensor((2, 2), dtype="float32"), |
| M: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.minimum(R.maximum(x, m), M) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_concat_model(dimension, num_inputs): |
| """STABLEHLO_CONCATENATE with given dimension and number of inputs.""" |
| builder = flatbuffers.Builder(1024) |
| shape = [2, 2] |
| |
| # Build concat options |
| _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsStart(builder) |
| _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsAddDimension(builder, dimension) |
| concat_opts = _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONCATENATE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| if dimension == 0: |
| out_shape = [num_inputs * shape[0], shape[1]] |
| else: |
| out_shape = [shape[0], num_inputs * shape[1]] |
| tensors = [_build_tensor(builder, i, shape) for i in range(num_inputs)] + [ |
| _build_tensor(builder, num_inputs, out_shape) |
| ] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| list(range(num_inputs)), |
| [num_inputs], |
| builtin_options2_type=_tfl_builtin_options2.StablehloConcatenateOptions, |
| builtin_options2=concat_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=list(range(num_inputs)), |
| outputs=[num_inputs], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(num_inputs + 1)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| @pytest.mark.parametrize("dimension", [0, 1]) |
| def test_stablehlo_concatenate(dimension): |
| """TFLite StableHLO CONCATENATE with 2 inputs along given axis.""" |
| num_inputs = 2 |
| mod = _load_model_from_buffer( |
| _build_stablehlo_concat_model(dimension=dimension, num_inputs=num_inputs) |
| ) |
| |
| out_dim = (4, 2) if dimension == 0 else (2, 4) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor(out_dim, dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor(out_dim, dtype="float32") = R.concat((x, y), axis=dimension) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_broadcast_in_dim_model(input_shape, broadcast_dims, output_shape): |
| """STABLEHLO_BROADCAST_IN_DIM with given broadcast dimensions.""" |
| builder = flatbuffers.Builder(1024) |
| |
| # Build broadcast dimensions vector |
| _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStartBroadcastDimensionsVector( |
| builder, len(broadcast_dims) |
| ) |
| for d in reversed(broadcast_dims): |
| builder.PrependInt64(d) |
| dims_vec = builder.EndVector() |
| |
| _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStart(builder) |
| _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsAddBroadcastDimensions( |
| builder, dims_vec |
| ) |
| bcast_opts = _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_BROADCAST_IN_DIM") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_in = _build_tensor(builder, 0, input_shape) |
| t_out = _build_tensor(builder, 1, output_shape) |
| tensors = [t_in, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options2_type=_tfl_builtin_options2.StablehloBroadcastInDimOptions, |
| builtin_options2=bcast_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(2)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_broadcast_in_dim(): |
| """TFLite StableHLO BROADCAST_IN_DIM: (3,) -> (2, 3) with dims=[1].""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_broadcast_in_dim_model( |
| input_shape=[3], broadcast_dims=[1], output_shape=[2, 3] |
| ) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(R.reshape(x, (1, 3)), (2, 3)) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_iota_model(iota_dimension, output_shape): |
| """STABLEHLO_IOTA with given iota dimension and output shape.""" |
| builder = flatbuffers.Builder(1024) |
| |
| _tfl_stablehlo_iota_opts.StablehloIotaOptionsStart(builder) |
| _tfl_stablehlo_iota_opts.StablehloIotaOptionsAddIotaDimension(builder, iota_dimension) |
| iota_opts = _tfl_stablehlo_iota_opts.StablehloIotaOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_IOTA") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_out = _build_tensor(builder, 0, output_shape, tensor_type=_tfl_tensor_type.INT32) |
| tensors = [t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [], |
| [0], |
| builtin_options2_type=_tfl_builtin_options2.StablehloIotaOptions, |
| builtin_options2=iota_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[], |
| outputs=[0], |
| ) |
| buffers = [_build_buffer(builder)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_iota(): |
| """TFLite StableHLO IOTA: iota_dim=1, shape=(2, 3), dtype=int32.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_iota_model(iota_dimension=1, output_shape=[2, 3]) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main() -> R.Tensor((2, 3), dtype="int32"): |
| R.func_attr({"num_input": 0}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 3), dtype="int32") = R.broadcast_to( |
| R.reshape(R.arange(0, 3, 1, dtype="int32"), (1, 3)), (2, 3) |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_compare_model(direction): |
| """STABLEHLO_COMPARE with given comparison direction.""" |
| builder = flatbuffers.Builder(1024) |
| |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(builder, direction) |
| cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| shape = [2, 2] |
| t_lhs = _build_tensor(builder, 0, shape) |
| t_rhs = _build_tensor(builder, 1, shape) |
| t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL) |
| tensors = [t_lhs, t_rhs, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, |
| builtin_options2=cmp_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "direction_enum, relax_op", |
| [ |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_EQ, |
| R.equal, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_NE, |
| R.not_equal, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GE, |
| R.greater_equal, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, |
| R.greater, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LE, |
| R.less_equal, |
| ), |
| ( |
| _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, |
| R.less, |
| ), |
| ], |
| ) |
| def test_stablehlo_compare(direction_enum, relax_op): |
| """TFLite StableHLO COMPARE with various comparison directions.""" |
| mod = _load_model_from_buffer(_build_stablehlo_compare_model(direction_enum)) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="bool"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="bool") = relax_op(x, y) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_compare_totalorder_unsupported(): |
| """STABLEHLO_COMPARE with TOTALORDER type raises OpNotImplemented.""" |
| builder = flatbuffers.Builder(1024) |
| |
| _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection |
| _TYPE = _tfl_stablehlo_comp_type.StablehloComparisonType |
| |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection( |
| builder, _DIR.STABLEHLO_COMPARISON_DIRECTION_EQ |
| ) |
| _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddCompareType( |
| builder, _TYPE.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER |
| ) |
| cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| shape = [2, 2] |
| t_lhs = _build_tensor(builder, 0, shape) |
| t_rhs = _build_tensor(builder, 1, shape) |
| t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL) |
| tensors = [t_lhs, t_rhs, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, |
| builtin_options2=cmp_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| buf = _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="TOTALORDER"): |
| from_tflite(tflite_model) |
| |
| |
| def _stablehlo_gather_i64_vector(builder, start_vector_fn, values): |
| start_vector_fn(builder, len(values)) |
| for value in reversed(values): |
| builder.PrependInt64(value) |
| return builder.EndVector() |
| |
| |
| def _build_stablehlo_gather_model( |
| *, |
| data_shape, |
| indices_shape, |
| output_shape, |
| offset_dims, |
| collapsed_slice_dims, |
| start_index_map, |
| index_vector_dim, |
| slice_sizes, |
| ): |
| """Build a minimal STABLEHLO_GATHER TFLite model.""" |
| builder = flatbuffers.Builder(1024) |
| |
| offset_dims_vec = _stablehlo_gather_i64_vector( |
| builder, |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartOffsetDimsVector, |
| offset_dims, |
| ) |
| collapsed_slice_dims_vec = _stablehlo_gather_i64_vector( |
| builder, |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartCollapsedSliceDimsVector, |
| collapsed_slice_dims, |
| ) |
| start_index_map_vec = _stablehlo_gather_i64_vector( |
| builder, |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartStartIndexMapVector, |
| start_index_map, |
| ) |
| slice_sizes_vec = _stablehlo_gather_i64_vector( |
| builder, |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartSliceSizesVector, |
| slice_sizes, |
| ) |
| |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsStart(builder) |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddOffsetDims(builder, offset_dims_vec) |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddCollapsedSliceDims( |
| builder, collapsed_slice_dims_vec |
| ) |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddStartIndexMap(builder, start_index_map_vec) |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddIndexVectorDim(builder, index_vector_dim) |
| _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddSliceSizes(builder, slice_sizes_vec) |
| gather_opts = _tfl_stablehlo_gather_opts.StablehloGatherOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_GATHER") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_data = _build_tensor(builder, 0, data_shape) |
| t_indices = _build_tensor(builder, 1, indices_shape, tensor_type=_tfl_tensor_type.INT32) |
| t_out = _build_tensor(builder, 2, output_shape) |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloGatherOptions, |
| builtin_options2=gather_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_data, t_indices, t_out], |
| operators=[op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "axis, offset_dims, slice_sizes, output_shape", |
| [ |
| (0, [1], [1, 4], [2, 4]), |
| (1, [0], [3, 1], [3, 2]), |
| ], |
| ) |
| def test_stablehlo_gather_take_equivalent(axis, offset_dims, slice_sizes, output_shape): |
| """TFLite StableHLO GATHER take-equivalent subset.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_gather_model( |
| data_shape=[3, 4], |
| indices_shape=[2, 1], |
| output_shape=output_shape, |
| offset_dims=offset_dims, |
| collapsed_slice_dims=[axis], |
| start_index_map=[axis], |
| index_vector_dim=1, |
| slice_sizes=slice_sizes, |
| ) |
| ) |
| |
| out_shape = tuple(output_shape) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((3, 4), dtype="float32"), |
| indices: R.Tensor((2, 1), dtype="int32"), |
| ) -> R.Tensor(out_shape, dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| reshaped: R.Tensor((2,), dtype="int32") = R.reshape(indices, (2,)) |
| gv: R.Tensor(out_shape, dtype="float32") = R.take( |
| data, reshaped, axis=axis, mode="fast" |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_gather_complex_unsupported(): |
| """TFLite StableHLO GATHER with multi-dimensional start_index_map is unsupported.""" |
| buf = _build_stablehlo_gather_model( |
| data_shape=[3, 4], |
| indices_shape=[2, 2], |
| output_shape=[2], |
| offset_dims=[], |
| collapsed_slice_dims=[0, 1], |
| start_index_map=[0, 1], |
| index_vector_dim=1, |
| slice_sizes=[1, 1], |
| ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="start_index_map"): |
| from_tflite(tflite_model) |
| |
| |
| def _pad_vector(builder, start_vector_fn, values): |
| """Build a FlatBuffers int64 vector for pad options.""" |
| start_vector_fn(builder, len(values)) |
| for v in reversed(values): |
| builder.PrependInt64(v) |
| return builder.EndVector() |
| |
| |
| def _build_stablehlo_pad_model(edge_low, edge_high, interior): |
| """STABLEHLO_PAD with given padding vectors.""" |
| builder = flatbuffers.Builder(1024) |
| |
| lo_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, |
| edge_low, |
| ) |
| hi_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, |
| edge_high, |
| ) |
| int_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, |
| interior, |
| ) |
| |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) |
| pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_in = _build_tensor(builder, 0, [3, 3]) |
| # pad_value is a scalar tensor |
| t_pad_val = _build_tensor(builder, 1, []) |
| t_out = _build_tensor(builder, 2, [4, 4]) |
| tensors = [t_in, t_pad_val, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, |
| builtin_options2=pad_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[2], |
| ) |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| ] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_pad(): |
| """TFLite StableHLO PAD: edge_low=[1,0], edge_high=[0,1], interior=[0,0].""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_pad_model(edge_low=[1, 0], edge_high=[0, 1], interior=[0, 0]) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((3, 3), dtype="float32"), |
| ) -> R.Tensor((4, 4), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((4, 4), dtype="float32") = R.nn.pad( |
| x, pad_width=[1, 0, 0, 1], pad_value=0.0 |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_pad_interior_unsupported(): |
| """STABLEHLO_PAD with interior padding raises OpNotImplemented.""" |
| builder = flatbuffers.Builder(1024) |
| |
| lo_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, |
| [0, 0], |
| ) |
| hi_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, |
| [0, 0], |
| ) |
| int_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, |
| [1, 0], |
| ) |
| |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) |
| pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_in = _build_tensor(builder, 0, [3, 3]) |
| t_pv = _build_tensor(builder, 1, []) |
| t_out = _build_tensor(builder, 2, [3, 3]) |
| tensors = [t_in, t_pv, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, |
| builtin_options2=pad_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[2], |
| ) |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| ] |
| buf = _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| with pytest.raises(tvm.error.OpNotImplemented, match="interior"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_pad_negative_unsupported(): |
| """STABLEHLO_PAD with negative edge padding raises OpNotImplemented.""" |
| builder = flatbuffers.Builder(1024) |
| |
| lo_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, |
| [-1, 0], |
| ) |
| hi_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, |
| [0, 0], |
| ) |
| int_vec = _pad_vector( |
| builder, |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, |
| [0, 0], |
| ) |
| |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) |
| _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) |
| pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_in = _build_tensor(builder, 0, [3, 3]) |
| t_pv = _build_tensor(builder, 1, []) |
| t_out = _build_tensor(builder, 2, [2, 3]) |
| tensors = [t_in, t_pv, t_out] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, |
| builtin_options2=pad_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[2], |
| ) |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| ] |
| buf = _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| with pytest.raises(tvm.error.OpNotImplemented, match="negative"): |
| from_tflite(tflite_model) |
| |
| |
| def _build_stablehlo_dynamic_slice_model(slice_sizes, start_vals): |
| """STABLEHLO_DYNAMIC_SLICE with given slice sizes and start indices.""" |
| builder = flatbuffers.Builder(1024) |
| ndim = len(slice_sizes) |
| |
| # Build SliceSizes vector |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(builder, ndim) |
| for v in reversed(slice_sizes): |
| builder.PrependInt64(v) |
| sizes_vec = builder.EndVector() |
| |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(builder, sizes_vec) |
| dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| # operand + start indices + output |
| t_in = _build_tensor(builder, 0, [3, 3]) |
| start_tensors = [] |
| start_inputs = [] |
| start_buffers = [] |
| for i, sv in enumerate(start_vals): |
| bidx = 1 + i |
| start_tensors.append(_build_tensor(builder, bidx, [], tensor_type=_tfl_tensor_type.INT32)) |
| start_inputs.append(bidx) |
| start_buffers.append(_build_buffer(builder, np.array([sv], dtype=np.int32).tobytes())) |
| out_idx = 1 + ndim |
| t_out = _build_tensor(builder, out_idx, slice_sizes) |
| tensors = [t_in, *start_tensors, t_out] |
| op_inputs = [0, *start_inputs] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| op_inputs, |
| [out_idx], |
| builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, |
| builtin_options2=dyn_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=[0], |
| outputs=[out_idx], |
| ) |
| buffers = [_build_buffer(builder), *start_buffers, _build_buffer(builder)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): |
| """STABLEHLO_DYNAMIC_SLICE with runtime start-index inputs.""" |
| builder = flatbuffers.Builder(1024) |
| ndim = len(slice_sizes) |
| |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(builder, ndim) |
| for v in reversed(slice_sizes): |
| builder.PrependInt64(v) |
| sizes_vec = builder.EndVector() |
| |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) |
| _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(builder, sizes_vec) |
| dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_in = _build_tensor(builder, 0, [3, 3]) |
| start_tensors = [ |
| _build_tensor(builder, 1 + i, [], tensor_type=_tfl_tensor_type.INT32) for i in range(ndim) |
| ] |
| out_idx = 1 + ndim |
| t_out = _build_tensor(builder, out_idx, slice_sizes) |
| start_inputs = list(range(1, 1 + ndim)) |
| tensors = [t_in, *start_tensors, t_out] |
| op_inputs = [0, *start_inputs] |
| |
| op = _build_operator( |
| builder, |
| 0, |
| op_inputs, |
| [out_idx], |
| builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, |
| builtin_options2=dyn_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=op_inputs, |
| outputs=[out_idx], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_dynamic_slice(): |
| """TFLite StableHLO DYNAMIC_SLICE: start=[0,1], sizes=[2,2] from (3,3).""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_dynamic_slice_model(slice_sizes=[2, 2], start_vals=[0, 1]) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((3, 3), dtype="float32"), |
| ) -> R.Tensor(dtype="float32", ndim=2): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor(dtype="float32", ndim=2) = R.dynamic_strided_slice( |
| x, |
| R.const([0, 1], dtype="int64"), |
| R.const([2, 3], dtype="int64"), |
| R.const([1, 1], dtype="int64"), |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_dynamic_slice_dynamic_starts_unsupported(): |
| """TFLite StableHLO DYNAMIC_SLICE with runtime starts is not supported yet.""" |
| buf = _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes=[2, 2]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): |
| """TFLite StableHLO DYNAMIC_SLICE with out-of-bounds starts is not supported.""" |
| buf = _build_stablehlo_dynamic_slice_model(slice_sizes=[2, 2], start_vals=[0, 2]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_cbrt(): |
| """TFLite StableHLO CBRT uses a sign-preserving composite expression.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_model(builtin_name="STABLEHLO_CBRT", input_count=1) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.negative(x) |
| lv1: R.Tensor((2, 2), dtype="float32") = R.power(lv, R.const(1.0 / 3.0, "float32")) |
| lv2: R.Tensor((2, 2), dtype="bool") = R.less(x, R.const(0, "float32")) |
| lv3: R.Tensor((2, 2), dtype="float32") = R.negative(lv1) |
| lv4: R.Tensor((2, 2), dtype="float32") = R.power(x, R.const(1.0 / 3.0, "float32")) |
| gv: R.Tensor((2, 2), dtype="float32") = R.where(lv2, lv3, lv4) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_remainder(): |
| """TFLite StableHLO REMAINDER uses truncating remainder semantics.""" |
| mod = _load_model_from_buffer( |
| _build_stablehlo_model(builtin_name="STABLEHLO_REMAINDER", input_count=2) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| x: R.Tensor((2, 2), dtype="float32"), |
| y: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.divide(x, y) |
| lv1: R.Tensor((2, 2), dtype="float32") = R.trunc(lv) |
| lv2: R.Tensor((2, 2), dtype="float32") = R.multiply(y, lv1) |
| gv: R.Tensor((2, 2), dtype="float32") = R.subtract(x, lv2) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_stablehlo_dynamic_update_slice_model(start_vals, dynamic_starts=False): |
| """Build a minimal STABLEHLO_DYNAMIC_UPDATE_SLICE model.""" |
| builder = flatbuffers.Builder(1024) |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_UPDATE_SLICE") |
| op_code = _build_operator_code(builder, builtin_op) |
| |
| t_operand = _build_tensor(builder, 0, [3, 4]) |
| t_update = _build_tensor(builder, 1, [2, 2]) |
| start_tensors = [ |
| _build_tensor(builder, 2 + i, [], tensor_type=_tfl_tensor_type.INT32) |
| for i in range(len(start_vals)) |
| ] |
| out_idx = 2 + len(start_vals) |
| t_out = _build_tensor(builder, out_idx, [3, 4]) |
| tensors = [t_operand, t_update, *start_tensors, t_out] |
| |
| op_inputs = [0, 1, *range(2, out_idx)] |
| op = _build_operator(builder, 0, op_inputs, [out_idx]) |
| subgraph_inputs = op_inputs if dynamic_starts else [0, 1] |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[op], |
| inputs=subgraph_inputs, |
| outputs=[out_idx], |
| ) |
| if dynamic_starts: |
| buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] |
| else: |
| start_buffers = [ |
| _build_buffer(builder, np.array([start], dtype=np.int32).tobytes()) |
| for start in start_vals |
| ] |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| *start_buffers, |
| _build_buffer(builder), |
| ] |
| |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_dynamic_update_slice(): |
| """TFLite StableHLO DYNAMIC_UPDATE_SLICE with static starts.""" |
| mod = _load_model_from_buffer(_build_stablehlo_dynamic_update_slice_model([1, 1])) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| operand: R.Tensor((3, 4), dtype="float32"), |
| update: R.Tensor((2, 2), dtype="float32"), |
| ) -> R.Tensor((3, 4), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((3, 4), dtype="float32") = R.scatter_nd( |
| operand, |
| R.const([[[1, 1], [1, 2]], [[2, 1], [2, 2]]], dtype="int64"), |
| update, |
| reduction="update", |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_dynamic_update_slice_dynamic_starts_unsupported(): |
| """TFLite StableHLO DYNAMIC_UPDATE_SLICE with runtime starts is unsupported.""" |
| buf = _build_stablehlo_dynamic_update_slice_model([0, 0], dynamic_starts=True) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_dynamic_update_slice_out_of_bounds_unsupported(): |
| """TFLite StableHLO DYNAMIC_UPDATE_SLICE rejects out-of-bounds updates.""" |
| buf = _build_stablehlo_dynamic_update_slice_model([2, 3]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): |
| from_tflite(tflite_model) |
| |
| |
| def _build_stablehlo_dot_general_model(lhs_contract, rhs_contract, lhs_batch=None, rhs_batch=None): |
| """Build a minimal STABLEHLO_DOT_GENERAL model.""" |
| builder = flatbuffers.Builder(1024) |
| lhs_batch = [] if lhs_batch is None else lhs_batch |
| rhs_batch = [] if rhs_batch is None else rhs_batch |
| |
| lhs_batch_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsBatchingDimensionsVector, |
| lhs_batch, |
| ) |
| rhs_batch_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsBatchingDimensionsVector, |
| rhs_batch, |
| ) |
| lhs_contract_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartLhsContractingDimensionsVector, |
| lhs_contract, |
| ) |
| rhs_contract_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStartRhsContractingDimensionsVector, |
| rhs_contract, |
| ) |
| |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsStart(builder) |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsBatchingDimensions( |
| builder, lhs_batch_vec |
| ) |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsBatchingDimensions( |
| builder, rhs_batch_vec |
| ) |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddLhsContractingDimensions( |
| builder, lhs_contract_vec |
| ) |
| _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsAddRhsContractingDimensions( |
| builder, rhs_contract_vec |
| ) |
| dot_opts = _tfl_stablehlo_dot_opts.StablehloDotGeneralOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DOT_GENERAL") |
| op_code = _build_operator_code(builder, builtin_op) |
| t_lhs = _build_tensor(builder, 0, [2, 3]) |
| t_rhs = _build_tensor(builder, 1, [3, 4]) |
| t_out = _build_tensor(builder, 2, [2, 4]) |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloDotGeneralOptions, |
| builtin_options2=dot_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_lhs, t_rhs, t_out], |
| operators=[op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_dot_general(): |
| """TFLite StableHLO DOT_GENERAL canonical 2D matmul.""" |
| mod = _load_model_from_buffer(_build_stablehlo_dot_general_model([1], [0])) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| lhs: R.Tensor((2, 3), dtype="float32"), |
| rhs: R.Tensor((3, 4), dtype="float32"), |
| ) -> R.Tensor((2, 4), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 4), dtype="float32") = R.matmul(lhs, rhs, out_dtype="void") |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_dot_general_noncanonical_unsupported(): |
| """TFLite StableHLO DOT_GENERAL rejects non-canonical contracting dims.""" |
| buf = _build_stablehlo_dot_general_model([0], [0]) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="contracting"): |
| from_tflite(tflite_model) |
| |
| |
| def _build_stablehlo_convolution_model(feature_group_count=1, input_batch_dimension=0): |
| """Build a minimal STABLEHLO_CONVOLUTION model.""" |
| builder = flatbuffers.Builder(1024) |
| |
| window_strides_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowStridesVector, |
| [1, 1], |
| ) |
| padding_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartPaddingVector, |
| [0, 0, 0, 0], |
| ) |
| lhs_dilation_vec = _tflite_int64_vector( |
| builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartLhsDilationVector, [1, 1] |
| ) |
| rhs_dilation_vec = _tflite_int64_vector( |
| builder, _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartRhsDilationVector, [1, 1] |
| ) |
| window_reversal_vec = _tflite_bool_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartWindowReversalVector, |
| [False, False], |
| ) |
| input_spatial_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartInputSpatialDimensionsVector, |
| [1, 2], |
| ) |
| kernel_spatial_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartKernelSpatialDimensionsVector, |
| [0, 1], |
| ) |
| output_spatial_vec = _tflite_int64_vector( |
| builder, |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStartOutputSpatialDimensionsVector, |
| [1, 2], |
| ) |
| |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsStart(builder) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowStrides( |
| builder, window_strides_vec |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddPadding(builder, padding_vec) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddLhsDilation(builder, lhs_dilation_vec) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddRhsDilation(builder, rhs_dilation_vec) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddWindowReversal( |
| builder, window_reversal_vec |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputBatchDimension( |
| builder, input_batch_dimension |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputFeatureDimension(builder, 3) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddInputSpatialDimensions( |
| builder, input_spatial_vec |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelInputFeatureDimension(builder, 2) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelOutputFeatureDimension(builder, 3) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddKernelSpatialDimensions( |
| builder, kernel_spatial_vec |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputBatchDimension(builder, 0) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputFeatureDimension(builder, 3) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddOutputSpatialDimensions( |
| builder, output_spatial_vec |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddFeatureGroupCount( |
| builder, feature_group_count |
| ) |
| _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsAddBatchGroupCount(builder, 1) |
| conv_opts = _tfl_stablehlo_conv_opts.StablehloConvolutionOptionsEnd(builder) |
| |
| builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONVOLUTION") |
| op_code = _build_operator_code(builder, builtin_op) |
| t_data = _build_tensor(builder, 0, [1, 5, 5, 2]) |
| t_kernel = _build_tensor(builder, 1, [3, 3, 2, 4]) |
| t_out = _build_tensor(builder, 2, [1, 3, 3, 4]) |
| op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options2_type=_tfl_builtin_options2.StablehloConvolutionOptions, |
| builtin_options2=conv_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_data, t_kernel, t_out], |
| operators=[op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| buffers = [_build_buffer(builder) for _ in range(3)] |
| return _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers |
| ) |
| |
| |
| def test_stablehlo_convolution(): |
| """TFLite StableHLO CONVOLUTION canonical NHWC/HWIO 2D convolution.""" |
| mod = _load_model_from_buffer(_build_stablehlo_convolution_model()) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| data: R.Tensor((1, 5, 5, 2), dtype="float32"), |
| kernel: R.Tensor((3, 3, 2, 4), dtype="float32"), |
| ) -> R.Tensor((1, 3, 3, 4), dtype="float32"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 3, 3, 4), dtype="float32") = R.nn.conv2d( |
| data, |
| kernel, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_stablehlo_convolution_feature_group_unsupported(): |
| """TFLite StableHLO CONVOLUTION rejects grouped convolution in the first subset.""" |
| buf = _build_stablehlo_convolution_model(feature_group_count=2) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="feature_group_count"): |
| from_tflite(tflite_model) |
| |
| |
| def test_stablehlo_convolution_dimension_numbers_unsupported(): |
| """TFLite StableHLO CONVOLUTION rejects non-canonical dimension numbers.""" |
| buf = _build_stablehlo_convolution_model(input_batch_dimension=1) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="dimension numbers"): |
| from_tflite(tflite_model) |
| |
| |
| # Quantized TFLite QDQ tests |
| |
| |
| def test_tensor_quantization_parameters_are_parsed(): |
| """Tensor quantization metadata is kept without requiring quantized op support.""" |
| builder = flatbuffers.Builder(1024) |
| |
| per_tensor_quantization = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| per_axis_quantization = _build_quantization_parameters( |
| builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 |
| ) |
| per_tensor = _build_tensor( |
| builder, |
| 0, |
| [1, 4], |
| tensor_type=_tfl_tensor_type.UINT8, |
| quantization=per_tensor_quantization, |
| ) |
| per_axis = _build_tensor( |
| builder, |
| 1, |
| [1, 2, 3, 2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=per_axis_quantization, |
| ) |
| subgraph = _build_subgraph( |
| builder, tensors=[per_tensor, per_axis], operators=[], inputs=[0, 1], outputs=[0, 1] |
| ) |
| buffers = [_build_buffer(builder), _build_buffer(builder)] |
| buf = _finish_tflite_model(builder, subgraph=subgraph, operator_codes=[], buffers=buffers) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| converter = tflite_frontend.OperatorConverter( |
| tflite_model, tflite_model.Subgraphs(0), tflite_frontend.ExprTable(), None |
| ) |
| per_tensor_wrapper, per_axis_wrapper = converter.get_tensors([0, 1]) |
| |
| np.testing.assert_allclose(per_tensor_wrapper.qnn_params["scale"].data.numpy(), 0.5) |
| np.testing.assert_equal(per_tensor_wrapper.qnn_params["zero_point"].data.numpy(), 3) |
| assert per_tensor_wrapper.qnn_params["axis"] == 0 |
| |
| np.testing.assert_allclose( |
| per_axis_wrapper.qnn_params["scale"].data.numpy(), np.array([0.25, 0.75]) |
| ) |
| np.testing.assert_equal(per_axis_wrapper.qnn_params["zero_point"].data.numpy(), 0) |
| assert per_axis_wrapper.qnn_params["axis"] == 3 |
| |
| mod = from_tflite(tflite_model) |
| assert len(mod["main"].params) == 2 |
| |
| |
| def test_quantize_op_uses_relax_quantize(): |
| """TFLite QUANTIZE float32 -> int8 uses R.quantize.""" |
| builder = flatbuffers.Builder(1024) |
| |
| input_data = np.array([1.0, 2.0], dtype=np.float32) |
| output_qparams = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| input_tensor = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.FLOAT32) |
| output_tensor = _build_tensor( |
| builder, |
| 1, |
| [2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=output_qparams, |
| ) |
| |
| quantize_op = _build_operator(builder, 0, [0], [1]) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, output_tensor], |
| operators=[quantize_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.QUANTIZE)] |
| input_buffer = _build_buffer(builder, input_data.tobytes()) |
| output_buffer = _build_buffer(builder) |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[input_buffer, output_buffer], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2,), dtype="float32")) -> R.Tensor((2,), dtype="int8"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2,), dtype="int8") = R.quantize( |
| x, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| axis=0, |
| out_dtype="int8", |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantize_op_requantize_uses_dq_q(): |
| """TFLite QUANTIZE with quantized input uses DQ→Q (requantize).""" |
| builder = flatbuffers.Builder(1024) |
| |
| input_data = np.array([10, 20], dtype=np.int8) |
| input_qparams = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[1], quantized_dimension=0 |
| ) |
| output_qparams = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| input_tensor = _build_tensor( |
| builder, |
| 0, |
| [2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=input_qparams, |
| ) |
| output_tensor = _build_tensor( |
| builder, |
| 1, |
| [2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=output_qparams, |
| ) |
| |
| quantize_op = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, output_tensor], |
| operators=[quantize_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| operator_codes = [ |
| _build_operator_code(builder, _tfl_builtin_operator.QUANTIZE), |
| ] |
| input_buffer = _build_buffer(builder, input_data.tobytes()) |
| output_buffer = _build_buffer(builder) |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[input_buffer, output_buffer], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), |
| ) -> R.Tensor((2,), dtype="int8"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.25, "float32"), |
| R.const(1, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| gv: R.Tensor((2,), dtype="int8") = R.quantize( |
| lv, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_dequantize_op_uses_relax_dequantize(): |
| """TFLite DEQUANTIZE int8 -> float32 uses R.dequantize.""" |
| builder = flatbuffers.Builder(1024) |
| |
| input_data = np.array([10, 20], dtype=np.int8) |
| input_qparams = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| input_tensor = _build_tensor( |
| builder, |
| 0, |
| [2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=input_qparams, |
| ) |
| output_tensor = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.FLOAT32) |
| |
| dequantize_op = _build_operator(builder, 0, [0], [1]) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, output_tensor], |
| operators=[dequantize_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEQUANTIZE)] |
| input_buffer = _build_buffer(builder, input_data.tobytes()) |
| output_buffer = _build_buffer(builder) |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[input_buffer, output_buffer], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2,), dtype="int8")) -> R.Tensor((2,), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2,), dtype="float32") = R.dequantize( |
| x, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_conv2d_per_tensor_uses_qdq(): |
| """Quantized Conv2D with per-tensor quantization uses DQ -> conv2d -> Q.""" |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| input_tensor = _build_tensor( |
| builder, |
| 0, |
| [1, 4, 4, 1], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=in_q, |
| ) |
| weight_tensor = _build_tensor( |
| builder, |
| 1, |
| [2, 3, 3, 1], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=wt_q, |
| ) |
| output_tensor = _build_tensor( |
| builder, |
| 2, |
| [1, 2, 2, 2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=out_q, |
| ) |
| |
| _tfl_conv2d_options.Conv2DOptionsStart(builder) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) |
| _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) |
| conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) |
| |
| conv_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.Conv2DOptions, |
| builtin_options=conv_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, weight_tensor, output_tensor], |
| operators=[conv_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), |
| ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[1, 2, 3, 0], |
| ) |
| lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( |
| lv1, |
| R.const(0.25, "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=3, |
| ) |
| lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( |
| lv, |
| lv2, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( |
| lv3, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_conv2d_per_channel_weight_uses_remapped_axis(): |
| """Quantized Conv2D remaps per-channel weight axis after OHWI -> HWIO.""" |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| input_tensor = _build_tensor( |
| builder, |
| 0, |
| [1, 4, 4, 1], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=in_q, |
| ) |
| weight_tensor = _build_tensor( |
| builder, |
| 1, |
| [2, 3, 3, 1], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=wt_q, |
| ) |
| output_tensor = _build_tensor( |
| builder, |
| 2, |
| [1, 2, 2, 2], |
| tensor_type=_tfl_tensor_type.INT8, |
| quantization=out_q, |
| ) |
| |
| _tfl_conv2d_options.Conv2DOptionsStart(builder) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.VALID) |
| _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) |
| conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) |
| |
| conv_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.Conv2DOptions, |
| builtin_options=conv_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, weight_tensor, output_tensor], |
| operators=[conv_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder), _build_buffer(builder), _build_buffer(builder)], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), |
| ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[1, 2, 3, 0], |
| ) |
| lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( |
| lv1, |
| R.const([0.25, 0.75], "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=3, |
| ) |
| lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( |
| lv, |
| lv2, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( |
| lv3, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_concat_uses_qdq(): |
| """Quantized CONCATENATION uses DQ each input → concat → Q.""" |
| import flatbuffers |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| _tfl_concatenation_options.ConcatenationOptionsStart(builder) |
| _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) |
| _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction(builder, 0) |
| concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) |
| |
| concat_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.ConcatenationOptions, |
| builtin_options=concat_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t0, t1, t2], |
| operators=[concat_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), |
| ) -> R.Tensor((1, 4), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( |
| tvmgen_tensor_1, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) |
| gv: R.Tensor((1, 4), dtype="int8") = R.quantize( |
| lv2, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_concat_fused_relu_uses_quantized_clip(): |
| """Quantized CONCATENATION fused RELU clips in the quantized domain.""" |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| t0 = _build_tensor(builder, 0, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t1 = _build_tensor(builder, 1, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t2 = _build_tensor(builder, 2, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| _tfl_concatenation_options.ConcatenationOptionsStart(builder) |
| _tfl_concatenation_options.ConcatenationOptionsAddAxis(builder, 1) |
| _tfl_concatenation_options.ConcatenationOptionsAddFusedActivationFunction( |
| builder, _tfl_activation_fn.RELU |
| ) |
| concat_opts = _tfl_concatenation_options.ConcatenationOptionsEnd(builder) |
| |
| concat_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.ConcatenationOptions, |
| builtin_options=concat_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t0, t1, t2], |
| operators=[concat_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONCATENATION)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| mod = _load_model_from_buffer(buf) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 2), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((1, 2), dtype="int8"), |
| ) -> R.Tensor((1, 4), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 2), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((1, 2), dtype="float32") = R.dequantize( |
| tvmgen_tensor_1, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv2: R.Tensor((1, 4), dtype="float32") = R.concat((lv, lv1), axis=1) |
| lv3: R.Tensor((1, 4), dtype="int8") = R.quantize( |
| lv2, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| gv: R.Tensor((1, 4), dtype="int8") = R.clip(lv3, min=3.0, max=127.0) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_add_uses_qdq(): |
| """Quantized ADD uses DQ each input -> add -> Q.""" |
| builder = flatbuffers.Builder(1024) |
| |
| lhs_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| rhs_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[1], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) |
| t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) |
| t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| _tfl_add_options.AddOptionsStart(builder) |
| _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, 0) |
| add_opts = _tfl_add_options.AddOptionsEnd(builder) |
| |
| add_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.AddOptions, |
| builtin_options=add_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_lhs, t_rhs, t_out], |
| operators=[add_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| mod = _load_model_from_buffer(buf) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), |
| ) -> R.Tensor((2,), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_1, |
| R.const(0.25, "float32"), |
| R.const(1, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) |
| gv: R.Tensor((2,), dtype="int8") = R.quantize( |
| lv2, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_add_fused_relu6_uses_float_clip_before_quantize(): |
| """Quantized ADD fused RELU6 applies the activation before quantizing.""" |
| builder = flatbuffers.Builder(1024) |
| |
| lhs_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| rhs_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[1], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=lhs_q) |
| t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=rhs_q) |
| t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| _tfl_add_options.AddOptionsStart(builder) |
| _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.RELU6) |
| add_opts = _tfl_add_options.AddOptionsEnd(builder) |
| |
| add_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.AddOptions, |
| builtin_options=add_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_lhs, t_rhs, t_out], |
| operators=[add_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| mod = _load_model_from_buffer(buf) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((2,), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2,), dtype="int8"), |
| ) -> R.Tensor((2,), dtype="int8"): |
| R.func_attr({"num_input": 2}) |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_1, |
| R.const(0.25, "float32"), |
| R.const(1, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv2: R.Tensor((2,), dtype="float32") = R.add(lv, lv1) |
| lv3: R.Tensor((2,), dtype="float32") = R.clip(lv2, min=0, max=6) |
| gv: R.Tensor((2,), dtype="int8") = R.quantize( |
| lv3, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_add_without_output_qparams_invalid(): |
| """Quantized ADD with missing output qparams raises OpAttributeInvalid.""" |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| |
| t_lhs = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t_rhs = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t_out = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT8) |
| |
| _tfl_add_options.AddOptionsStart(builder) |
| _tfl_add_options.AddOptionsAddFusedActivationFunction(builder, _tfl_activation_fn.NONE) |
| add_opts = _tfl_add_options.AddOptionsEnd(builder) |
| |
| add_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.AddOptions, |
| builtin_options=add_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_lhs, t_rhs, t_out], |
| operators=[add_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.ADD)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpAttributeInvalid, match="output must have quantization"): |
| from_tflite(tflite_model) |
| |
| |
| def test_quantized_square_unsupported(): |
| """Quantized SQUARE is rejected instead of applying integer power directly.""" |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t_out = _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| square_op = _build_operator(builder, 0, [0], [1]) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_out], |
| operators=[square_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.SQUARE)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 2, |
| ) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="SQUARE"): |
| _load_model_from_buffer(buf) |
| |
| |
| def test_quantized_conv2d_with_int32_bias_dequantizes_bias(): |
| """Conv2D with INT32 bias dequantizes bias with in_scale x wt_scale.""" |
| import flatbuffers |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor( |
| builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q |
| ) |
| t_wt = _build_tensor( |
| builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q |
| ) |
| t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) |
| t_ou = _build_tensor( |
| builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q |
| ) |
| |
| _tfl_conv2d_options.Conv2DOptionsStart(builder) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) |
| conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) |
| |
| conv_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2], |
| [3], |
| builtin_options_type=_tfl_builtin_options.Conv2DOptions, |
| builtin_options=conv_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_wt, t_bi, t_ou], |
| operators=[conv_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 4, |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), |
| tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), |
| ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[1, 2, 3, 0], |
| ) |
| lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( |
| lv1, |
| R.const(0.25, "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=3, |
| ) |
| lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( |
| lv, |
| lv2, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| lv4: R.Tensor((), dtype="float32") = R.multiply( |
| R.const(0.5, "float32"), |
| R.const(0.25, "float32"), |
| ) |
| lv5: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_2, |
| lv4, |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) |
| gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( |
| lv6, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_conv2d_per_channel_weight_with_int32_bias_dequantizes_bias(): |
| """Conv2D with per-channel weight quantization uses vector bias scale.""" |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor( |
| builder, 0, [1, 4, 4, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q |
| ) |
| t_wt = _build_tensor( |
| builder, 1, [2, 3, 3, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q |
| ) |
| t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) |
| t_ou = _build_tensor( |
| builder, 3, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q |
| ) |
| |
| _tfl_conv2d_options.Conv2DOptionsStart(builder) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddFusedActivationFunction(builder, 0) |
| conv_opts = _tfl_conv2d_options.Conv2DOptionsEnd(builder) |
| |
| conv_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2], |
| [3], |
| builtin_options_type=_tfl_builtin_options.Conv2DOptions, |
| builtin_options=conv_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_wt, t_bi, t_ou], |
| operators=[conv_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.CONV_2D)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 4, |
| ) |
| |
| mod = _load_model_from_buffer(buf) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4, 4, 1), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2, 3, 3, 1), dtype="int8"), |
| tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), |
| ) -> R.Tensor((1, 2, 2, 2), dtype="int8"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((3, 3, 1, 2), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[1, 2, 3, 0], |
| ) |
| lv2: R.Tensor((3, 3, 1, 2), dtype="float32") = R.dequantize( |
| lv1, |
| R.const([0.25, 0.75], "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=3, |
| ) |
| lv3: R.Tensor((1, 2, 2, 2), dtype="float32") = R.nn.conv2d( |
| lv, |
| lv2, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| lv4: R.Tensor((2,), dtype="float32") = R.multiply( |
| R.const(0.5, "float32"), |
| R.const([0.25, 0.75], "float32"), |
| ) |
| lv5: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_2, |
| lv4, |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv6: R.Tensor((1, 2, 2, 2), dtype="float32") = R.add(lv3, lv5) |
| gv: R.Tensor((1, 2, 2, 2), dtype="int8") = R.quantize( |
| lv6, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_per_channel_depthwise_conv_unsupported(): |
| """Per-channel quantized depthwise Conv2D raises OpNotImplemented.""" |
| import flatbuffers |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[0], quantized_dimension=0 |
| ) |
| # Per-channel weight: 2 channels, scale vector length 2 |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25, 0.75], zero_point=[0, 0], quantized_dimension=3 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor( |
| builder, 0, [1, 4, 4, 2], tensor_type=_tfl_tensor_type.INT8, quantization=in_q |
| ) |
| t_wt = _build_tensor( |
| builder, 1, [1, 3, 3, 2], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q |
| ) |
| t_ou = _build_tensor( |
| builder, 2, [1, 2, 2, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q |
| ) |
| |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsStart(builder) |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideH(builder, 1) |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddStrideW(builder, 1) |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddDepthMultiplier(builder, 1) |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddPadding(builder, 1) |
| _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsAddFusedActivationFunction(builder, 0) |
| dw_opts = _tfl_depthwise_conv2d_options.DepthwiseConv2DOptionsEnd(builder) |
| |
| dw_op = _build_operator( |
| builder, |
| 0, |
| [0, 1], |
| [2], |
| builtin_options_type=_tfl_builtin_options.DepthwiseConv2DOptions, |
| builtin_options=dw_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_wt, t_ou], |
| operators=[dw_op], |
| inputs=[0, 1], |
| outputs=[2], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DEPTHWISE_CONV_2D)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 3, |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| |
| with pytest.raises(tvm.error.OpNotImplemented, match="Per-channel"): |
| from_tflite(tflite_model) |
| |
| |
| def test_uint8_reshape_requantize_uses_dq_reshape_q(): |
| """uint8 RESHAPE with different qparams uses DQ→reshape→Q.""" |
| import flatbuffers |
| import numpy as np |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(1024) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[128], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[100], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.UINT8, quantization=in_q) |
| t_ou = _build_tensor(builder, 1, [2, 2], tensor_type=_tfl_tensor_type.UINT8, quantization=out_q) |
| |
| # Use ReshapeOptions with static new_shape [2, 2] |
| new_shape_np = np.array([2, 2], dtype=np.int32) |
| new_shape_vec = _tflite_int32_vector( |
| builder, _tfl_reshape_options.ReshapeOptionsStartNewShapeVector, new_shape_np |
| ) |
| _tfl_reshape_options.ReshapeOptionsStart(builder) |
| _tfl_reshape_options.ReshapeOptionsAddNewShape(builder, new_shape_vec) |
| reshape_opts = _tfl_reshape_options.ReshapeOptionsEnd(builder) |
| |
| reshape_op = _build_operator( |
| builder, |
| 0, |
| [0], |
| [1], |
| builtin_options_type=_tfl_builtin_options.ReshapeOptions, |
| builtin_options=reshape_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_ou], |
| operators=[reshape_op], |
| inputs=[0], |
| outputs=[1], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.RESHAPE)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder), _build_buffer(builder)], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4), dtype="uint8"), |
| ) -> R.Tensor((2, 2), dtype="uint8"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(128, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((2, 2), dtype="float32") = R.reshape( |
| lv, |
| R.shape([2, 2]), |
| ) |
| gv: R.Tensor((2, 2), dtype="uint8") = R.quantize( |
| lv1, |
| R.const(1.0, "float32"), |
| R.const(100, "int32"), |
| out_dtype="uint8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_transpose_conv_with_int32_bias_dequantizes_bias(): |
| """TRANSPOSE_CONV with INT32 bias dequantizes bias before adding.""" |
| import struct |
| |
| import flatbuffers |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor( |
| builder, 0, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=in_q |
| ) |
| t_wt = _build_tensor( |
| builder, 1, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q |
| ) |
| t_bi = _build_tensor(builder, 2, [1], tensor_type=_tfl_tensor_type.INT32) |
| t_ou = _build_tensor( |
| builder, 3, [1, 1, 1, 1], tensor_type=_tfl_tensor_type.INT8, quantization=out_q |
| ) |
| oshape_data = struct.pack("<iiii", 1, 1, 1, 1) |
| t_oshape = _build_tensor(builder, 4, [4], tensor_type=_tfl_tensor_type.INT32) |
| |
| _tfl_transpose_conv_options.TransposeConvOptionsStart(builder) |
| _tfl_transpose_conv_options.TransposeConvOptionsAddStrideH(builder, 1) |
| _tfl_transpose_conv_options.TransposeConvOptionsAddStrideW(builder, 1) |
| _tfl_transpose_conv_options.TransposeConvOptionsAddPadding(builder, 1) # VALID |
| _tfl_transpose_conv_options.TransposeConvOptionsAddFusedActivationFunction(builder, 0) |
| tc_opts = _tfl_transpose_conv_options.TransposeConvOptionsEnd(builder) |
| |
| tc_op = _build_operator( |
| builder, |
| 0, |
| [4, 1, 0, 2], |
| [3], |
| builtin_options_type=_tfl_builtin_options.TransposeConvOptions, |
| builtin_options=tc_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_wt, t_bi, t_ou, t_oshape], |
| operators=[tc_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.TRANSPOSE_CONV)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder, oshape_data), |
| ], |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 1, 1, 1), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((1, 1, 1, 1), dtype="int8"), |
| tvmgen_tensor_2: R.Tensor((1,), dtype="int32"), |
| ) -> R.Tensor((1, 1, 1, 1), dtype="int8"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((1, 1, 1, 1), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[3, 0, 1, 2], |
| ) |
| lv2: R.Tensor((1, 1, 1, 1), dtype="float32") = R.dequantize( |
| lv1, |
| R.const(0.25, "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=1, |
| ) |
| lv3: R.Tensor((1, 1, 1, 1), dtype="float32") = R.nn.conv2d_transpose( |
| lv, |
| lv2, |
| strides=[1, 1], |
| padding=[0, 0, 0, 0], |
| data_layout="NHWC", |
| kernel_layout="IOHW", |
| out_dtype="float32", |
| ) |
| lv4: R.Tensor((), dtype="float32") = R.multiply( |
| R.const(0.5, "float32"), |
| R.const(0.25, "float32"), |
| ) |
| lv5: R.Tensor((1,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_2, |
| lv4, |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv6: R.Tensor((1, 1, 1, 1), dtype="float32") = R.add(lv3, lv5) |
| gv: R.Tensor((1, 1, 1, 1), dtype="int8") = R.quantize( |
| lv6, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_quantized_fully_connected_with_int32_bias_dequantizes_bias(): |
| """Quantized FullyConnected with INT32 bias dequantizes bias with in_scale x wt_scale.""" |
| import flatbuffers |
| import tflite.Model |
| |
| builder = flatbuffers.Builder(2048) |
| |
| in_q = _build_quantization_parameters( |
| builder, scale=[0.5], zero_point=[3], quantized_dimension=0 |
| ) |
| wt_q = _build_quantization_parameters( |
| builder, scale=[0.25], zero_point=[0], quantized_dimension=0 |
| ) |
| out_q = _build_quantization_parameters( |
| builder, scale=[1.0], zero_point=[0], quantized_dimension=0 |
| ) |
| |
| t_in = _build_tensor(builder, 0, [1, 4], tensor_type=_tfl_tensor_type.INT8, quantization=in_q) |
| t_wt = _build_tensor(builder, 1, [2, 4], tensor_type=_tfl_tensor_type.INT8, quantization=wt_q) |
| t_bi = _build_tensor(builder, 2, [2], tensor_type=_tfl_tensor_type.INT32) |
| t_ou = _build_tensor(builder, 3, [1, 2], tensor_type=_tfl_tensor_type.INT8, quantization=out_q) |
| |
| _tfl_fully_connected_options.FullyConnectedOptionsStart(builder) |
| _tfl_fully_connected_options.FullyConnectedOptionsAddFusedActivationFunction(builder, 0) |
| _tfl_fully_connected_options.FullyConnectedOptionsAddWeightsFormat( |
| builder, _tfl_fc_weights_format.DEFAULT |
| ) |
| _tfl_fully_connected_options.FullyConnectedOptionsAddKeepNumDims(builder, 0) |
| fc_opts = _tfl_fully_connected_options.FullyConnectedOptionsEnd(builder) |
| |
| fc_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2], |
| [3], |
| builtin_options_type=_tfl_builtin_options.FullyConnectedOptions, |
| builtin_options=fc_opts, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[t_in, t_wt, t_bi, t_ou], |
| operators=[fc_op], |
| inputs=[0, 1, 2], |
| outputs=[3], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.FULLY_CONNECTED)] |
| buf = _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[_build_buffer(builder)] * 4, |
| ) |
| |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((1, 4), dtype="int8"), |
| tvmgen_tensor_1: R.Tensor((2, 4), dtype="int8"), |
| tvmgen_tensor_2: R.Tensor((2,), dtype="int32"), |
| ) -> R.Tensor((1, 2), dtype="int8"): |
| R.func_attr({"num_input": 3}) |
| with R.dataflow(): |
| lv: R.Tensor((1, 4), dtype="float32") = R.dequantize( |
| tvmgen_tensor_0, |
| R.const(0.5, "float32"), |
| R.const(3, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv1: R.Tensor((4, 2), dtype="int8") = R.permute_dims( |
| tvmgen_tensor_1, |
| axes=[1, 0], |
| ) |
| lv2: R.Tensor((4, 2), dtype="float32") = R.dequantize( |
| lv1, |
| R.const(0.25, "float32"), |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=1, |
| ) |
| lv3: R.Tensor((1, 2), dtype="float32") = R.matmul(lv, lv2, out_dtype="void") |
| lv4: R.Tensor((), dtype="float32") = R.multiply( |
| R.const(0.5, "float32"), |
| R.const(0.25, "float32"), |
| ) |
| lv5: R.Tensor((2,), dtype="float32") = R.dequantize( |
| tvmgen_tensor_2, |
| lv4, |
| R.const(0, "int32"), |
| out_dtype="float32", |
| axis=0, |
| ) |
| lv6: R.Tensor((1, 2), dtype="float32") = R.add(lv3, lv5) |
| gv: R.Tensor((1, 2), dtype="int8") = R.quantize( |
| lv6, |
| R.const(1.0, "float32"), |
| R.const(0, "int32"), |
| out_dtype="int8", |
| axis=0, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_csr_sparsity( |
| builder, |
| *, |
| dense_sizes, |
| row_ptrs, |
| col_indices, |
| sparse_axis, |
| traversal_order=None, |
| ): |
| row_ptrs_vec = _tflite_int32_table(builder, row_ptrs) |
| col_indices_vec = _tflite_int32_table(builder, col_indices) |
| dim_metadata = [] |
| |
| for axis, dense_size in enumerate(dense_sizes): |
| _tfl_dimension_metadata.DimensionMetadataStart(builder) |
| if axis == sparse_axis: |
| _tfl_dimension_metadata.DimensionMetadataAddFormat( |
| builder, _tfl_dimension_type.SPARSE_CSR |
| ) |
| _tfl_dimension_metadata.DimensionMetadataAddArraySegmentsType( |
| builder, _tfl_sparse_index_vector.Int32Vector |
| ) |
| _tfl_dimension_metadata.DimensionMetadataAddArraySegments(builder, row_ptrs_vec) |
| _tfl_dimension_metadata.DimensionMetadataAddArrayIndicesType( |
| builder, _tfl_sparse_index_vector.Int32Vector |
| ) |
| _tfl_dimension_metadata.DimensionMetadataAddArrayIndices(builder, col_indices_vec) |
| else: |
| _tfl_dimension_metadata.DimensionMetadataAddFormat(builder, _tfl_dimension_type.DENSE) |
| _tfl_dimension_metadata.DimensionMetadataAddDenseSize(builder, dense_size) |
| dim_metadata.append(_tfl_dimension_metadata.DimensionMetadataEnd(builder)) |
| |
| if traversal_order is None: |
| traversal_order = list(range(len(dense_sizes))) |
| |
| traversal_order_vec = _tflite_int32_vector( |
| builder, |
| _tfl_sparsity_parameters.SparsityParametersStartTraversalOrderVector, |
| traversal_order, |
| ) |
| dim_metadata_vec = _tflite_offset_vector( |
| builder, _tfl_sparsity_parameters.SparsityParametersStartDimMetadataVector, dim_metadata |
| ) |
| |
| _tfl_sparsity_parameters.SparsityParametersStart(builder) |
| _tfl_sparsity_parameters.SparsityParametersAddTraversalOrder(builder, traversal_order_vec) |
| _tfl_sparsity_parameters.SparsityParametersAddDimMetadata(builder, dim_metadata_vec) |
| return _tfl_sparsity_parameters.SparsityParametersEnd(builder) |
| |
| |
| def _build_densify_only_case(builder): |
| sparse_tensor_idx = 0 |
| dense_tensor_idx = 1 |
| shape = [2, 2] |
| sparsity = _build_csr_sparsity( |
| builder, |
| dense_sizes=shape, |
| row_ptrs=_DENSIFY_ROW_PTRS, |
| col_indices=_DENSIFY_COL_INDICES, |
| sparse_axis=1, |
| ) |
| |
| sparse_tensor = _build_tensor(builder, 0, shape, sparsity) |
| dense_tensor = _build_tensor(builder, 1, shape) |
| densify_op = _build_operator( |
| builder, |
| 0, |
| [sparse_tensor_idx], |
| [dense_tensor_idx], |
| _tfl_builtin_options.DensifyOptions, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[sparse_tensor, dense_tensor], |
| operators=[densify_op], |
| inputs=[], |
| outputs=[dense_tensor_idx], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DENSIFY)] |
| return _DENSIFY_TEST_VALUES, subgraph, operator_codes |
| |
| |
| def _build_densify_add_case(builder): |
| input_tensor_idx = 0 |
| sparse_tensor_idx = 1 |
| dense_tensor_idx = 2 |
| output_tensor_idx = 3 |
| shape = [2, 2] |
| sparsity = _build_csr_sparsity( |
| builder, |
| dense_sizes=shape, |
| row_ptrs=_DENSIFY_ROW_PTRS, |
| col_indices=_DENSIFY_COL_INDICES, |
| sparse_axis=1, |
| ) |
| |
| input_tensor = _build_tensor(builder, 1, shape) |
| sparse_tensor = _build_tensor(builder, 0, shape, sparsity) |
| dense_tensor = _build_tensor(builder, 1, shape) |
| output_tensor = _build_tensor(builder, 1, shape) |
| |
| densify_op = _build_operator( |
| builder, |
| 1, |
| [sparse_tensor_idx], |
| [dense_tensor_idx], |
| _tfl_builtin_options.DensifyOptions, |
| ) |
| _tfl_add_options.AddOptionsStart(builder) |
| add_options = _tfl_add_options.AddOptionsEnd(builder) |
| add_op = _build_operator( |
| builder, |
| 0, |
| [input_tensor_idx, dense_tensor_idx], |
| [output_tensor_idx], |
| _tfl_builtin_options.AddOptions, |
| add_options, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, sparse_tensor, dense_tensor, output_tensor], |
| operators=[densify_op, add_op], |
| inputs=[input_tensor_idx], |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [ |
| _build_operator_code(builder, _tfl_builtin_operator.ADD), |
| _build_operator_code(builder, _tfl_builtin_operator.DENSIFY), |
| ] |
| return _DENSIFY_TEST_VALUES, subgraph, operator_codes |
| |
| |
| def _build_densify_conv2d_case(builder): |
| input_tensor_idx = 0 |
| sparse_kernel_idx = 1 |
| dense_kernel_idx = 2 |
| output_tensor_idx = 3 |
| |
| sparsity = _build_csr_sparsity( |
| builder, |
| dense_sizes=[1, 2, 2, 1], |
| row_ptrs=_DENSIFY_ROW_PTRS, |
| col_indices=_DENSIFY_COL_INDICES, |
| sparse_axis=2, |
| ) |
| |
| input_tensor = _build_tensor(builder, 1, [1, 4, 4, 1]) |
| sparse_kernel = _build_tensor(builder, 0, [1, 2, 2, 1], sparsity) |
| dense_kernel = _build_tensor(builder, 1, [1, 2, 2, 1]) |
| output_tensor = _build_tensor(builder, 1, [1, 4, 4, 1]) |
| |
| _tfl_conv2d_options.Conv2DOptionsStart(builder) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideH(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddStrideW(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddPadding(builder, _tfl_padding.SAME) |
| _tfl_conv2d_options.Conv2DOptionsAddDilationHFactor(builder, 1) |
| _tfl_conv2d_options.Conv2DOptionsAddDilationWFactor(builder, 1) |
| conv2d_options = _tfl_conv2d_options.Conv2DOptionsEnd(builder) |
| |
| densify_op = _build_operator( |
| builder, |
| 1, |
| [sparse_kernel_idx], |
| [dense_kernel_idx], |
| _tfl_builtin_options.DensifyOptions, |
| ) |
| conv2d_op = _build_operator( |
| builder, |
| 0, |
| [input_tensor_idx, dense_kernel_idx], |
| [output_tensor_idx], |
| _tfl_builtin_options.Conv2DOptions, |
| conv2d_options, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, sparse_kernel, dense_kernel, output_tensor], |
| operators=[densify_op, conv2d_op], |
| inputs=[input_tensor_idx], |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [ |
| _build_operator_code(builder, _tfl_builtin_operator.CONV_2D), |
| _build_operator_code(builder, _tfl_builtin_operator.DENSIFY), |
| ] |
| return _DENSIFY_TEST_VALUES, subgraph, operator_codes |
| |
| |
| def _build_densify_fully_connected_case(builder): |
| input_tensor_idx = 0 |
| sparse_weight_idx = 1 |
| dense_weight_idx = 2 |
| output_tensor_idx = 3 |
| weight_shape = [4, 4] |
| |
| sparsity = _build_csr_sparsity( |
| builder, |
| dense_sizes=weight_shape, |
| row_ptrs=_DENSIFY_FC_ROW_PTRS, |
| col_indices=_DENSIFY_FC_COL_INDICES, |
| sparse_axis=1, |
| ) |
| |
| input_tensor = _build_tensor(builder, 1, [1, 4]) |
| sparse_weight = _build_tensor(builder, 0, weight_shape, sparsity) |
| dense_weight = _build_tensor(builder, 1, weight_shape) |
| output_tensor = _build_tensor(builder, 1, [1, 4]) |
| |
| _tfl_fully_connected_options.FullyConnectedOptionsStart(builder) |
| _tfl_fully_connected_options.FullyConnectedOptionsAddWeightsFormat( |
| builder, _tfl_fc_weights_format.DEFAULT |
| ) |
| fc_options = _tfl_fully_connected_options.FullyConnectedOptionsEnd(builder) |
| |
| densify_op = _build_operator( |
| builder, |
| 1, |
| [sparse_weight_idx], |
| [dense_weight_idx], |
| _tfl_builtin_options.DensifyOptions, |
| ) |
| fc_op = _build_operator( |
| builder, |
| 0, |
| [input_tensor_idx, dense_weight_idx], |
| [output_tensor_idx], |
| _tfl_builtin_options.FullyConnectedOptions, |
| fc_options, |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, sparse_weight, dense_weight, output_tensor], |
| operators=[densify_op, fc_op], |
| inputs=[input_tensor_idx], |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [ |
| _build_operator_code(builder, _tfl_builtin_operator.FULLY_CONNECTED), |
| _build_operator_code(builder, _tfl_builtin_operator.DENSIFY), |
| ] |
| return _DENSIFY_FC_WEIGHT_VALUES, subgraph, operator_codes |
| |
| |
| def _build_densify_model(*, downstream_op=None): |
| """Build a sparse TFLite model with DENSIFY operator for testing.""" |
| scenario_builders = { |
| None: _build_densify_only_case, |
| "add": _build_densify_add_case, |
| "conv2d": _build_densify_conv2d_case, |
| "fully_connected": _build_densify_fully_connected_case, |
| } |
| if downstream_op not in scenario_builders: |
| raise ValueError(f"Unsupported DENSIFY downstream op: {downstream_op}") |
| |
| builder = flatbuffers.Builder(4096) |
| sparse_values, subgraph, operator_codes = scenario_builders[downstream_op](builder) |
| sparse_buffer = _build_buffer(builder, sparse_values.tobytes()) |
| empty_buffer = _build_buffer(builder) |
| return _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=operator_codes, |
| buffers=[sparse_buffer, empty_buffer], |
| ) |
| |
| |
| def _load_densify_module(downstream_op=None): |
| """Load a DENSIFY test model and return the converted Relax module.""" |
| model_bytes = _build_densify_model(downstream_op=downstream_op) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(model_bytes, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| return mod |
| |
| |
| def test_densify(): |
| """Test TFLite DENSIFY operator conversion.""" |
| mod = _load_densify_module() |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main() -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 0}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.const(_DENSIFY_TEST_DENSE) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_densify_with_add(): |
| """Test DENSIFY followed by a downstream ADD operator.""" |
| mod = _load_densify_module(downstream_op="add") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((2, 2), dtype="float32") = R.add(x, R.const(_DENSIFY_TEST_DENSE)) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_densify_with_conv2d(): |
| """Test DENSIFY followed by CONV2D - a real-world scenario. |
| |
| This simulates a sparse convolution where DENSIFY converts sparse weights |
| before CONV2D uses them for inference. |
| """ |
| mod = _load_densify_module(downstream_op="conv2d") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 4, 4, 1), dtype="float32")) -> R.Tensor( |
| (1, 4, 4, 1), dtype="float32" |
| ): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| gv: R.Tensor((1, 4, 4, 1), dtype="float32") = R.nn.conv2d( |
| x, |
| R.const(_DENSIFY_CONV_KERNEL_DENSE_HWIO), |
| strides=[1, 1], |
| padding=[0, 0, 1, 1], |
| dilation=[1, 1], |
| groups=1, |
| data_layout="NHWC", |
| kernel_layout="HWIO", |
| out_layout="NHWC", |
| out_dtype="void", |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_densify_with_fully_connected(): |
| """Test DENSIFY followed by FULLY_CONNECTED - a real-world scenario. |
| |
| This simulates a sparse fully connected layer where DENSIFY converts |
| sparse weights before matrix multiplication for inference. |
| """ |
| mod = _load_densify_module(downstream_op="fully_connected") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((1, 4), dtype="float32")) -> R.Tensor((1, 4), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| weight_t: R.Tensor((4, 4), dtype="float32") = R.permute_dims( |
| R.const(_DENSIFY_FC_WEIGHT_DENSE_OI), axes=[1, 0] |
| ) |
| gv: R.Tensor((1, 4), dtype="float32") = R.matmul(x, weight_t, out_dtype="void") |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def _build_dilate_only_case( |
| builder, *, input_shape, dilations, dilation_value, dynamic_dilations=False |
| ): |
| input_tensor_idx = 0 |
| dilations_tensor_idx = 1 |
| padding_value_tensor_idx = 2 |
| output_tensor_idx = 3 |
| |
| output_shape = tuple((input_shape[i] - 1) * dilations[i] + 1 for i in range(len(input_shape))) |
| |
| input_tensor = _build_tensor(builder, 1, input_shape) |
| dilations_tensor = _build_tensor( |
| builder, 2, [len(dilations)], tensor_type=_tfl_tensor_type.INT32 |
| ) |
| padding_value_tensor = _build_tensor(builder, 3, []) |
| output_tensor = _build_tensor(builder, 4, output_shape) |
| |
| _tfl_dilate_options.DilateOptionsStart(builder) |
| dilate_opts = _tfl_dilate_options.DilateOptionsEnd(builder) |
| |
| dilate_op = _build_operator( |
| builder, |
| 0, |
| [input_tensor_idx, dilations_tensor_idx, padding_value_tensor_idx], |
| [output_tensor_idx], |
| builtin_options2_type=_tfl_builtin_options2.DilateOptions, |
| builtin_options2=dilate_opts, |
| ) |
| sg_inputs = ( |
| [input_tensor_idx, dilations_tensor_idx] if dynamic_dilations else [input_tensor_idx] |
| ) |
| subgraph = _build_subgraph( |
| builder, |
| tensors=[input_tensor, dilations_tensor, padding_value_tensor, output_tensor], |
| operators=[dilate_op], |
| inputs=sg_inputs, |
| outputs=[output_tensor_idx], |
| ) |
| operator_codes = [_build_operator_code(builder, _tfl_builtin_operator.DILATE)] |
| return subgraph, operator_codes |
| |
| |
| def test_dilate(): |
| """TFLite DILATE with constant dilations""" |
| builder = flatbuffers.Builder(1024) |
| input_shape = (3, 4) |
| dilations = [2, 2] |
| dilation_value = 0.5 |
| |
| subgraph, operator_codes = _build_dilate_only_case( |
| builder, |
| input_shape=input_shape, |
| dilations=dilations, |
| dilation_value=dilation_value, |
| ) |
| |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder, np.asarray(dilations, dtype=np.int32).tobytes()), |
| _build_buffer(builder, np.asarray([dilation_value], dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| ] |
| |
| buf = _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers |
| ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), |
| ) -> R.Tensor((5, 7), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( |
| tvmgen_tensor_0, R.shape([3, 1, 4]) |
| ) |
| lv1: R.Tensor((3, 1, 4), dtype="float32") = R.full( |
| R.shape([3, 1, 4]), R.const(0.5, "float32"), dtype="float32" |
| ) |
| lv2: R.Tensor((3, 2, 4), dtype="float32") = R.concat((lv, lv1), axis=1) |
| lv3: R.Tensor((6, 4), dtype="float32") = R.reshape(lv2, R.shape([6, 4])) |
| lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice( |
| lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False |
| ) |
| lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape(lv4, R.shape([5, 4, 1])) |
| lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full( |
| R.shape([5, 4, 1]), R.const(0.5, "float32"), dtype="float32" |
| ) |
| lv7: R.Tensor((5, 4, 2), dtype="float32") = R.concat((lv5, lv6), axis=2) |
| lv8: R.Tensor((5, 8), dtype="float32") = R.reshape(lv7, R.shape([5, 8])) |
| gv: R.Tensor((5, 7), dtype="float32") = R.strided_slice( |
| lv8, [0, 1], [0, 0], [5, 7], [1, 1], assume_inbound=False |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_dilate_dynamic_dilations(): |
| """DILATE with runtime dilations""" |
| builder = flatbuffers.Builder(1024) |
| input_shape = (3, 4) |
| dilations_for_shape = [2, 2] |
| dilation_value = 0.5 |
| |
| subgraph, operator_codes = _build_dilate_only_case( |
| builder, |
| input_shape=input_shape, |
| dilations=dilations_for_shape, |
| dilation_value=dilation_value, |
| dynamic_dilations=True, |
| ) |
| |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder), |
| _build_buffer(builder), # dilations is a runtime input so empty buffer |
| _build_buffer(builder, np.asarray([dilation_value], dtype=np.float32).tobytes()), |
| _build_buffer(builder), |
| ] |
| |
| buf = _finish_tflite_model( |
| builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers |
| ) |
| if hasattr(tflite.Model, "Model"): |
| tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) |
| else: |
| tflite_model = tflite.Model.GetRootAsModel(buf, 0) |
| mod = from_tflite(tflite_model) |
| mod["main"] = mod["main"].without_attr("params") |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main( |
| tvmgen_tensor_0: R.Tensor((3, 4), dtype="float32"), |
| tvmgen_tensor_1: R.Tensor((2,), dtype="int32"), |
| ) -> R.Tensor(dtype="float32", ndim=2): |
| R.func_attr({"num_input": 2}) |
| dilate_stride_0 = T.int64() |
| dilate_stride_1 = T.int64() |
| with R.dataflow(): |
| lv: R.Tensor((2,), dtype="int32") = R.match_cast( |
| tvmgen_tensor_1, R.Tensor((2,), dtype="int32") |
| ) |
| lv1: R.Tensor((2,), dtype="int64") = R.astype(lv, dtype="int64") |
| lv2: R.Shape(ndim=2) = R.tensor_to_shape(lv1) |
| _lv3: R.Shape([dilate_stride_0, dilate_stride_1]) = R.match_cast( |
| lv2, R.Shape([dilate_stride_0, dilate_stride_1]) |
| ) |
| lv4: R.Tensor((3, 1, 4), dtype="float32") = R.reshape( |
| tvmgen_tensor_0, R.shape([3, 1, 4]) |
| ) |
| lv5: R.Tensor((3, dilate_stride_0 - 1, 4), dtype="float32") = R.full( |
| R.shape([3, dilate_stride_0 - 1, 4]), |
| R.const(0.5, "float32"), |
| dtype="float32", |
| ) |
| lv6: R.Tensor((3, 1 + (dilate_stride_0 - 1), 4), dtype="float32") = R.concat( |
| (lv4, lv5), axis=1 |
| ) |
| lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = R.reshape( |
| lv6, R.shape([3 * dilate_stride_0, 4]) |
| ) |
| lv8: R.Tensor( |
| (T.min(dilate_stride_0 * 2 + 1, dilate_stride_0 * 3), 4), |
| dtype="float32", |
| ) = R.strided_slice( |
| lv7, |
| [0, 1], |
| [0, 0], |
| [2 * dilate_stride_0 + 1, 4], |
| [1, 1], |
| assume_inbound=False, |
| ) |
| lv9: R.Tensor((2 * dilate_stride_0 + 1, 4, 1), dtype="float32") = R.reshape( |
| lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1]) |
| ) |
| lv10: R.Tensor( |
| (2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), dtype="float32" |
| ) = R.full( |
| R.shape([2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1]), |
| R.const(0.5, "float32"), |
| dtype="float32", |
| ) |
| lv11: R.Tensor( |
| (2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)), |
| dtype="float32", |
| ) = R.concat((lv9, lv10), axis=2) |
| lv12: R.Tensor((2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32") = ( |
| R.reshape(lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1])) |
| ) |
| gv: R.Tensor( |
| ( |
| dilate_stride_0 * 2 + 1, |
| T.min(dilate_stride_1 * 3 + 1, dilate_stride_1 * 4), |
| ), |
| dtype="float32", |
| ) = R.strided_slice( |
| lv12, |
| [0, 1], |
| [0, 0], |
| [2 * dilate_stride_0 + 1, 3 * dilate_stride_1 + 1], |
| [1, 1], |
| assume_inbound=False, |
| ) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| # ── UNIDIRECTIONAL_SEQUENCE_RNN ─────────────────────────────────────────────── |
| |
| |
| def _build_unidirectional_sequence_rnn_model( |
| batch, |
| time, |
| input_size, |
| num_units, |
| weights, |
| recurrent_weights, |
| bias, |
| activation, |
| *, |
| time_major=False, |
| ): |
| """Build a minimal TFLite flatbuffer model containing one UNIDIRECTIONAL_SEQUENCE_RNN op. |
| |
| Tensor layout (indices 0-5): |
| 0 - input [batch, time, input_size] (or [time, batch, input_size] if time_major) |
| 1 - input_weights [num_units, input_size] (constant) |
| 2 - recurrent_wts [num_units, num_units] (constant) |
| 3 - bias [num_units] (constant) |
| 4 - hidden_state [batch, num_units] (variable, zero-initialised) |
| 5 - output [batch, time, num_units] |
| """ |
| builder = flatbuffers.Builder(4096) |
| |
| _tfl_sequence_rnn_options.SequenceRNNOptionsStart(builder) |
| _tfl_sequence_rnn_options.SequenceRNNOptionsAddTimeMajor(builder, time_major) |
| _tfl_sequence_rnn_options.SequenceRNNOptionsAddFusedActivationFunction(builder, activation) |
| rnn_opts = _tfl_sequence_rnn_options.SequenceRNNOptionsEnd(builder) |
| |
| rnn_op_code = _build_operator_code(builder, _tfl_builtin_operator.UNIDIRECTIONAL_SEQUENCE_RNN) |
| |
| input_shape = [time, batch, input_size] if time_major else [batch, time, input_size] |
| |
| def _t(buf_idx, shape, is_variable=False): |
| shape_vec = _tflite_shape(builder, shape) |
| _tfl_tensor.TensorStart(builder) |
| _tfl_tensor.TensorAddBuffer(builder, buf_idx) |
| _tfl_tensor.TensorAddHasRank(builder, True) |
| _tfl_tensor.TensorAddIsVariable(builder, is_variable) |
| _tfl_tensor.TensorAddShape(builder, shape_vec) |
| _tfl_tensor.TensorAddType(builder, _tfl_tensor_type.FLOAT32) |
| return _tfl_tensor.TensorEnd(builder) |
| |
| tensors = [ |
| _t(0, input_shape), |
| _t(1, [num_units, input_size]), |
| _t(2, [num_units, num_units]), |
| _t(3, [num_units]), |
| _t(4, [batch, num_units], is_variable=True), |
| _t(5, [batch, time, num_units]), |
| ] |
| |
| rnn_op = _build_operator( |
| builder, |
| 0, |
| [0, 1, 2, 3, 4], |
| [5], |
| builtin_options_type=_tfl_builtin_options.SequenceRNNOptions, |
| builtin_options=rnn_opts, |
| ) |
| |
| subgraph = _build_subgraph( |
| builder, |
| tensors=tensors, |
| operators=[rnn_op], |
| inputs=[0], |
| outputs=[5], |
| ) |
| |
| buffers = [ |
| _build_buffer(builder), |
| _build_buffer(builder, weights.tobytes()), |
| _build_buffer(builder, recurrent_weights.tobytes()), |
| _build_buffer(builder, bias.tobytes()), |
| _build_buffer(builder), |
| _build_buffer(builder), |
| ] |
| |
| return _finish_tflite_model( |
| builder, |
| subgraph=subgraph, |
| operator_codes=[rnn_op_code], |
| buffers=buffers, |
| ) |
| |
| |
| def test_unidirectional_sequence_rnn_none_activation(): |
| """UNIDIRECTIONAL_SEQUENCE_RNN with NONE activation, time=1, lowers to matmul/add/stack. |
| |
| Cell equation: h_t = x_t @ W.T + h_{t-1} @ Wr.T + b (no activation for NONE) |
| """ |
| from tflite.ActivationFunctionType import ActivationFunctionType |
| |
| batch, time, input_size, num_units = 2, 1, 2, 2 |
| weights = np.eye(num_units, input_size, dtype=np.float32) |
| recurrent_weights = np.eye(num_units, dtype=np.float32) |
| bias = np.zeros(num_units, dtype=np.float32) |
| |
| mod = _load_model_from_buffer( |
| _build_unidirectional_sequence_rnn_model( |
| batch, |
| time, |
| input_size, |
| num_units, |
| weights, |
| recurrent_weights, |
| bias, |
| ActivationFunctionType.NONE, |
| ) |
| ) |
| |
| @I.ir_module |
| class Expected: |
| @R.function |
| def main(x: R.Tensor((2, 1, 2), dtype="float32")) -> R.Tensor((2, 1, 2), dtype="float32"): |
| R.func_attr({"num_input": 1}) |
| with R.dataflow(): |
| lv: R.Tensor((2, 2), dtype="float32") = R.squeeze(x, axis=[1]) |
| lv1: R.Tensor((2, 2), dtype="float32") = R.permute_dims( |
| R.const(np.eye(2, dtype=np.float32)), axes=None |
| ) |
| lv2: R.Tensor((2, 2), dtype="float32") = R.matmul(lv, lv1, out_dtype="void") |
| lv3: R.Tensor((2, 2), dtype="float32") = R.zeros(R.shape([2, 2]), dtype="float32") |
| lv4: R.Tensor((2, 2), dtype="float32") = R.permute_dims( |
| R.const(np.eye(2, dtype=np.float32)), axes=None |
| ) |
| lv5: R.Tensor((2, 2), dtype="float32") = R.matmul(lv3, lv4, out_dtype="void") |
| lv6: R.Tensor((2, 2), dtype="float32") = R.add(lv2, lv5) |
| lv7: R.Tensor((2, 2), dtype="float32") = R.add( |
| lv6, R.const(np.zeros(2, dtype=np.float32)) |
| ) |
| gv: R.Tensor((2, 1, 2), dtype="float32") = R.stack((lv7,), axis=1) |
| R.output(gv) |
| return gv |
| |
| tvm.ir.assert_structural_equal(mod, Expected) |
| |
| |
| def test_unidirectional_sequence_rnn_relu_activation(): |
| """UNIDIRECTIONAL_SEQUENCE_RNN with RELU activation and multiple time steps.""" |
| from tflite.ActivationFunctionType import ActivationFunctionType |
| |
| batch, time, input_size, num_units = 2, 3, 4, 8 |
| np.random.seed(42) |
| weights = np.random.randn(num_units, input_size).astype(np.float32) |
| recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) |
| bias = np.random.randn(num_units).astype(np.float32) |
| |
| mod = _load_model_from_buffer( |
| _build_unidirectional_sequence_rnn_model( |
| batch, |
| time, |
| input_size, |
| num_units, |
| weights, |
| recurrent_weights, |
| bias, |
| ActivationFunctionType.RELU, |
| ) |
| ) |
| |
| fn = mod["main"] |
| assert len(fn.params) == 1, "only the sequence input should be a graph input" |
| in_shape = fn.params[0].struct_info.shape |
| assert tuple(int(d) for d in in_shape) == (batch, time, input_size) |
| out_shape = fn.ret_struct_info.shape |
| assert tuple(int(d) for d in out_shape) == (batch, time, num_units) |
| |
| |
| def test_unidirectional_sequence_rnn_time_major(): |
| """UNIDIRECTIONAL_SEQUENCE_RNN with time_major=True transposes before unrolling.""" |
| from tflite.ActivationFunctionType import ActivationFunctionType |
| |
| batch, time, input_size, num_units = 3, 4, 2, 5 |
| np.random.seed(7) |
| weights = np.random.randn(num_units, input_size).astype(np.float32) |
| recurrent_weights = np.random.randn(num_units, num_units).astype(np.float32) |
| bias = np.zeros(num_units, dtype=np.float32) |
| |
| mod = _load_model_from_buffer( |
| _build_unidirectional_sequence_rnn_model( |
| batch, |
| time, |
| input_size, |
| num_units, |
| weights, |
| recurrent_weights, |
| bias, |
| ActivationFunctionType.NONE, |
| time_major=True, |
| ) |
| ) |
| |
| fn = mod["main"] |
| # Input to the graph is the raw time-major tensor [time, batch, input_size]. |
| in_shape = fn.params[0].struct_info.shape |
| assert tuple(int(d) for d in in_shape) == (time, batch, input_size) |
| # Output is always batch-major [batch, time, num_units]. |
| out_shape = fn.ret_struct_info.shape |
| assert tuple(int(d) for d in out_shape) == (batch, time, num_units) |
| |
| |
| if __name__ == "__main__": |
| pytest.main(["-s", __file__]) |