blob: a2f189297ae02c001fecee609f9c54e255953646 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Comprehensive test cases for Relax to PyFunc converter.
Tests all major features including basic operations, call_tir, call_dps_packed, and symbolic shapes.
"""
import pytest
import torch
import torch.nn.functional as F
import numpy as np
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter
@I.ir_module
class ComprehensiveTestModule:
"""Test module covering all converter features."""
@T.prim_func
def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
"""TIR function for addition."""
x = T.match_buffer(var_x, (5,), "float32")
y = T.match_buffer(var_y, (5,), "float32")
out = T.match_buffer(var_out, (5,), "float32")
for i in range(5):
out[i] = x[i] + y[i]
@T.prim_func
def mul_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
"""TIR function for multiplication."""
x = T.match_buffer(var_x, (3, 4), "float32")
y = T.match_buffer(var_y, (3, 4), "float32")
out = T.match_buffer(var_out, (3, 4), "float32")
for i in range(3):
for j in range(4):
out[i, j] = x[i, j] * y[i, j]
@R.function
def simple_add(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.add(x, y)
@R.function
def with_relu(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.nn.relu(x)
@R.function
def with_call_tir(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
cls = ComprehensiveTestModule
return R.call_tir(cls.add_tir, (x, y), out_sinfo=R.Tensor((5,), "float32"))
@R.function
def with_call_dps_packed(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.call_dps_packed(
"my_softmax", (x, R.prim_value(1)), out_sinfo=R.Tensor((5,), "float32")
)
@R.function
def complex_function(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
added = R.add(x, y)
relued = R.nn.relu(added)
cls = ComprehensiveTestModule
tir_result = R.call_tir(cls.add_tir, (relued, y), out_sinfo=R.Tensor((5,), "float32"))
return R.nn.relu(tir_result)
@R.function
def symbolic_add(
x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32")
) -> R.Tensor(("n",), "float32"):
return R.add(x, y)
@R.function
def symbolic_matmul(
x: R.Tensor(("batch", "m", "k"), "float32"), y: R.Tensor(("batch", "k", "n"), "float32")
) -> R.Tensor(("batch", "m", "n"), "float32"):
return R.matmul(x, y)
@R.function
def symbolic_expand_dims(
x: R.Tensor(("batch", "seq_len"), "float32")
) -> R.Tensor(("batch", "seq_len", 1), "float32"):
return R.expand_dims(x, axis=2)
@R.function
def multi_ops(
x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")
) -> R.Tensor((3, 4), "float32"):
added = R.add(x, y)
multiplied = R.multiply(added, y)
powered = R.power(multiplied, R.const(2.0))
maxed = R.maximum(powered, x)
return maxed
@R.function
def reduction_ops(x: R.Tensor((5,), "float32")) -> R.Tensor((), "float32"):
sum_val = R.sum(x)
mean_val = R.mean(x)
max_val = R.max(x)
return R.add(R.add(sum_val, mean_val), max_val)
@R.function
def comparison_ops(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "bool"):
eq_val = R.equal(x, y)
gt_val = R.greater(x, y)
return R.logical_and(eq_val, gt_val)
@R.function
def test_reshape(x: R.Tensor((2, 3), "float32")) -> R.Tensor((6,), "float32"):
return R.reshape(x, (6,))
@R.function
def test_permute_dims(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((4, 2, 3), "float32"):
return R.permute_dims(x, axes=[2, 0, 1])
@R.function
def test_concat(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((4, 3), "float32"):
return R.concat((x, y), axis=0)
@R.function
def test_split(x: R.Tensor((4, 3), "float32")) -> R.Tuple:
return R.split(x, indices_or_sections=2, axis=0)
@R.function
def test_stack(
x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")
) -> R.Tensor((2, 2, 3), "float32"):
return R.stack((x, y), axis=1)
@R.function
def test_take(
x: R.Tensor((3, 4), "float32"), indices: R.Tensor((2,), "int64")
) -> R.Tensor((2,), "float32"):
return R.take(x, indices, axis=0)
@R.function
def test_flip(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
return R.flip(x, axis=1)
@R.function
def test_tile(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 6), "float32"):
return R.tile(x, (2, 2))
@R.function
def test_repeat(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 3), "float32"):
return R.repeat(x, repeats=2, axis=0)
@R.function
def test_expand_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3, 1), "float32"):
return R.expand_dims(x, axis=2)
@R.function
def test_squeeze(x: R.Tensor((2, 3, 1), "float32")) -> R.Tensor((2, 3), "float32"):
return R.squeeze(x, axis=2)
@R.function
def test_sum_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"):
return R.sum(x, axis=0)
@R.function
def test_max_with_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"):
return R.max(x, axis=0)
def create_mock_packed_function():
"""Create a mock packed function for testing."""
def mock_softmax(x, axis):
"""Mock softmax function that just returns the input."""
return x
# Register the function globally
tvm.register_global_func("my_softmax", mock_softmax)
class TestRelaxToPyFuncConverter:
"""Comprehensive test class for Relax to PyFunc converter."""
@classmethod
def setup_class(cls):
"""Set up test fixtures."""
cls.ir_mod = ComprehensiveTestModule
cls.converter = RelaxToPyFuncConverter(cls.ir_mod)
create_mock_packed_function()
def test_basic_operations(self):
"""Test basic arithmetic operations."""
converted_ir_mod = self.converter.convert(["simple_add", "with_relu"])
# Test simple_add
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["simple_add"](x, y)
expected = torch.add(x, y)
assert torch.allclose(result, expected)
# Test with_relu
x_neg = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["with_relu"](x_neg)
expected = torch.nn.functional.relu(x_neg)
assert torch.allclose(result, expected)
def test_call_tir(self):
"""Test call_tir functionality with DLPack conversion."""
converted_ir_mod = self.converter.convert(["with_call_tir"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["with_call_tir"](x, y)
expected = torch.add(x, y)
assert torch.allclose(result, expected)
assert result.shape == expected.shape
def test_call_dps_packed(self):
"""Test call_dps_packed functionality."""
converted_ir_mod = self.converter.convert(["with_call_dps_packed"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["with_call_dps_packed"](x)
expected = x
assert torch.allclose(result, expected)
def test_complex_function(self):
"""Test complex function with multiple operations."""
converted_ir_mod = self.converter.convert(["complex_function"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["complex_function"](x, y)
# Expected: relu(add(relu(add(x, y)), y))
step1 = torch.add(x, y)
step2 = torch.nn.functional.relu(step1)
step3 = torch.add(step2, y) # TIR call
expected = torch.nn.functional.relu(step3)
assert torch.allclose(result, expected)
def test_symbolic_shapes(self):
"""Test symbolic shape handling."""
converted_ir_mod = self.converter.convert(
["symbolic_add", "symbolic_matmul", "symbolic_expand_dims"]
)
# Test symbolic_add
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["symbolic_add"](x, y)
expected = torch.add(x, y)
assert torch.allclose(result, expected)
# Test symbolic_matmul
x = torch.randn(2, 3, 4, dtype=torch.float32) # (batch=2, m=3, k=4)
y = torch.randn(2, 4, 5, dtype=torch.float32) # (batch=2, k=4, n=5)
result = converted_ir_mod.pyfuncs["symbolic_matmul"](x, y)
expected = torch.matmul(x, y)
assert torch.allclose(result, expected)
assert result.shape == (2, 3, 5)
# Test symbolic_expand_dims
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["symbolic_expand_dims"](x)
expected = torch.unsqueeze(x, dim=2)
assert torch.allclose(result, expected)
assert result.shape == (2, 2, 1)
def test_multi_operations(self):
"""Test multiple operations in sequence."""
converted_ir_mod = self.converter.convert(["multi_ops"])
x = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
dtype=torch.float32,
)
y = torch.tensor(
[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], dtype=torch.float32
)
result = converted_ir_mod.pyfuncs["multi_ops"](x, y)
# Expected: maximum(power(multiply(add(x, y), y), 2), x)
step1 = torch.add(x, y)
step2 = torch.mul(step1, y)
step3 = torch.pow(step2, 2.0)
expected = torch.maximum(step3, x)
assert torch.allclose(result, expected)
def test_reduction_operations(self):
"""Test reduction operations."""
converted_ir_mod = self.converter.convert(["reduction_ops"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["reduction_ops"](x)
# Expected: sum(x) + mean(x) + max(x)
expected = torch.sum(x) + torch.mean(x) + torch.max(x)
assert torch.allclose(result, expected)
assert result.shape == ()
def test_comparison_operations(self):
"""Test comparison operations."""
converted_ir_mod = self.converter.convert(["comparison_ops"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([1.0, 2.5, 3.0, 4.5, 5.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["comparison_ops"](x, y)
# Expected: logical_and(equal(x, y), greater(x, y))
eq_val = torch.eq(x, y)
gt_val = torch.gt(x, y)
expected = torch.logical_and(eq_val, gt_val)
assert torch.allclose(result, expected)
assert result.dtype == torch.bool
def test_operator_mapping_completeness(self):
"""Test that operator mapping is comprehensive."""
operator_map = RelaxToPyFuncConverter._get_op_map()
# Check that we have a good number of operators
assert len(operator_map) > 100, f"Expected >100 operators, got {len(operator_map)}"
# Check key operator categories
binary_ops = [
op
for op in operator_map.keys()
if op.startswith("relax.") and not op.startswith("relax.nn.")
]
nn_ops = [op for op in operator_map.keys() if op.startswith("relax.nn.")]
assert len(binary_ops) > 20, f"Expected >20 binary ops, got {len(binary_ops)}"
assert len(nn_ops) > 30, f"Expected >30 nn ops, got {len(nn_ops)}"
# Check specific important operators
important_ops = [
"relax.add",
"relax.multiply",
"relax.nn.relu",
"relax.nn.softmax",
"relax.matmul",
"relax.reshape",
"relax.sum",
"relax.mean",
]
for op in important_ops:
assert op in operator_map, f"Missing important operator: {op}"
def test_error_handling(self):
"""Test error handling for invalid inputs."""
converted_ir_mod = self.converter.convert(["simple_add"])
# Test with wrong number of arguments
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
with pytest.raises(ValueError, match="Expected 2 arguments"):
converted_ir_mod.pyfuncs["simple_add"](x) # Missing second argument
# Test with incompatible shapes - this should raise a RuntimeError
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
y = torch.tensor([1.0, 2.0], dtype=torch.float32) # Different shape
# This should raise a RuntimeError because shapes don't match
with pytest.raises(RuntimeError, match="The size of tensor a"):
converted_ir_mod.pyfuncs["simple_add"](x, y)
def test_conversion_metadata(self):
"""Test that conversion preserves metadata correctly."""
converted_ir_mod = self.converter.convert(["simple_add"])
# Check that pyfuncs attribute exists
assert hasattr(converted_ir_mod, "pyfuncs")
assert "simple_add" in converted_ir_mod.pyfuncs
# Check function metadata
pyfunc = converted_ir_mod.pyfuncs["simple_add"]
assert hasattr(pyfunc, "__name__")
assert hasattr(pyfunc, "__doc__")
assert pyfunc.__name__ == "simple_add"
def test_tensor_operations(self):
"""Test tensor manipulation operations."""
converted_ir_mod = self.converter.convert(
[
"test_reshape",
"test_permute_dims",
"test_concat",
"test_split",
"test_stack",
"test_take",
"test_flip",
"test_tile",
"test_repeat",
"test_expand_dims",
"test_squeeze",
"test_sum_with_axis",
"test_max_with_axis",
]
)
# Test reshape
x1 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result1 = converted_ir_mod.pyfuncs["test_reshape"](x1)
expected1 = torch.reshape(x1, (6,))
assert torch.allclose(result1, expected1), "Reshape operation failed"
# Test permute_dims
x2 = torch.randn(2, 3, 4)
result2 = converted_ir_mod.pyfuncs["test_permute_dims"](x2)
expected2 = torch.permute(x2, (2, 0, 1))
assert torch.allclose(result2, expected2), "Permute_dims operation failed"
# Test concat
x3 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
y3 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32)
result3 = converted_ir_mod.pyfuncs["test_concat"](x3, y3)
expected3 = torch.cat([x3, y3], dim=0)
assert torch.allclose(result3, expected3), "Concat operation failed"
# Test split
x4 = torch.tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]],
dtype=torch.float32,
)
result4 = converted_ir_mod.pyfuncs["test_split"](x4)
expected4 = torch.split(x4, 2, dim=0)
assert len(result4) == len(expected4), "Split operation failed - wrong number of tensors"
for r, e in zip(result4, expected4):
assert torch.allclose(r, e), "Split operation failed - tensor mismatch"
# Test stack
x5 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
y5 = torch.tensor([[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]], dtype=torch.float32)
result5 = converted_ir_mod.pyfuncs["test_stack"](x5, y5)
expected5 = torch.stack([x5, y5], dim=1)
assert torch.allclose(result5, expected5), "Stack operation failed"
# Test take
x6 = torch.tensor(
[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0]],
dtype=torch.float32,
)
indices = torch.tensor([0, 2], dtype=torch.int64)
result6 = converted_ir_mod.pyfuncs["test_take"](x6, indices)
expected6 = x6[indices]
assert torch.allclose(result6, expected6), "Take operation failed"
# Test flip
x7 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result7 = converted_ir_mod.pyfuncs["test_flip"](x7)
expected7 = torch.flip(x7, dims=[1])
assert torch.allclose(result7, expected7), "Flip operation failed"
# Test tile
x8 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result8 = converted_ir_mod.pyfuncs["test_tile"](x8)
expected8 = torch.tile(x8, (2, 2))
assert torch.allclose(result8, expected8), "Tile operation failed"
# Test repeat
x9 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result9 = converted_ir_mod.pyfuncs["test_repeat"](x9)
expected9 = torch.repeat_interleave(x9, repeats=2, dim=0)
assert torch.allclose(result9, expected9), "Repeat operation failed"
# Test expand_dims
x10 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result10 = converted_ir_mod.pyfuncs["test_expand_dims"](x10)
expected10 = torch.unsqueeze(x10, dim=2)
assert torch.allclose(result10, expected10), "Expand_dims operation failed"
# Test squeeze
x11 = torch.tensor([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], dtype=torch.float32)
result11 = converted_ir_mod.pyfuncs["test_squeeze"](x11)
expected11 = torch.squeeze(x11, dim=2)
assert torch.allclose(result11, expected11), "Squeeze operation failed"
# Test sum with axis
x12 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result12 = converted_ir_mod.pyfuncs["test_sum_with_axis"](x12)
expected12 = torch.sum(x12, dim=0)
assert torch.allclose(result12, expected12), "Sum with axis operation failed"
# Test max with axis
x13 = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
result13 = converted_ir_mod.pyfuncs["test_max_with_axis"](x13)
expected13 = torch.max(x13, dim=0)[0] # torch.max returns (values, indices)
assert torch.allclose(result13, expected13), "Max with axis operation failed"
@I.ir_module
class ExtendedOperatorsModule:
"""Extended test module with additional operators not covered in ComprehensiveTestModule."""
# Unary operations not covered in ComprehensiveTestModule
@R.function
def test_abs(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.abs(x)
@R.function
def test_neg(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.negative(x)
@R.function
def test_exp(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.exp(x)
@R.function
def test_log(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.log(x)
@R.function
def test_sqrt(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.sqrt(x)
@R.function
def test_sin(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.sin(x)
@R.function
def test_cos(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.cos(x)
@R.function
def test_tanh(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.tanh(x)
@R.function
def test_sigmoid(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.sigmoid(x)
# Comparison operations not covered in ComprehensiveTestModule
@R.function
def test_less(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "bool"):
return R.less(x, y)
@R.function
def test_not_equal(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "bool"):
return R.not_equal(x, y)
# Binary operations not covered in ComprehensiveTestModule
@R.function
def test_multiply(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.multiply(x, y)
@R.function
def test_divide(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.divide(x, y)
@R.function
def test_power(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.power(x, y)
@R.function
def test_maximum(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.maximum(x, y)
@R.function
def test_minimum(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.minimum(x, y)
@R.function
def test_subtract(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.subtract(x, y)
# Additional tensor operations with different parameters
@R.function
def test_transpose_2d(x: R.Tensor((2, 4), "float32")) -> R.Tensor((4, 2), "float32"):
return R.permute_dims(x, axes=[1, 0])
@R.function
def test_mean_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"):
return R.mean(x, axis=0)
@R.function
def test_min_axis(x: R.Tensor((2, 3), "float32")) -> R.Tensor((3,), "float32"):
return R.min(x, axis=0)
# Neural network operations not covered in ComprehensiveTestModule
@R.function
def test_gelu_nn(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
return R.nn.gelu(x)
@R.function
def test_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"):
return R.nn.softmax(x, axis=1)
@R.function
def test_log_softmax_nn(x: R.Tensor((2, 5), "float32")) -> R.Tensor((2, 5), "float32"):
return R.nn.log_softmax(x, axis=1)
# Advanced tensor operations with different parameters
@R.function
def test_tile_dims(x: R.Tensor((2, 3), "float32")) -> R.Tensor((4, 9), "float32"):
return R.tile(x, (2, 3))
@R.function
def test_repeat_axis(x: R.Tensor((3,), "float32")) -> R.Tensor((6,), "float32"):
return R.repeat(x, repeats=2, axis=0)
class TestExtendedOperators:
"""Test class for extended operator coverage."""
@classmethod
def setup_class(cls):
"""Set up test fixtures."""
cls.ir_mod = ExtendedOperatorsModule
cls.converter = RelaxToPyFuncConverter(cls.ir_mod)
def test_unary_operations(self):
"""Test unary operations."""
converted_ir_mod = self.converter.convert(
["test_abs", "test_neg", "test_exp", "test_log", "test_sqrt"]
)
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=torch.float32)
# Test abs
result = converted_ir_mod.pyfuncs["test_abs"](x)
expected = torch.abs(x)
assert torch.allclose(result, expected)
# Test negative
result = converted_ir_mod.pyfuncs["test_neg"](x)
expected = torch.neg(x)
assert torch.allclose(result, expected)
# Test exp
result = converted_ir_mod.pyfuncs["test_exp"](x)
expected = torch.exp(x)
assert torch.allclose(result, expected)
# Test log (with positive values)
x_pos = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_log"](x_pos)
expected = torch.log(x_pos)
assert torch.allclose(result, expected)
# Test sqrt
result = converted_ir_mod.pyfuncs["test_sqrt"](x_pos)
expected = torch.sqrt(x_pos)
assert torch.allclose(result, expected)
def test_trigonometric_operations(self):
"""Test trigonometric operations."""
converted_ir_mod = self.converter.convert(
["test_sin", "test_cos", "test_tanh", "test_sigmoid"]
)
x = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0], dtype=torch.float32)
# Test sin
result = converted_ir_mod.pyfuncs["test_sin"](x)
expected = torch.sin(x)
assert torch.allclose(result, expected)
# Test cos
result = converted_ir_mod.pyfuncs["test_cos"](x)
expected = torch.cos(x)
assert torch.allclose(result, expected)
# Test tanh
result = converted_ir_mod.pyfuncs["test_tanh"](x)
expected = torch.tanh(x)
assert torch.allclose(result, expected)
# Test sigmoid
result = converted_ir_mod.pyfuncs["test_sigmoid"](x)
expected = torch.sigmoid(x)
assert torch.allclose(result, expected)
def test_comparison_operations(self):
"""Test comparison operations not covered in ComprehensiveTestModule."""
converted_ir_mod = self.converter.convert(["test_less", "test_not_equal"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32)
# Test less
result = converted_ir_mod.pyfuncs["test_less"](x, y)
expected = torch.lt(x, y)
assert torch.equal(result, expected)
# Test not equal
result = converted_ir_mod.pyfuncs["test_not_equal"](x, y)
expected = torch.ne(x, y)
assert torch.equal(result, expected)
def test_binary_operations(self):
"""Test binary operations."""
converted_ir_mod = self.converter.convert(
[
"test_multiply",
"test_divide",
"test_power",
"test_maximum",
"test_minimum",
"test_subtract",
]
)
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
y = torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0], dtype=torch.float32)
# Test multiply
result = converted_ir_mod.pyfuncs["test_multiply"](x, y)
expected = torch.mul(x, y)
assert torch.allclose(result, expected)
# Test divide
result = converted_ir_mod.pyfuncs["test_divide"](x, y)
expected = torch.div(x, y)
assert torch.allclose(result, expected)
# Test power
result = converted_ir_mod.pyfuncs["test_power"](x, y)
expected = torch.pow(x, y)
assert torch.allclose(result, expected)
# Test maximum
result = converted_ir_mod.pyfuncs["test_maximum"](x, y)
expected = torch.maximum(x, y)
assert torch.allclose(result, expected)
# Test minimum
result = converted_ir_mod.pyfuncs["test_minimum"](x, y)
expected = torch.minimum(x, y)
assert torch.allclose(result, expected)
# Test subtract
result = converted_ir_mod.pyfuncs["test_subtract"](x, y)
expected = torch.sub(x, y)
assert torch.allclose(result, expected)
def test_tensor_operations(self):
"""Test tensor operations not covered in ComprehensiveTestModule."""
converted_ir_mod = self.converter.convert(["test_transpose_2d"])
x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)
# Test transpose
result = converted_ir_mod.pyfuncs["test_transpose_2d"](x)
expected = torch.transpose(x, 0, 1)
assert torch.allclose(result, expected)
assert result.shape == (4, 2)
def test_reduction_operations(self):
"""Test reduction operations not covered in ComprehensiveTestModule."""
converted_ir_mod = self.converter.convert(["test_mean_axis", "test_min_axis"])
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float32)
# Test mean
result = converted_ir_mod.pyfuncs["test_mean_axis"](x)
expected = torch.mean(x, dim=0)
assert torch.allclose(result, expected)
assert result.shape == (3,)
# Test min
result = converted_ir_mod.pyfuncs["test_min_axis"](x)
expected = torch.min(x, dim=0)[0]
assert torch.allclose(result, expected)
assert result.shape == (3,)
def test_neural_network_operations(self):
"""Test neural network operations not covered in ComprehensiveTestModule."""
converted_ir_mod = self.converter.convert(
["test_gelu_nn", "test_softmax_nn", "test_log_softmax_nn"]
)
x = torch.tensor(
[[-2.0, -1.0, 0.0, 1.0, 2.0], [0.5, 1.5, 2.5, 3.5, 4.5]], dtype=torch.float32
)
# Test gelu
result = converted_ir_mod.pyfuncs["test_gelu_nn"](x[0])
expected = F.gelu(x[0])
assert torch.allclose(result, expected)
# Test softmax
result = converted_ir_mod.pyfuncs["test_softmax_nn"](x)
expected = F.softmax(x, dim=1)
assert torch.allclose(result, expected)
# Test log_softmax
result = converted_ir_mod.pyfuncs["test_log_softmax_nn"](x)
expected = F.log_softmax(x, dim=1)
assert torch.allclose(result, expected)
def test_advanced_tensor_operations(self):
"""Test advanced tensor operations with different parameters."""
converted_ir_mod = self.converter.convert(["test_tile_dims", "test_repeat_axis"])
x = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)
# Test tile with different dimensions
result = converted_ir_mod.pyfuncs["test_tile_dims"](x)
expected = torch.tile(x, (2, 3))
assert torch.allclose(result, expected)
assert result.shape == (4, 12)
# Test repeat with different parameters
x_1d = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_repeat_axis"](x_1d)
expected = torch.repeat_interleave(x_1d, repeats=2, dim=0)
assert torch.allclose(result, expected)
assert result.shape == (6,)
class TestDLPackAndTupleSupport:
"""Test DLPack conversion, tuple handling, and API compatibility features."""
def test_dlpack_conversion_fallback(self):
"""Test DLPack conversion with numpy fallback."""
@I.ir_module
class DLPackTestModule:
@T.prim_func
def test_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (4,), "float32")
y = T.match_buffer(var_y, (4,), "float32")
out = T.match_buffer(var_out, (4,), "float32")
for i in range(4):
out[i] = x[i] + y[i]
@R.function
def test_func(
x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")
) -> R.Tensor((4,), "float32"):
return R.call_tir(
DLPackTestModule.test_tir, (x, y), out_sinfo=R.Tensor((4,), "float32")
)
converter = RelaxToPyFuncConverter(DLPackTestModule)
converted_ir_mod = converter.convert(["test_func"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_func"](x, y)
expected = torch.add(x, y)
assert torch.allclose(result, expected), "DLPack conversion with numpy fallback failed"
def test_tuple_return_handling(self):
"""Test proper handling of tuple returns (e.g., split operation)."""
@I.ir_module
class TupleTestModule:
@R.function
def test_split(x: R.Tensor((6,), "float32")) -> R.Tuple:
return R.split(x, indices_or_sections=3, axis=0)
converter = RelaxToPyFuncConverter(TupleTestModule)
converted_ir_mod = converter.convert(["test_split"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_split"](x)
expected = torch.split(x, 2, dim=0)
assert isinstance(result, tuple), "Split should return tuple"
assert len(result) == len(expected), "Split should return correct number of tensors"
for r, e in zip(result, expected):
assert torch.allclose(r, e), "Split tensor values should match"
def test_tvm_runtime_api_compatibility(self):
"""Test compatibility with tvm.runtime API instead of deprecated tvm.nd."""
@I.ir_module
class RuntimeAPITestModule:
@T.prim_func
def test_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (3,), "float32")
y = T.match_buffer(var_y, (3,), "float32")
out = T.match_buffer(var_out, (3,), "float32")
for i in range(3):
out[i] = x[i] * y[i]
@R.function
def test_func(
x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")
) -> R.Tensor((3,), "float32"):
return R.call_tir(
RuntimeAPITestModule.test_tir, (x, y), out_sinfo=R.Tensor((3,), "float32")
)
converter = RelaxToPyFuncConverter(RuntimeAPITestModule)
converted_ir_mod = converter.convert(["test_func"])
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
y = torch.tensor([2.0, 3.0, 4.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_func"](x, y)
expected = torch.mul(x, y)
assert torch.allclose(result, expected)
def test_packed_function_with_primvalue_args(self):
"""Test packed function calls with PrimValue arguments."""
# Register a test packed function
def test_packed_func(x, axis):
return x # Simple identity function
tvm.register_global_func("test_packed_func", test_packed_func)
@I.ir_module
class PackedFuncTestModule:
@R.function
def test_dps(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"):
return R.call_dps_packed(
"test_packed_func", (x, R.const(0)), out_sinfo=R.Tensor((4,), "float32")
)
converter = RelaxToPyFuncConverter(PackedFuncTestModule)
converted_ir_mod = converter.convert(["test_dps"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_dps"](x)
expected = x # Identity function
assert torch.allclose(result, expected), "Packed function with PrimValue args failed"
def test_mixed_tir_and_relax_operations(self):
"""Test mixed TIR and Relax operations in a single function."""
@I.ir_module
class MixedOpsTestModule:
@T.prim_func
def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (4,), "float32")
y = T.match_buffer(var_y, (4,), "float32")
out = T.match_buffer(var_out, (4,), "float32")
for i in range(4):
out[i] = x[i] + y[i]
@R.function
def test_mixed(
x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")
) -> R.Tensor((4,), "float32"):
# TIR operation
tir_result = R.call_tir(
MixedOpsTestModule.add_tir, (x, y), out_sinfo=R.Tensor((4,), "float32")
)
# Relax operations
relued = R.nn.relu(tir_result)
powered = R.power(relued, R.const(2.0))
return R.nn.gelu(powered)
converter = RelaxToPyFuncConverter(MixedOpsTestModule)
converted_ir_mod = converter.convert(["test_mixed"])
x = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_mixed"](x, y)
# Manual computation for expected result
added = torch.add(x, y)
relued = F.relu(added)
powered = torch.pow(relued, 2.0)
expected = F.gelu(powered)
assert torch.allclose(result, expected)
def test_error_handling_improvements(self):
"""Test improved error handling with tensor fallbacks."""
@I.ir_module
class ErrorHandlingTestModule:
@R.function
def test_error_handling(x: R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"):
# This should trigger fallback mechanisms
return R.nn.relu(x)
converter = RelaxToPyFuncConverter(ErrorHandlingTestModule)
converted_ir_mod = converter.convert(["test_error_handling"])
x = torch.tensor([-2.0, -1.0, 0.0, 1.0], dtype=torch.float32)
result = converted_ir_mod.pyfuncs["test_error_handling"](x)
expected = F.relu(x)
assert torch.allclose(result, expected), "Error handling with tensor fallbacks failed"
assert isinstance(result, torch.Tensor), "Result should be a tensor, not a string"
if __name__ == "__main__":
tvm.testing.main()