blob: 7378fe74a42b4c81f7cb2ba77eb873ce55d6d586 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test DLPack integration between PyTorch and TVM.
This test verifies:
1. DLPack conversion from PyTorch to TVM
2. DLPack conversion from TVM to PyTorch
3. Data integrity preservation during conversion
4. Functionality equivalence between DLPack and numpy fallback
5. Error handling for unsupported data types
"""
import pytest
import torch
import tvm
from tvm import relax, tir
from tvm.script import relax as R, tir as T
from tvm.relax import BasePyModule
import numpy as np
class TestDLPackIntegration:
def test_dlpack_pytorch_to_tvm_conversion(self):
pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor)
assert isinstance(tvm_tensor, tvm.runtime.Tensor)
assert tvm_tensor.shape == pytorch_tensor.shape
assert str(tvm_tensor.dtype) == str(pytorch_tensor.dtype).replace("torch.", "")
tvm_numpy = tvm_tensor.numpy()
pytorch_numpy = pytorch_tensor.numpy()
np.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5)
def test_dlpack_pytorch_to_tvm_conversion_gpu(self):
if tvm.cuda().exist:
pytorch_tensor = torch.tensor(
[1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32, device="cuda"
)
tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor)
assert isinstance(tvm_tensor, tvm.runtime.Tensor)
assert tvm_tensor.shape == pytorch_tensor.shape
assert str(tvm_tensor.dtype) == str(pytorch_tensor.dtype).replace("torch.", "")
assert str(tvm_tensor.device) == "cuda:0"
# Move to CPU for numpy conversion
tvm_numpy = tvm_tensor.numpy()
pytorch_numpy = pytorch_tensor.cpu().numpy()
np.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5)
else:
pytest.skip("CUDA not available")
def test_dlpack_tvm_to_pytorch_conversion(self):
import numpy as np
data = np.array([1.0, 2.0, 3.0, 5.0], dtype="float32")
tvm_tensor = tvm.runtime.tensor(data)
pytorch_tensor = torch.from_dlpack(tvm_tensor)
assert isinstance(pytorch_tensor, torch.Tensor)
assert pytorch_tensor.shape == tvm_tensor.shape
assert pytorch_tensor.dtype == torch.float32
tvm_numpy = tvm_tensor.numpy()
pytorch_numpy = pytorch_tensor.numpy()
np.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5)
def test_dlpack_tvm_to_pytorch_conversion_gpu(self):
if tvm.cuda().exist:
import numpy as np
data = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype="float32")
tvm_tensor = tvm.runtime.tensor(data, device=tvm.cuda(0))
pytorch_tensor = torch.from_dlpack(tvm_tensor)
assert isinstance(pytorch_tensor, torch.Tensor)
assert pytorch_tensor.shape == tvm_tensor.shape
assert pytorch_tensor.dtype == torch.float32
assert pytorch_tensor.device.type == "cuda"
tvm_numpy = tvm_tensor.numpy()
pytorch_numpy = pytorch_tensor.cpu().numpy()
np.testing.assert_allclose(tvm_numpy, pytorch_numpy, atol=1e-5)
else:
pytest.skip("CUDA not available")
def test_dlpack_roundtrip_conversion(self):
"""Test roundtrip conversion: PyTorch -> TVM -> PyTorch."""
# Create PyTorch tensor
original_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
# Convert to TVM
tvm_tensor = tvm.runtime.from_dlpack(original_tensor)
# Convert back to PyTorch
result_tensor = torch.from_dlpack(tvm_tensor)
# Verify roundtrip integrity
assert torch.allclose(original_tensor, result_tensor, atol=1e-5)
assert original_tensor.dtype == result_tensor.dtype
assert original_tensor.shape == result_tensor.shape
def test_dlpack_different_data_types(self):
"""Test DLPack conversion with different data types."""
test_types = [
(torch.float32, "float32"),
(torch.float64, "float64"),
(torch.int32, "int32"),
(torch.int64, "int64"),
]
for torch_dtype, tvm_dtype in test_types:
# Create PyTorch tensor
pytorch_tensor = torch.tensor([1, 2, 3], dtype=torch_dtype)
# Convert to TVM
tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor)
# Convert back to PyTorch
result_tensor = torch.from_dlpack(tvm_tensor)
# Verify conversion
assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5)
assert pytorch_tensor.dtype == result_tensor.dtype
def test_dlpack_different_shapes(self):
"""Test DLPack conversion with different tensor shapes."""
test_shapes = [
(1,),
(2, 3),
(4, 5, 6),
(1, 1, 1, 1),
]
for shape in test_shapes:
# Create PyTorch tensor
pytorch_tensor = torch.randn(shape, dtype=torch.float32)
# Convert to TVM
tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor)
# Convert back to PyTorch
result_tensor = torch.from_dlpack(tvm_tensor)
# Verify conversion
assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5)
assert pytorch_tensor.shape == result_tensor.shape
def test_dlpack_functionality_verification(self):
"""Test that DLPack and numpy conversions produce identical results."""
# Create large PyTorch tensor
size = 1000000
pytorch_tensor = torch.randn(size, dtype=torch.float32)
# Test DLPack conversion
tvm_tensor_dlpack = tvm.runtime.from_dlpack(pytorch_tensor)
# Test numpy conversion
numpy_array = pytorch_tensor.detach().cpu().numpy()
tvm_tensor_numpy = tvm.runtime.tensor(numpy_array)
# Verify both methods produce same result
result_dlpack = torch.from_dlpack(tvm_tensor_dlpack)
result_numpy = torch.from_numpy(tvm_tensor_numpy.numpy())
assert torch.allclose(result_dlpack, result_numpy, atol=1e-5)
# Verify data integrity
assert torch.allclose(result_dlpack, pytorch_tensor, atol=1e-5)
assert result_dlpack.shape == pytorch_tensor.shape
assert result_dlpack.dtype == pytorch_tensor.dtype
def test_dlpack_error_handling(self):
"""Test DLPack error handling for unsupported operations."""
# Test with non-contiguous tensor
pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
non_contiguous = pytorch_tensor[::2] # Create non-contiguous view
# This should work (PyTorch handles non-contiguous tensors)
try:
tvm_tensor = tvm.runtime.from_dlpack(non_contiguous)
result_tensor = torch.from_dlpack(tvm_tensor)
assert torch.allclose(non_contiguous, result_tensor, atol=1e-5)
except Exception as e:
# If it fails, that's also acceptable
pass
def test_dlpack_with_base_py_module(self):
"""Test DLPack conversion within BasePyModule context."""
# Create a simple IRModule
@T.prim_func
def identity_func(A: T.Buffer((3,), "float32"), B: T.Buffer((3,), "float32")):
for i in T.grid(3):
B[i] = A[i]
ir_mod = tvm.IRModule({"identity_func": identity_func})
device = tvm.cpu(0)
py_mod = BasePyModule(ir_mod, device)
# Create PyTorch tensor
input_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
# Call TIR function (this will trigger DLPack conversion)
result = py_mod.call_tir(identity_func, [input_tensor], R.Tensor((3,), "float32"))
# Verify result
assert isinstance(result, torch.Tensor)
assert torch.allclose(result, input_tensor, atol=1e-5)
def test_dlpack_device_consistency(self):
"""Test DLPack conversion maintains device consistency."""
# Test CPU tensor
cpu_tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
cpu_tvm = tvm.runtime.from_dlpack(cpu_tensor)
cpu_result = torch.from_dlpack(cpu_tvm)
assert cpu_result.device.type == "cpu"
assert torch.allclose(cpu_tensor, cpu_result, atol=1e-5)
# Note: GPU testing would require CUDA/OpenCL setup
# This is a basic test that CPU works correctly
def test_dlpack_memory_sharing(self):
"""Test that DLPack conversion shares memory when possible."""
# Create PyTorch tensor
pytorch_tensor = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float32)
# Convert to TVM
tvm_tensor = tvm.runtime.from_dlpack(pytorch_tensor)
# Modify the original tensor
pytorch_tensor[0] = 10.0
# Convert back to PyTorch
result_tensor = torch.from_dlpack(tvm_tensor)
# The result should reflect the modification (memory sharing)
assert result_tensor[0] == 10.0
assert torch.allclose(pytorch_tensor, result_tensor, atol=1e-5)
def test_dlpack_batch_operations(self):
"""Test DLPack conversion with batch operations."""
# Create batch of tensors
batch_size = 10
pytorch_tensors = [torch.randn(5, dtype=torch.float32) for _ in range(batch_size)]
# Convert all to TVM
tvm_tensors = [tvm.runtime.from_dlpack(t) for t in pytorch_tensors]
# Convert all back to PyTorch
result_tensors = [torch.from_dlpack(t) for t in tvm_tensors]
# Verify all conversions
for i in range(batch_size):
assert torch.allclose(pytorch_tensors[i], result_tensors[i], atol=1e-5)
def test_dlpack_edge_cases(self):
"""Test DLPack conversion with edge cases."""
# Empty tensor
empty_tensor = torch.tensor([], dtype=torch.float32)
empty_tvm = tvm.runtime.from_dlpack(empty_tensor)
empty_result = torch.from_dlpack(empty_tvm)
assert empty_result.shape == empty_tensor.shape
assert empty_result.dtype == empty_tensor.dtype
# Single element tensor
single_tensor = torch.tensor([42.0], dtype=torch.float32)
single_tvm = tvm.runtime.from_dlpack(single_tensor)
single_result = torch.from_dlpack(single_tvm)
assert single_result.shape == single_tensor.shape
assert single_result[0] == 42.0
if __name__ == "__main__":
pytest.main([__file__])