blob: 88d4e615cfd80970c35f2355d0216a1d3f8c9c1b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test PyTorch integration with TVM Relax.
This test verifies:
1. Seamless PyTorch tensor I/O with TVM backend
2. Cross-function calls between Python, TIR, and Relax functions
3. Dynamic Python function addition and execution
4. End-to-end pipeline testing
5. Error handling and edge cases
"""
import pytest
import torch
import torch.nn.functional as F
import tvm
from tvm import relax, tir
from tvm.script import ir as I, relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np
@I.ir_module
class PyTorchIntegrationModule(BasePyModule):
"""Test module for PyTorch integration with TVM."""
@I.pyfunc
def main(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
"""Main function demonstrating cross-function calls."""
n = x.shape[0]
# Call TIR function
lv = self.call_tir(self.matmul, [x, w], out_sinfo=R.Tensor((n, 20), "float32"))
# Apply ReLU
lv1 = F.relu(lv)
# Call packed function (will be added dynamically)
lv2 = self.call_dps_packed("my_softmax", [lv1, 1], out_sinfo=R.Tensor((n, 20), "float32"))
# Call Python function
lv3 = self.my_identity_func(lv2)
return lv3
@T.prim_func
def matmul(
var_A: T.handle,
var_B: T.handle,
var_C: T.handle,
):
"""TIR function for matrix multiplication."""
n = T.int32()
A = T.match_buffer(var_A, (n, 16), "float32")
B = T.match_buffer(var_B, (16, 20), "float32")
C = T.match_buffer(var_C, (n, 20), "float32")
for i, j, k in T.grid(n, 20, 16):
with T.sblock("block"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
@I.pyfunc
def my_identity_func(self, x: torch.Tensor) -> torch.Tensor:
return x
class TestPyTorchIntegration:
def test_module_creation_and_instantiation(self):
module = PyTorchIntegrationModule
assert hasattr(module, "__call__"), "Module should be callable"
device = tvm.cpu(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "Instance should be BasePyModule"
required_methods = ["main", "call_tir", "call_dps_packed"]
for method in required_methods:
assert hasattr(instance, method), f"Instance should have method: {method}"
def test_module_creation_and_instantiation_gpu(self):
module = PyTorchIntegrationModule
if tvm.cuda().exist:
assert hasattr(module, "__call__"), "Module should be callable"
device = tvm.cuda(0)
instance = module(device)
assert isinstance(instance, BasePyModule), "Instance should be BasePyModule"
required_methods = ["main", "call_tir", "call_dps_packed"]
for method in required_methods:
assert hasattr(instance, method), f"Instance should have method: {method}"
assert "cuda" in str(instance.target)
else:
pytest.skip("CUDA not available")
def test_python_function_execution(self):
"""Test that Python functions execute correctly."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Test my_identity_func
input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = instance.my_identity_func(input_tensor)
assert isinstance(result, torch.Tensor)
assert torch.allclose(result, input_tensor, atol=1e-5)
def test_tir_function_execution(self):
"""Test that TIR functions execute correctly."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Test matmul function
n = 3
x = torch.randn(n, 16, dtype=torch.float32)
w = torch.randn(16, 20, dtype=torch.float32)
result = instance.call_tir(instance.matmul, [x, w], R.Tensor((n, 20), "float32"))
assert isinstance(result, torch.Tensor)
assert result.shape == (n, 20)
# Verify result with PyTorch matmul
expected = torch.matmul(x, w)
assert torch.allclose(result, expected, atol=1e-3)
def test_dynamic_python_function_addition(self):
"""Test adding Python functions dynamically."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Define a custom function
def custom_activation(x):
return torch.sigmoid(x)
# Add the function
instance.add_python_function("custom_activation", custom_activation)
# Verify function is added
assert hasattr(instance, "custom_activation")
assert "custom_activation" in instance.pyfuncs
# Test function execution
input_tensor = torch.tensor([1.0, -1.0, 0.0], dtype=torch.float32)
result = instance.custom_activation(input_tensor)
assert isinstance(result, torch.Tensor)
expected = torch.sigmoid(input_tensor)
assert torch.allclose(result, expected, atol=1e-5)
def test_call_dps_packed_with_dynamic_function(self):
"""Test call_dps_packed with dynamically added function."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Define my_softmax function
def my_softmax(tensor, dim):
"""Custom softmax function for testing call_dps_packed."""
# Convert TVM Tensor to PyTorch tensor if needed
if hasattr(tensor, "numpy"):
tensor = torch.from_numpy(tensor.numpy())
return F.softmax(tensor, dim=dim)
# Add the function
instance.my_softmax = my_softmax
# Test call_dps_packed
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
result = instance.call_dps_packed(
"my_softmax", [input_tensor, 1], R.Tensor((2, 2), "float32")
)
assert isinstance(result, torch.Tensor)
expected = F.softmax(input_tensor, dim=1)
assert torch.allclose(result, expected, atol=1e-5)
def test_end_to_end_pipeline(self):
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
def my_softmax(tensor, dim):
if hasattr(tensor, "numpy"):
tensor = torch.from_numpy(tensor.numpy())
return F.softmax(tensor, dim=dim)
instance.my_softmax = my_softmax
n = 5
x = torch.randn(n, 16, dtype=torch.float32)
w = torch.randn(16, 20, dtype=torch.float32)
result = instance.main(x, w)
assert isinstance(result, torch.Tensor)
assert result.shape == (n, 20)
assert result.dtype == torch.float32
def test_end_to_end_pipeline_gpu(self):
module = PyTorchIntegrationModule
if tvm.cuda().exist:
device = tvm.cuda(0)
instance = module(device)
# Test basic GPU functionality without complex TIR operations
assert isinstance(instance, BasePyModule)
assert "cuda" in str(instance.target)
# Test that we can create and work with GPU tensors
n = 5
x = torch.randn(n, 16, dtype=torch.float32, device="cuda")
w = torch.randn(16, 20, dtype=torch.float32, device="cuda")
assert x.device.type == "cuda"
assert w.device.type == "cuda"
assert x.shape == (n, 16)
assert w.shape == (16, 20)
# Test basic PyTorch operations on GPU
result = torch.matmul(x, w)
assert isinstance(result, torch.Tensor)
assert result.shape == (n, 20)
assert result.dtype == torch.float32
assert result.device.type == "cuda"
else:
pytest.skip("CUDA not available")
def test_cross_function_data_flow(self):
"""Test data flow between different function types."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Add required functions
def my_softmax(tensor, dim):
if hasattr(tensor, "numpy"):
tensor = torch.from_numpy(tensor.numpy())
return F.softmax(tensor, dim=dim)
instance.my_softmax = my_softmax
# Create test data
n = 4
x = torch.randn(n, 16, dtype=torch.float32)
w = torch.randn(16, 20, dtype=torch.float32)
# Execute step by step to verify data flow
# Step 1: TIR matmul
lv = instance.call_tir(instance.matmul, [x, w], R.Tensor((n, 20), "float32"))
assert isinstance(lv, torch.Tensor)
assert lv.shape == (n, 20)
# Step 2: ReLU
lv1 = F.relu(lv)
assert isinstance(lv1, torch.Tensor)
assert lv1.shape == (n, 20)
# Step 3: Softmax via call_dps_packed
lv2 = instance.call_dps_packed("my_softmax", [lv1, 1], R.Tensor((n, 20), "float32"))
assert isinstance(lv2, torch.Tensor)
assert lv2.shape == (n, 20)
# Step 4: Identity function
lv3 = instance.my_identity_func(lv2)
assert isinstance(lv3, torch.Tensor)
assert lv3.shape == (n, 20)
# Verify final result matches expected
expected = F.softmax(F.relu(torch.matmul(x, w)), dim=1)
assert torch.allclose(lv3, expected, atol=1e-3)
def test_error_handling(self):
"""Test error handling for various edge cases."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Test with missing function
with pytest.raises(Exception):
instance.call_dps_packed(
"non_existent_function", [torch.tensor([1.0])], R.Tensor((1,), "float32")
)
# Test with wrong tensor shapes
x = torch.randn(3, 16, dtype=torch.float32)
w = torch.randn(15, 20, dtype=torch.float32) # Wrong shape
with pytest.raises(Exception):
instance.call_tir(instance.matmul, [x, w], R.Tensor((3, 20), "float32"))
def test_tensor_type_preservation(self):
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
def my_softmax(tensor, dim):
if hasattr(tensor, "numpy"):
tensor = torch.from_numpy(tensor.numpy())
return F.softmax(tensor, dim=dim)
instance.my_softmax = my_softmax
# Test with float32 data type (TIR function is hardcoded for float32)
test_dtype = torch.float32
n = 3
x = torch.randn(n, 16, dtype=test_dtype)
w = torch.randn(16, 20, dtype=test_dtype)
result = instance.main(x, w)
# Verify type preservation
assert result.dtype == test_dtype
assert isinstance(result, torch.Tensor)
assert result.shape == (n, 20)
assert result.dtype == torch.float32
def test_batch_processing(self):
"""Test processing multiple inputs in batch."""
module = PyTorchIntegrationModule
device = tvm.cpu(0)
instance = module(device)
# Add required functions
def my_softmax(tensor, dim):
if hasattr(tensor, "numpy"):
tensor = torch.from_numpy(tensor.numpy())
return F.softmax(tensor, dim=dim)
instance.my_softmax = my_softmax
# Process multiple inputs
batch_size = 5
results = []
for i in range(batch_size):
n = 3 + i # Varying batch sizes
x = torch.randn(n, 16, dtype=torch.float32)
w = torch.randn(16, 20, dtype=torch.float32)
result = instance.main(x, w)
results.append(result)
assert isinstance(result, torch.Tensor)
assert result.shape == (n, 20)
# Verify all results are valid
assert len(results) == batch_size
for result in results:
assert isinstance(result, torch.Tensor)
if __name__ == "__main__":
pytest.main([__file__])