blob: a64b3fed5aeabd25badf3de766035c9a4c613d4a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring, invalid-name, unused-argument
import pytest
import tvm
from tvm.relax.base_py_module import BasePyModule
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class SimplePyFuncModule(BasePyModule):
"""Test simple Python functions with basic operations."""
@I.pyfunc
def add(self, x, y):
"""Simple addition function."""
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(self.add_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32"))
return self._convert_tvm_to_pytorch(result)
@I.pyfunc
def multiply(self, x, y):
"""Simple multiplication function."""
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(
self.multiply_tir, [x_tvm, y_tvm], out_sinfo=R.Tensor((5,), "float32")
)
return self._convert_tvm_to_pytorch(result)
@T.prim_func
def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (5,), "float32")
y = T.match_buffer(var_y, (5,), "float32")
out = T.match_buffer(var_out, (5,), "float32")
for i in range(5):
out[i] = x[i] + y[i]
@T.prim_func
def multiply_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle):
x = T.match_buffer(var_x, (5,), "float32")
y = T.match_buffer(var_y, (5,), "float32")
out = T.match_buffer(var_out, (5,), "float32")
for i in range(5):
out[i] = x[i] * y[i]
@R.function
def main_relax(
x: R.Tensor((5,), "float32"), y: R.Tensor((5,), "float32")
) -> R.Tensor((5,), "float32"):
return R.add(x, y)
@I.ir_module
class ComplexPyFuncModule(BasePyModule):
"""Test complex Python logic with ML pipeline and error handling."""
@I.pyfunc
def ml_pipeline(self, input_data, model_params):
"""Complex ML pipeline with data validation and error handling."""
# Data validation
if input_data is None or model_params is None:
raise ValueError("Inputs cannot be None")
try:
# Convert to TVM format
tvm_data = self._convert_pytorch_to_tvm(input_data)
tvm_params = self._convert_pytorch_to_tvm(model_params)
# Run ML inference
features = self.call_tir(
self.extract_features, [tvm_data], out_sinfo=R.Tensor((10,), "float32")
)
predictions = self.call_tir(
self.ml_inference, [features, tvm_params], out_sinfo=R.Tensor((5,), "float32")
)
# Post-process results
final_result = self.call_tir(
self.post_process, [predictions], out_sinfo=R.Tensor((5,), "float32")
)
return self._convert_tvm_to_pytorch(final_result)
except Exception as e:
self._log_error(f"ML pipeline failed: {e}")
return self._get_default_value()
@I.pyfunc
def data_preprocessing(self, raw_data):
"""Data preprocessing with conditional logic."""
if hasattr(raw_data, "numpy"):
# Vectorized path for numpy-compatible data
data_np = raw_data.numpy()
processed = self._vectorized_preprocess(data_np)
else:
# Fallback path for other data types
processed = self._elementwise_preprocess(raw_data)
# Convert and return
tvm_processed = self._convert_pytorch_to_tvm(processed)
result = self.call_tir(
self.normalize_data, [tvm_processed], out_sinfo=R.Tensor((10,), "float32")
)
return self._convert_tvm_to_pytorch(result)
@T.prim_func
def extract_features(data: T.handle, features: T.handle):
T.func_attr({"tir.noalias": True})
Data = T.match_buffer(data, (10,), "float32")
Features = T.match_buffer(features, (10,), "float32")
for i in range(10):
Features[i] = T.sqrt(Data[i])
@T.prim_func
def ml_inference(features: T.handle, params: T.handle, output: T.handle):
T.func_attr({"tir.noalias": True})
Features = T.match_buffer(features, (10,), "float32")
Params = T.match_buffer(params, (10,), "float32")
Output = T.match_buffer(output, (5,), "float32")
for i in range(5):
Output[i] = Features[i] * Params[i] + Features[i + 5] * Params[i + 5]
@T.prim_func
def post_process(predictions: T.handle, final: T.handle):
T.func_attr({"tir.noalias": True})
Predictions = T.match_buffer(predictions, (5,), "float32")
Final = T.match_buffer(final, (5,), "float32")
for i in range(5):
Final[i] = T.max(Predictions[i], 0.0)
@T.prim_func
def normalize_data(data: T.handle, normalized: T.handle):
T.func_attr({"tir.noalias": True})
Data = T.match_buffer(data, (10,), "float32")
Normalized = T.match_buffer(normalized, (10,), "float32")
for i in range(10):
Normalized[i] = Data[i] / 255.0
@I.ir_module
class EdgeCasePyFuncModule(BasePyModule):
"""Test edge cases and boundary conditions."""
@I.pyfunc
def empty_func(self):
"""Empty function with no operations."""
pass
@I.pyfunc
def single_return(self, x):
"""Function with immediate return."""
return x
@I.pyfunc
def nested_conditionals(self, data, threshold):
"""Function with complex nested conditional logic."""
if data is None:
return None
if hasattr(data, "shape"):
if len(data.shape) == 1:
if data.shape[0] > threshold:
return self._process_large_data(data)
else:
return self._process_small_data(data)
elif len(data.shape) == 2:
return self._process_2d_data(data)
else:
return self._process_nd_data(data)
else:
return self._process_scalar_data(data)
@I.pyfunc
def loop_with_break(self, data, max_iter):
"""Function with loop and break statement."""
result = []
for i, item in enumerate(data):
if i >= max_iter:
break
if item > 0:
result.append(item * 2)
else:
result.append(0)
return result
@T.prim_func
def dummy_tir(data: T.handle, output: T.handle):
T.func_attr({"tir.noalias": True})
Data = T.match_buffer(data, (1,), "float32")
Output = T.match_buffer(output, (1,), "float32")
Output[0] = Data[0]
@I.ir_module
class PerformancePyFuncModule(BasePyModule):
"""Test performance optimization patterns."""
@I.pyfunc
def vectorized_operation(self, x, y):
"""Vectorized operation with numpy fallback."""
try:
# Try vectorized operation first
if hasattr(x, "numpy") and hasattr(y, "numpy"):
x_np = x.numpy()
y_np = y.numpy()
result_np = x_np + y_np
return self._convert_numpy_to_pytorch(result_np)
except Exception:
pass
# Fallback to TVM processing
x_tvm = self._convert_pytorch_to_tvm(x)
y_tvm = self._convert_pytorch_to_tvm(y)
result = self.call_tir(
self.vectorized_add, [x_tvm, y_tvm], out_sinfo=R.Tensor((10,), "float32")
)
return self._convert_tvm_to_pytorch(result)
@I.pyfunc
def batch_processing(self, batch_data):
"""Batch processing with memory optimization."""
batch_size = len(batch_data)
results = []
# Process in chunks to optimize memory usage
chunk_size = min(batch_size, 100)
for i in range(0, batch_size, chunk_size):
chunk = batch_data[i : i + chunk_size]
chunk_result = self._process_chunk(chunk)
results.extend(chunk_result)
return results
@I.pyfunc
def memory_efficient_transform(self, large_tensor):
"""Memory-efficient tensor transformation."""
# Use in-place operations when possible
if hasattr(large_tensor, "requires_grad") and not large_tensor.requires_grad:
# In-place operation for efficiency
large_tensor.add_(1.0)
return large_tensor
else:
# Create new tensor if gradients are needed
return large_tensor + 1.0
@T.prim_func
def vectorized_add(a: T.handle, b: T.handle, c: T.handle):
T.func_attr({"tir.noalias": True})
A = T.match_buffer(a, (10,), "float32")
B = T.match_buffer(b, (10,), "float32")
C = T.match_buffer(c, (10,), "float32")
for i in range(10):
C[i] = A[i] + B[i]
@I.ir_module
class IntegrationPyFuncModule(BasePyModule):
"""Test integration with external libraries and complex workflows."""
@I.pyfunc
def sklearn_integration(self, input_data, scaler_params):
"""Integration with scikit-learn preprocessing."""
try:
# Import sklearn components
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
# Create and fit scaler
scaler = StandardScaler()
if scaler_params is not None:
scaler.mean_ = scaler_params["mean"]
scaler.scale_ = scaler_params["scale"]
else:
scaler.fit(input_data)
# Transform data
scaled_data = scaler.transform(input_data)
# Apply PCA if needed
if input_data.shape[1] > 10:
pca = PCA(n_components=10)
reduced_data = pca.fit_transform(scaled_data)
else:
reduced_data = scaled_data
# Convert to TVM and process
tvm_data = self._convert_pytorch_to_tvm(reduced_data)
result = self.call_tir(
self.final_transform,
[tvm_data],
out_sinfo=R.Tensor((reduced_data.shape[0], 10), "float32"),
)
return self._convert_tvm_to_pytorch(result)
except ImportError:
# Fallback if sklearn is not available
return self._fallback_preprocessing(input_data)
@I.pyfunc
def multi_stage_pipeline(self, raw_input):
"""Multi-stage processing pipeline."""
# Stage 1: Data cleaning
cleaned = self._clean_data(raw_input)
# Stage 2: Feature extraction
features = self._extract_features(cleaned)
# Stage 3: Model inference
predictions = self._run_inference(features)
# Stage 4: Post-processing
final_result = self._post_process_output(predictions)
return final_result
@T.prim_func
def final_transform(data: T.handle, output: T.handle):
T.func_attr({"tir.noalias": True})
Data = T.match_buffer(data, (10, 10), "float32")
Output = T.match_buffer(output, (10, 10), "float32")
for i in range(10):
for j in range(10):
Output[i, j] = T.tanh(Data[i, j])
@I.ir_module
class ErrorHandlingPyFuncModule(BasePyModule):
"""Test comprehensive error handling and validation."""
@I.pyfunc
def robust_data_processing(self, input_data, config):
"""Robust data processing with comprehensive error handling."""
try:
# Validate inputs
if not self._validate_inputs(input_data, config):
raise ValueError("Invalid input data or configuration")
# Check data types
if not self._check_data_types(input_data):
raise TypeError("Unsupported data types")
# Process data with retry logic
max_retries = config.get("max_retries", 3)
for attempt in range(max_retries):
try:
result = self._process_with_validation(input_data, config)
if self._validate_output(result):
return result
else:
raise RuntimeError("Output validation failed")
except Exception as e:
if attempt == max_retries - 1:
raise
self._log_warning(f"Attempt {attempt + 1} failed: {e}")
continue
except Exception as e:
self._log_error(f"Data processing failed: {e}")
return self._get_safe_fallback(input_data, config)
@I.pyfunc
def graceful_degradation(self, primary_input, fallback_input):
"""Function that gracefully degrades when primary path fails."""
try:
# Try primary processing path
result = self._primary_processing(primary_input)
return result
except Exception as e:
self._log_warning(f"Primary processing failed: {e}")
try:
# Try fallback path
result = self._fallback_processing(fallback_input)
return result
except Exception as e2:
self._log_error(f"Fallback processing also failed: {e2}")
# Return safe default
return self._get_safe_default()
@T.prim_func
def safe_transform(data: T.handle, output: T.handle):
T.func_attr({"tir.noalias": True})
Data = T.match_buffer(data, (5,), "float32")
Output = T.match_buffer(output, (5,), "float32")
for i in range(5):
# Safe operation that handles edge cases
if Data[i] > 0:
Output[i] = T.sqrt(Data[i])
else:
Output[i] = 0.0
# Pytest test functions to verify the classes work correctly
def test_simple_pyfunc_module_creation():
"""Test that SimplePyFuncModule can be created."""
# Get the IRModule instance from the TVMScript decorated class
ir_mod = SimplePyFuncModule
device = tvm.cpu()
# Create BasePyModule instance
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Note: Python functions are stored in pyfuncs, not as direct attributes
# We need to check if they exist in the IRModule's pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "add" in ir_mod.pyfuncs
assert "multiply" in ir_mod.pyfuncs
# Check that TIR functions exist
assert hasattr(module, "add_tir")
assert hasattr(module, "multiply_tir")
# Note: This particular TVMScript is for testing purpose only, and cannot compile
# Relax functions may not be available due to TVMScript compilation issues
print("Note: This TVMScript is for testing purpose only, and cannot compile")
def test_complex_pyfunc_module_creation():
"""Test that ComplexPyFuncModule can be created."""
ir_mod = ComplexPyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Check Python functions in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "ml_pipeline" in ir_mod.pyfuncs
assert "data_preprocessing" in ir_mod.pyfuncs
# Check TIR functions
assert hasattr(module, "extract_features")
assert hasattr(module, "ml_inference")
assert hasattr(module, "post_process")
assert hasattr(module, "normalize_data")
def test_edge_case_pyfunc_module_creation():
"""Test that EdgeCasePyFuncModule can be created."""
ir_mod = EdgeCasePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Check Python functions in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "empty_func" in ir_mod.pyfuncs
assert "single_return" in ir_mod.pyfuncs
assert "nested_conditionals" in ir_mod.pyfuncs
assert "loop_with_break" in ir_mod.pyfuncs
# Check TIR function
assert hasattr(module, "dummy_tir")
def test_performance_pyfunc_module_creation():
"""Test that PerformancePyFuncModule can be created."""
ir_mod = PerformancePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Check Python functions in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "vectorized_operation" in ir_mod.pyfuncs
assert "batch_processing" in ir_mod.pyfuncs
assert "memory_efficient_transform" in ir_mod.pyfuncs
# Check TIR function
assert hasattr(module, "vectorized_add")
def test_integration_pyfunc_module_creation():
"""Test that IntegrationPyFuncModule can be created."""
ir_mod = IntegrationPyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Check Python functions in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "sklearn_integration" in ir_mod.pyfuncs
assert "multi_stage_pipeline" in ir_mod.pyfuncs
# Check TIR function
assert hasattr(module, "final_transform")
def test_error_handling_pyfunc_module_creation():
"""Test that ErrorHandlingPyFuncModule can be created."""
ir_mod = ErrorHandlingPyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
# Check Python functions in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "robust_data_processing" in ir_mod.pyfuncs
assert "graceful_degradation" in ir_mod.pyfuncs
# Check TIR function
assert hasattr(module, "safe_transform")
def test_all_modules_inherit_from_base():
"""Test that all modules properly inherit from BasePyModule."""
modules = [
SimplePyFuncModule,
ComplexPyFuncModule,
EdgeCasePyFuncModule,
PerformancePyFuncModule,
IntegrationPyFuncModule,
ErrorHandlingPyFuncModule,
]
device = tvm.cpu()
for ir_mod in modules:
module = BasePyModule(ir_mod, device)
assert isinstance(module, BasePyModule)
assert hasattr(module, "script")
assert hasattr(module, "show")
def test_pyfunc_decorators():
"""Test that all @I.pyfunc decorated functions are present."""
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Check that the functions exist in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "add" in ir_mod.pyfuncs
assert "multiply" in ir_mod.pyfuncs
# Get the actual function objects
add_func = ir_mod.pyfuncs["add"]
multiply_func = ir_mod.pyfuncs["multiply"]
# Check that they are callable
assert callable(add_func)
assert callable(multiply_func)
# Check function signatures
import inspect
add_sig = inspect.signature(add_func)
assert len(add_sig.parameters) == 3 # self, x, y
multiply_sig = inspect.signature(multiply_func)
assert len(multiply_sig.parameters) == 3 # self, x, y
def test_tir_functions():
"""Test that TIR functions are properly defined."""
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Check TIR function attributes
assert hasattr(module, "add_tir")
assert hasattr(module, "multiply_tir")
# These should be callable (though they're TIR functions)
assert callable(module.add_tir)
assert callable(module.multiply_tir)
def test_relax_functions():
"""Test that Relax functions are properly defined."""
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Note: This particular TVMScript is for testing purpose only, and cannot compile
# Relax functions may not be available due to TVMScript compilation issues
print("Note: This TVMScript is for testing purpose only, and cannot compile")
# We can still check that the module was created successfully
assert isinstance(module, BasePyModule)
assert hasattr(module, "script")
assert hasattr(module, "show")
def test_module_docstrings():
"""Test that all modules have proper docstrings."""
modules = [
SimplePyFuncModule,
ComplexPyFuncModule,
EdgeCasePyFuncModule,
PerformancePyFuncModule,
IntegrationPyFuncModule,
ErrorHandlingPyFuncModule,
]
for module_class in modules:
# TVMScript decorator changes the class, so we check that it's callable
# and can create instances instead of checking docstrings
assert callable(module_class)
# We can't directly instantiate TVMScript decorated classes
# but we can create BasePyModule instances with them
device = tvm.cpu()
instance = BasePyModule(module_class, device)
assert isinstance(instance, BasePyModule)
def test_python_function_complexity():
"""Test that complex Python functions have the expected structure."""
ir_mod = ComplexPyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Check that complex functions exist in pyfuncs
if hasattr(ir_mod, "pyfuncs"):
assert "ml_pipeline" in ir_mod.pyfuncs
assert "data_preprocessing" in ir_mod.pyfuncs
# Get the actual function objects
ml_func = ir_mod.pyfuncs["ml_pipeline"]
preprocess_func = ir_mod.pyfuncs["data_preprocessing"]
# These should be callable
assert callable(ml_func)
assert callable(preprocess_func)
# Check function signatures
import inspect
ml_sig = inspect.signature(ml_func)
assert len(ml_sig.parameters) == 3 # self, input_data, model_params
preprocess_sig = inspect.signature(preprocess_func)
assert len(preprocess_sig.parameters) == 2 # self, raw_data
def test_script_and_show_methods():
"""Test that script() and show() methods work correctly."""
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Test script() method
script_output = module.script()
assert isinstance(script_output, str)
assert len(script_output) > 0
# Test show() method
try:
module.show()
# If we get here, show() worked
assert True
except Exception as e:
# If show() fails, the feature is not working properly
pytest.fail(f"show() method failed: {e}")
def test_python_functions_in_irmodule():
"""Test that Python functions are properly stored in IRModule pyfuncs."""
ir_mod = SimplePyFuncModule
device = tvm.cpu()
module = BasePyModule(ir_mod, device)
# Check that pyfuncs attribute exists and contains our functions
if hasattr(ir_mod, "pyfuncs"):
pyfuncs = ir_mod.pyfuncs
assert isinstance(pyfuncs, dict)
assert "add" in pyfuncs
assert "multiply" in pyfuncs
# Check that the functions are callable
assert callable(pyfuncs["add"])
assert callable(pyfuncs["multiply"])
# Check function names
assert pyfuncs["add"].__name__ == "add"
assert pyfuncs["multiply"].__name__ == "multiply"
else:
pytest.fail("pyfuncs attribute not found in IRModule")
def test_call_py_func_with_base_py_module():
"""Test R.call_py_func with BasePyModule."""
import torch
import numpy as np
from tvm.relax.op import call_py_func
from tvm.relax.expr import StringImm
from tvm.relax import Var, TensorStructInfo
# Test 1: Operator creation and basic properties
x = Var("x", TensorStructInfo((5,), "float32"))
y = Var("y", TensorStructInfo((5,), "float32"))
call_expr = call_py_func(StringImm("test_func"), (x, y), out_sinfo=R.Tensor((5,), "float32"))
assert call_expr.op.name == "relax.call_py_func"
assert call_expr.args[0].value == "test_func"
assert len(call_expr.args) == 2
# Test 2: Compilation validation
try:
call_py_func(
"invalid",
(Var("x", TensorStructInfo((5,), "float32")),),
out_sinfo=R.Tensor((5,), "float32"),
)
assert False, "Should raise type error"
except Exception as e:
assert "Mismatched type" in str(e) or "Expected" in str(e)
# Test 3: Validation and error handling
@I.ir_module
class ValidationTestModule(BasePyModule):
@R.function
def test_invalid_call(x: R.Tensor((5,), "float32")) -> R.Tensor((5,), "float32"):
result = R.call_py_func("non_existent_func", (x,), out_sinfo=R.Tensor((5,), "float32"))
return result
device = tvm.cpu()
module = ValidationTestModule(device)
x = torch.randn(5, dtype=torch.float32)
with pytest.raises(ValueError, match="Python function 'non_existent_func' not found"):
module.call_py_func("non_existent_func", [x])
# Test 4: Using call_py_func within Relax functions
@I.ir_module
class RelaxCallPyFuncModule(BasePyModule):
@I.pyfunc
def torch_relu(self, x):
"""PyTorch ReLU implementation."""
return torch.relu(x)
@I.pyfunc
def torch_softmax(self, x, dim=0):
"""PyTorch softmax implementation."""
return torch.softmax(x, dim=dim)
@R.function
def mixed_computation(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), "float32"):
relu_result = R.call_py_func("torch_relu", (x,), out_sinfo=R.Tensor((10,), "float32"))
final_result = R.call_py_func(
"torch_softmax", (relu_result,), out_sinfo=R.Tensor((10,), "float32")
)
return final_result
device = tvm.cpu()
module = RelaxCallPyFuncModule(device)
x = torch.randn(10, dtype=torch.float32)
expected = torch.softmax(torch.relu(x), dim=0)
relu_result = module.call_py_func("torch_relu", [x])
final_result = module.call_py_func("torch_softmax", [relu_result])
# Convert to numpy for comparison
if isinstance(final_result, tvm.runtime.Tensor):
final_result_np = final_result.numpy()
else:
final_result_np = final_result
if isinstance(expected, torch.Tensor):
expected_np = expected.numpy()
else:
expected_np = expected
# Use numpy for comparison since we have numpy arrays
np.testing.assert_allclose(final_result_np, expected_np, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
tvm.testing.main()