blob: 393d640b237bf02d086cf9fac4dd19adfd7fa9fb [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: F821
"""Tests for ``@T.jit`` + ``T.constexpr``."""
from __future__ import annotations
import pytest
import tvm
from tvm.ir import assert_structural_equal
from tvm.script import tirx as T
def test_int_constexpr_specializes_loop_bound():
@T.jit(private=True)
def add(
A: T.Buffer((N,), "int32"),
B: T.Buffer((N,), "int32"),
C: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
for i in range(N):
C[i] = A[i] + B[i]
@T.prim_func(private=True)
def expected(
A: T.Buffer((128,), "int32"),
B: T.Buffer((128,), "int32"),
C: T.Buffer((128,), "int32"),
):
for i in range(128):
C[i] = A[i] + B[i]
assert_structural_equal(add.specialize(N=128), expected, map_free_vars=True)
def test_constexpr_in_2d_buffer_shape():
@T.jit(private=True)
def matadd(
A: T.Buffer((M, K), "int32"),
B: T.Buffer((M, K), "int32"),
C: T.Buffer((M, K), "int32"),
*,
M: T.constexpr,
K: T.constexpr,
):
for m in range(M):
for k in range(K):
C[m, k] = A[m, k] + B[m, k]
@T.prim_func(private=True)
def expected(
A: T.Buffer((4, 8), "int32"),
B: T.Buffer((4, 8), "int32"),
C: T.Buffer((4, 8), "int32"),
):
for m in range(4):
for k in range(8):
C[m, k] = A[m, k] + B[m, k]
assert_structural_equal(matadd.specialize(M=4, K=8), expected, map_free_vars=True)
def test_constexpr_in_body_expression():
@T.jit(private=True)
def scaled_copy(
A: T.Buffer((N,), "int32"),
B: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
SCALE: T.constexpr,
):
for i in range(N):
B[i] = A[i] * SCALE
@T.prim_func(private=True)
def expected(
A: T.Buffer((16,), "int32"),
B: T.Buffer((16,), "int32"),
):
for i in range(16):
B[i] = A[i] * 3
assert_structural_equal(scaled_copy.specialize(N=16, SCALE=3), expected, map_free_vars=True)
def test_specialize_cache_returns_same_instance():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
for i in range(N):
A[i] = 0
a = k.specialize(N=8)
b = k.specialize(N=8)
assert a is b
def test_specialize_different_args_produce_different_funcs():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
for i in range(N):
A[i] = 0
assert k.specialize(N=8) is not k.specialize(N=16)
def test_specialize_missing_constexpr_raises():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
SCALE: T.constexpr,
):
for i in range(N):
A[i] = SCALE
with pytest.raises(TypeError, match="missing"):
k.specialize(N=8)
def test_specialize_extra_kwarg_raises():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
for i in range(N):
A[i] = 0
with pytest.raises(TypeError, match="unexpected"):
k.specialize(N=8, BOGUS=42)
def test_jit_kernel_with_nested_inline_helper():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
@T.inline
def double(x):
A[x] = A[x] * 2
for i in range(N):
double(i)
@T.prim_func(private=True)
def expected(
A: T.Buffer((4,), "int32"),
):
for i in range(4):
A[i] = A[i] * 2
assert_structural_equal(k.specialize(N=4), expected, map_free_vars=True)
def test_constexpr_default_value():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
SCALE: T.constexpr = 7,
):
for i in range(N):
A[i] = SCALE
@T.prim_func(private=True)
def expected(
A: T.Buffer((8,), "int32"),
):
for i in range(8):
A[i] = 7
assert_structural_equal(k.specialize(N=8), expected, map_free_vars=True)
# Override the default
overridden = k.specialize(N=8, SCALE=99)
assert k.specialize(N=8) is not overridden
def test_specialize_returns_primfunc():
@T.jit(private=True)
def k(
A: T.Buffer((N,), "int32"),
*,
N: T.constexpr,
):
for i in range(N):
A[i] = 0
spec = k.specialize(N=8)
assert isinstance(spec, tvm.tirx.PrimFunc)
# Specialized PrimFunc has only the runtime params (constexpr stripped).
assert len(spec.params) == 1
if __name__ == "__main__":
pytest.main([__file__, "-v"])