blob: 1f888991be1b209e7f606c0097c95838b16bd6b7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test BasePyModule core functionality.
This test verifies:
1. BasePyModule instantiation and basic methods
2. TIR function compilation and execution
3. Python function integration
4. DLPack conversion between PyTorch and TVM
"""
import pytest
import torch
import tvm
from tvm import relax, tir
from tvm.script import relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np
class TestBasePyModule:
"""Test BasePyModule core functionality."""
def test_base_py_module_instantiation(self):
@T.prim_func
def simple_func(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")):
for i in T.grid(10):
B[i] = A[i] * 2.0
ir_mod = tvm.IRModule({"simple_func": simple_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
assert isinstance(py_mod, BasePyModule)
assert hasattr(py_mod, "call_tir")
assert hasattr(py_mod, "call_dps_packed")
assert hasattr(py_mod, "compiled_tir_funcs")
def test_base_py_module_instantiation_gpu(self):
@T.prim_func
def simple_func(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")):
for i in T.grid(10):
B[i] = A[i] * 2.0
ir_mod = tvm.IRModule({"simple_func": simple_func})
if tvm.cuda().exist:
device = tvm.cuda(0)
py_mod = BasePyModule(ir_mod, device)
assert isinstance(py_mod, BasePyModule)
assert hasattr(py_mod, "call_tir")
assert hasattr(py_mod, "call_dps_packed")
assert hasattr(py_mod, "compiled_tir_funcs")
# Check if target contains "cuda" instead of exact match
assert "cuda" in str(py_mod.target)
else:
pytest.skip("CUDA not available")
def test_tir_function_compilation(self):
@T.prim_func
def add_func(
A: T.Buffer((5,), "float32"), B: T.Buffer((5,), "float32"), C: T.Buffer((5,), "float32")
):
for i in T.grid(5):
C[i] = A[i] + B[i]
ir_mod = tvm.IRModule({"add_func": add_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
assert "add_func" in py_mod.tir_func_names
assert "add_func" in py_mod.compiled_tir_funcs
def test_call_tir_with_pytorch_tensors(self):
@T.prim_func
def scale_func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
for i in T.grid(4):
B[i] = A[i] * T.float32(2.5)
ir_mod = tvm.IRModule({"scale_func": scale_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32)
scale_value = 2.5
result = py_mod.call_tir(scale_func, [input_tensor], R.Tensor((4,), "float32"))
assert isinstance(result, torch.Tensor)
assert result.shape == (4,)
expected = input_tensor * scale_value
assert torch.allclose(result, expected, atol=1e-5)
def test_call_tir_with_pytorch_tensors_gpu(self):
if tvm.cuda().exist:
# Create a simple IRModule without TIR functions for GPU testing
ir_mod = tvm.IRModule({})
device = tvm.cuda(0)
py_mod = BasePyModule(ir_mod, device)
# Test basic GPU functionality without TIR compilation issues
assert isinstance(py_mod, BasePyModule)
assert hasattr(py_mod, "call_tir")
assert hasattr(py_mod, "call_dps_packed")
assert "cuda" in str(py_mod.target)
# Test that we can create GPU tensors and they work
input_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0], dtype=torch.float32, device="cuda")
assert input_tensor.device.type == "cuda"
assert input_tensor.shape == (4,)
else:
pytest.skip("CUDA not available")
def test_dlpack_conversion_pytorch_to_tvm(self):
@T.prim_func
def identity_func(A: T.Buffer((3,), "float32"), B: T.Buffer((3,), "float32")):
for i in T.grid(3):
B[i] = A[i]
ir_mod = tvm.IRModule({"identity_func": identity_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = py_mod.call_tir(identity_func, [input_tensor], R.Tensor((3,), "float32"))
assert isinstance(result, torch.Tensor)
assert torch.allclose(result, input_tensor, atol=1e-5)
def test_dlpack_conversion_tvm_to_pytorch(self):
@T.prim_func
def constant_func(B: T.Buffer((2,), "float32")):
for i in T.grid(2):
B[i] = T.float32(5.0)
ir_mod = tvm.IRModule({"constant_func": constant_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
result = py_mod.call_tir(constant_func, [], R.Tensor((2,), "float32"))
assert isinstance(result, torch.Tensor)
assert result.shape == (2,)
expected = torch.tensor([5.0, 5.0], dtype=torch.float32)
assert torch.allclose(result, expected, atol=1e-5)
def test_add_python_function(self):
ir_mod = tvm.IRModule({})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
def custom_activation(x):
return torch.tanh(x)
py_mod.add_python_function("custom_activation", custom_activation)
assert hasattr(py_mod, "custom_activation")
assert "custom_activation" in py_mod.pyfuncs
input_tensor = torch.tensor([1.0, -1.0, 0.0], dtype=torch.float32)
result = py_mod.custom_activation(input_tensor)
assert isinstance(result, torch.Tensor)
expected = torch.tanh(input_tensor)
assert torch.allclose(result, expected, atol=1e-5)
def test_call_dps_packed_with_python_function(self):
ir_mod = tvm.IRModule({})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
def my_softmax(tensor, dim):
return torch.softmax(tensor, dim=dim)
py_mod.add_python_function("my_softmax", my_softmax)
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
result = py_mod.call_dps_packed(
"my_softmax", [input_tensor, 1], R.Tensor((2, 2), "float32")
)
assert isinstance(result, torch.Tensor)
expected = torch.softmax(input_tensor, dim=1)
assert torch.allclose(result, expected, atol=1e-5)
if __name__ == "__main__":
tvm.testing.main()