blob: 7b3c4052fa93118b1e2d911e66ef1ea599e16ec0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test TVMScript @I.pyfunc decorator functionality.
This test verifies:
1. @I.pyfunc decorator works correctly
2. Python functions are properly integrated into IRModule
3. BasePyModule inheritance is handled correctly
4. ExternFunc nodes are created for Python functions
"""
import pytest
import torch
import tvm
from tvm import relax
from tvm.script import ir as I, relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np
@I.ir_module
class TestPyFuncModule(BasePyModule):
"""Test module with Python functions using @I.pyfunc decorator."""
@I.pyfunc
def pytorch_processor(x: torch.Tensor) -> torch.Tensor:
"""Python function that processes PyTorch tensors."""
return torch.nn.functional.relu(x) * 2.0
@I.pyfunc
def pytorch_adder(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Python function that adds two PyTorch tensors."""
return x + y
@I.pyfunc
def pytorch_complex_ops(x: torch.Tensor) -> torch.Tensor:
"""Complex PyTorch operations."""
result = torch.nn.functional.softmax(x, dim=0)
result = torch.nn.functional.dropout(result, p=0.1, training=False)
return result * 10.0
@T.prim_func
def simple_tir_func(
var_A: T.handle,
var_B: T.handle,
):
T.func_attr({"tir.noalias": True})
n = T.int32()
A = T.match_buffer(var_A, (n,), "float32")
B = T.match_buffer(var_B, (n,), "float32")
for i in T.grid(n):
with T.block("copy"):
vi = T.axis.remap("S", [i])
B[vi] = A[vi]
class TestTVMScriptPyFunc:
def test_pyfunc_decorator_creates_pyfuncs_attribute(self):
module = TestPyFuncModule
assert hasattr(module, "pyfuncs"), "Module should have pyfuncs attribute"
pyfuncs = module.pyfuncs
assert isinstance(pyfuncs, dict), "pyfuncs should be a dictionary"
expected_functions = ["pytorch_processor", "pytorch_adder", "pytorch_complex_ops"]
for func_name in expected_functions:
assert func_name in pyfuncs, f"Function {func_name} should be in pyfuncs"
def test_pyfunc_functions_are_callable(self):
"""Test that Python functions in pyfuncs are callable."""
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# Test pytorch_processor
processor_func = pyfuncs["pytorch_processor"]
assert callable(processor_func), "pytorch_processor should be callable"
# Test pytorch_adder
adder_func = pyfuncs["pytorch_adder"]
assert callable(adder_func), "pytorch_adder should be callable"
# Test pytorch_complex_ops
complex_func = pyfuncs["pytorch_complex_ops"]
assert callable(complex_func), "pytorch_complex_ops should be callable"
def test_pyfunc_functions_execute_correctly(self):
"""Test that Python functions execute correctly."""
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# Create test data
x = torch.tensor([1.0, -2.0, 3.0, -4.0, 5.0], dtype=torch.float32)
y = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], dtype=torch.float32)
# Test pytorch_processor
processor_func = pyfuncs["pytorch_processor"]
processor_result = processor_func(x)
assert isinstance(processor_result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(processor_result, expected, atol=1e-5)
# Test pytorch_adder
adder_func = pyfuncs["pytorch_adder"]
adder_result = adder_func(x, y)
assert isinstance(adder_result, torch.Tensor)
expected = x + y
assert torch.allclose(adder_result, expected, atol=1e-5)
# Test pytorch_complex_ops
complex_func = pyfuncs["pytorch_complex_ops"]
complex_result = complex_func(x)
assert isinstance(complex_result, torch.Tensor)
# Note: dropout is non-deterministic, so we just check shape and type
assert complex_result.shape == x.shape
assert complex_result.dtype == x.dtype
def test_pyfunc_module_has_functions_attribute(self):
"""Test that the module has functions attribute for IRModule operations."""
module = TestPyFuncModule
# Check if functions attribute exists
assert hasattr(module, "functions"), "Module should have functions attribute"
functions = module.functions
# TVM IRModule.functions is not a standard dict, but has dict-like behavior
assert hasattr(functions, "__getitem__"), "functions should support dict-like access"
assert hasattr(functions, "__iter__"), "functions should be iterable"
def test_pyfunc_module_script_method(self):
"""Test that the module has script() method for TVMScript output."""
module = TestPyFuncModule
# Check if script method exists
assert hasattr(module, "script"), "Module should have script method"
# Test script method execution
script_output = module.script()
assert isinstance(script_output, str), "script() should return a string"
assert len(script_output) > 0, "script() should return non-empty string"
def test_pyfunc_module_inheritance_flag(self):
"""Test that the module has BasePyModule inheritance flag."""
module = TestPyFuncModule
# Check if inheritance flag exists (this might not be set in all implementations)
if hasattr(module, "_base_py_module_inherited"):
assert module._base_py_module_inherited, "Inheritance flag should be True"
else:
# Alternative: check if the module supports Python functions
assert hasattr(module, "pyfuncs"), "Module should support Python functions"
# Check if original class is preserved (this might not be set in all implementations)
if hasattr(module, "_original_class"):
assert module._original_class is not None, "Original class should be preserved"
else:
# Alternative: check if module is callable (ModuleFactory)
assert hasattr(module, "__call__"), "Module should be callable (ModuleFactory)"
def test_pyfunc_module_creation_and_execution(self):
module = TestPyFuncModule
assert hasattr(module, "__call__"), "Module should be callable"
device = tvm.cpu(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "Instance should be BasePyModule"
assert hasattr(instance, "pyfuncs"), "Instance should have pyfuncs"
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = instance.pytorch_processor(x)
assert isinstance(result, torch.Tensor)
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(result, expected, atol=1e-5)
def test_pyfunc_module_creation_and_execution_gpu(self):
module = TestPyFuncModule
if tvm.cuda().exist:
device = tvm.cuda(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "Instance should be BasePyModule"
assert hasattr(instance, "pyfuncs"), "Instance should have pyfuncs"
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda")
result = instance.pytorch_processor(x)
assert isinstance(result, torch.Tensor)
assert result.device.type == "cuda"
expected = torch.nn.functional.relu(x) * 2.0
assert torch.allclose(result, expected, atol=1e-5)
else:
pytest.skip("CUDA not available")
def test_pyfunc_with_tir_integration(self):
"""Test that Python functions can work with TIR functions."""
module = TestPyFuncModule
# Create instance
device = tvm.cpu(0)
instance = module(device)
# Test TIR function execution
n = 5
input_tensor = torch.randn(n, dtype=torch.float32)
# Call TIR function - it needs 3 arguments: input, output, and size
# But call_tir handles the output buffer creation, so we only pass input and size
# Note: TIR functions expect TVM types, not Python types
result = instance.call_tir(
instance.simple_tir_func,
[input_tensor], # Only pass input tensor, let call_tir handle the rest
R.Tensor((n,), "float32"),
)
# Verify result
assert isinstance(result, torch.Tensor)
assert result.shape == (n,)
assert torch.allclose(result, input_tensor, atol=1e-5)
def test_pyfunc_decorator_preserves_function_signatures(self):
"""Test that @I.pyfunc decorator preserves function signatures."""
module = TestPyFuncModule
pyfuncs = module.pyfuncs
# Check function signatures
import inspect
# pytorch_processor signature
processor_func = pyfuncs["pytorch_processor"]
sig = inspect.signature(processor_func)
params = list(sig.parameters.keys())
assert len(params) == 1, "pytorch_processor should have 1 parameter"
assert params[0] == "x", "First parameter should be 'x'"
# pytorch_adder signature
adder_func = pyfuncs["pytorch_adder"]
sig = inspect.signature(adder_func)
params = list(sig.parameters.keys())
assert len(params) == 2, "pytorch_adder should have 2 parameters"
assert params[0] == "x", "First parameter should be 'x'"
assert params[1] == "y", "Second parameter should be 'y'"
if __name__ == "__main__":
pytest.main([__file__])