blob: ecde549b4afa8ab17cdbbe2b64ec7b09669f0012 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement
import sys
from typing import Any
import pytest
import tvm.testing
from tvm.script import from_source
from tvm.script import tir as T
from tvm.tir.schedule.testing import assert_structural_equal_ignore_global_symbol
@T.prim_func
def transformed_matmul_no_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i0, i1])
vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner)
T.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
T.writes([C[vi, vj], A[vi, vk]])
with T.init():
C[vi, vj] = 0.0
A[vi, vk] = A[vi, vk] + B[vj, vk]
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
@T.prim_func
def transformed_matmul_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
B = T.match_buffer(b, [128, 128])
C = T.match_buffer(c, [128, 128])
for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4):
with T.block("update"):
vi, vj = T.axis.remap("SS", [i0, i1])
vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner)
T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
T.writes(C[vi, vj], A[vi, vk])
with T.init():
C[vi, vj] = 0.0
A[vi, vk] = A[vi, vk] + B[vj, vk]
C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk])
def test_reads_writes_syntax_sugar():
assert_structural_equal_ignore_global_symbol(
transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar
)
@T.prim_func
def loop_no_syntax_sugar(a: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
for i in T.serial(0, 128):
for j in T.parallel(0, 128):
for k in T.vectorized(0, 128):
for x in T.unroll(0, 128):
for y in T.thread_binding(0, 128, thread="threadIdx.x"):
for z in T.thread_binding(0, 128, thread="threadIdx.x"):
A[i, j, k, x] = A[i, j, k, x] * 2.0
@T.prim_func
def loop_syntax_sugar(a: T.handle) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
for i in T.serial(128):
for j in T.parallel(128):
for k in T.vectorized(128):
for x in T.unroll(128):
for y in T.thread_binding(128, "threadIdx.x"):
for z in T.thread_binding(128, thread="threadIdx.x"):
A[i, j, k, x] = A[i, j, k, x] * 2.0
def test_loop_syntax_sugar():
assert_structural_equal_ignore_global_symbol(loop_no_syntax_sugar, loop_syntax_sugar)
# match buffer - use kwargs
@T.prim_func
def elementwise_handle(
a: T.handle,
b: T.handle,
) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32"),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
# match buffer - use buffer without kwargs
@T.prim_func
def elementwise_buffer_no_kwargs(
a: T.Buffer((128, 128, 128, 128), "float32"),
b: T.Buffer((128, 128, 128, 128), "float32"),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0
def test_match_buffer_syntax_sugar():
# with kwargs
assert_structural_equal_ignore_global_symbol(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs
assert_structural_equal_ignore_global_symbol(elementwise_handle, elementwise_buffer_no_kwargs)
def test_match_buffer_1d():
@T.prim_func
def func_no_sugar(a: T.handle):
A = T.match_buffer(a, shape=(16,))
for i in T.serial(16):
A[i] = 0.0
@T.prim_func
def func_with_sugar(A: T.Buffer(16, "float32")):
for i in T.serial(16):
A[i] = 0.0
assert_structural_equal_ignore_global_symbol(func_no_sugar, func_with_sugar)
# dynamic shape gemm
@T.prim_func
def gemm_dyn_shape(a: T.handle, b: T.handle, c: T.handle):
N = T.int32()
M = T.int32()
K = T.int32()
A = T.match_buffer(a, (N, K), "float32")
B = T.match_buffer(b, (K, M), "float32")
C = T.match_buffer(c, (N, M), "float32")
for i, j, k in T.grid(N, M, K):
with T.block("gemm"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
def test_dynamic_shape_gemm():
gemm_dyn_shape_roundtrip = from_source(gemm_dyn_shape.script())
assert_structural_equal_ignore_global_symbol(gemm_dyn_shape, gemm_dyn_shape_roundtrip)
@T.prim_func
def match_buffer_int64(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (T.int64(128), T.int64(128)), dtype="float32")
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
C = T.match_buffer(c, (T.int64(128), T.int64(128)), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
@T.prim_func
def match_buffer_int64_after_roundtrip(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32"),
) -> None:
B = T.alloc_buffer((T.int64(128), T.int64(128)), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
def test_match_buffer_int64():
original = match_buffer_int64
after_roundtrip = match_buffer_int64_after_roundtrip
assert_structural_equal_ignore_global_symbol(original, after_roundtrip, True)
def test_match_buffer_region_has_implicit_shape_dtype():
@T.prim_func
def explicit_shape_dtype(A: T.Buffer((16, 64), "int32")):
with T.block():
B = T.match_buffer(A[8:16, 32:64], shape=(8, 32), dtype="int32")
T.evaluate(0)
@T.prim_func
def implicit_shape_dtype(A: T.Buffer((16, 64), "int32")):
with T.block():
B = T.match_buffer(A[8:16, 32:64])
T.evaluate(0)
assert_structural_equal_ignore_global_symbol(explicit_shape_dtype, implicit_shape_dtype)
def test_match_buffer_input_requires_shape_arg():
with pytest.raises(tvm.error.DiagnosticError):
@T.prim_func
def func(a: T.handle):
A = T.match_buffer(a, dtype="int32")
T.evaluate(0)
def test_letstmt_bufferload_without_type_annotation():
# Variable assignment of PrimExpr types uses the dtype of the
# PrimExpr to determine the variable's dtype. Parsing of
# buf[indices] is done by generating a BufferSlice object, which
# handles both store and load cases. BufferSlice is not a
# PrimExpr, and implements BufferSlice.dtype explicitly.
# Failure occurred during parsing of the tvmscript.
@T.prim_func
def func_without_type_annotation(A: T.Buffer((1,), "int32")):
x = A[0]
T.evaluate(x)
def test_letstmt_bind_with_constant():
@T.prim_func
def constant_binds():
x = T.meta_var(1)
y = T.meta_var(42.0)
T.evaluate(T.cast(x, "float32") + y)
@T.prim_func
def constant_binds_wrapped():
x = T.meta_var(T.int32(1))
y = T.meta_var(T.float32(42.0))
T.evaluate(T.cast(x, "float32") + y)
assert_structural_equal_ignore_global_symbol(constant_binds, constant_binds_wrapped)
def test_func_call():
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
thread_id = (i % 8) * 4 + (j % 8) // 2
return T.meta_var((thread_id, (j // 8) * 4 + (i // 8) * 2 + (j % 2)))
@T.prim_func
def mma_sync_m16n16k16_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)
T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])
C[thread_id_C, local_id_C] += (
A[thread_id_A, local_id_A] * B[thread_id_B, local_id_B]
)
@T.prim_func
def mma_sync_m16n16k16_desc_manual(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float16", align=64, offset_factor=16, scope="warp")
with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
T.reads(
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2],
A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2],
B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2],
)
T.writes(C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2])
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2] = (
C[i % 8 * 4 + j % 8 // 2, j // 8 * 4 + i // 8 * 2 + j % 2]
+ A[i % 8 * 4 + k % 8 // 2, k // 8 * 4 + i // 8 * 2 + k % 2]
* B[k % 8 * 4 + j % 8 // 2, j // 8 * 4 + k // 8 * 2 + j % 2]
)
assert_structural_equal_ignore_global_symbol(
mma_sync_m16n16k16_desc, mma_sync_m16n16k16_desc_manual
)
# The following is an example of an error message from calling an invalid function
# error: Error occurred when invoking the function sqrt:
# loop of ufunc does not support argument 0 of type Var which has no callable sqrt method
# --> test_tvmscript_syntax_sugar.py:334:19
# |
# 334 | ind = sqrt(i)
# | ^^^^^^^
# note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.
# Uncomment to see the error above.
# def sqrt(x):
# import numpy as np
# return np.sqrt(x)
# @T.prim_func
# def loop(a: T.handle) -> None:
# A = T.match_buffer(a, (128,))
# for i in T.serial(128):
# ind = sqrt(i)
# A[i] = A[ind]
def test_int64_loop():
@T.prim_func
def int64_grid(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
) -> None:
for i, j in T.grid(T.int64(128), T.int64(128)):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + 1.0
@T.prim_func
def int64_grid_expanded(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
) -> None:
for i in range(T.int64(0), T.int64(128)):
for j in range(T.int64(0), T.int64(128)):
with T.block("C"):
vi = T.axis.spatial(T.int64(128), i)
vj = T.axis.spatial(T.int64(128), j)
B[vi, vj] = A[vi, vj] + 1.0
assert_structural_equal_ignore_global_symbol(int64_grid, int64_grid_expanded)
def test_implicit_evaluate_assume():
@T.prim_func
def explicit(A: T.Buffer(1, "int32")):
T.evaluate(T.assume(A[0] == 5))
A[0] = 10
@T.prim_func
def implicit(A: T.Buffer(1, "int32")):
T.assume(A[0] == 5)
A[0] = 10
assert_structural_equal_ignore_global_symbol(implicit, explicit)
def test_implicit_evaluate_call_extern():
@T.prim_func
def explicit(A: T.Buffer(1, "int32")):
T.evaluate(T.call_extern("extern_func", A.data, dtype="int32"))
@T.prim_func
def implicit(A: T.Buffer(1, "int32")):
T.call_extern("extern_func", A.data, dtype="int32")
assert_structural_equal_ignore_global_symbol(implicit, explicit)
def test_preserve_trivial_let_binding():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.LetStmt(i, var=j)
T.evaluate(j)
@T.prim_func
def implicit(i: T.int32):
j = i
T.evaluate(j)
assert_structural_equal_ignore_global_symbol(implicit, explicit)
def test_preserve_trivial_let_binding_of_value():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.LetStmt(42, var=j)
T.evaluate(j)
@T.prim_func
def implicit(i: T.int32):
j = 42
T.evaluate(j)
assert_structural_equal_ignore_global_symbol(implicit, explicit)
def test_preserve_parameter_name():
@T.prim_func
def func(i: T.int32):
j = i
T.evaluate(j)
param_name = func.params[0].name
assert param_name == "i"
def test_preserve_variable_name():
"""Use variable name when generating tir::LetStmt"""
@T.prim_func
def func():
for i in T.serial(16):
j = i // 4
T.evaluate(j)
var_name = func.body.body.var.name
assert var_name == "j"
def test_boolean_constant():
"""Python booleans should become T.Bool objects"""
@T.prim_func
def explicit():
T.evaluate(T.bool(True))
@T.prim_func
def implicit():
T.evaluate(True)
assert_structural_equal_ignore_global_symbol(implicit, explicit)
def test_foldable_boolean_in_assert():
"""Foldable booleans T.Bool objects
The condition of an assert statement should be a boolean
expression. Previously, this test failed because the FFI does not
distinguish between integer primitives and boolean primitives.
"""
@T.prim_func
def explicit():
assert T.bool(False), "Message"
T.evaluate(0)
@T.prim_func
def implicit():
assert 0 == 1, "Message"
T.evaluate(0)
assert_structural_equal_ignore_global_symbol(implicit, explicit)
if __name__ == "__main__":
tvm.testing.main()