blob: 789d6be3ad0bf43c90c5ebebe6a27f515dc21027 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-function-docstring,missing-module-docstring
import pytest
import tvm
import tvm.testing
from tvm import te, tir
from tvm.script import tir as T
from tvm.tir.schedule.testing import (
assert_structural_equal_ignore_global_symbol,
verify_trace_roundtrip,
)
from tvm.tir.tensor_intrin.arm_cpu import (
DP4A_S8S8S32_INTRIN,
DP4A_U8U8U32_INTRIN,
DP4A_U8S8S32_INTRIN,
DP4A_S8U8S32_INTRIN,
ARM_DOT_4x4_i8_NEON_INTRIN,
ARM_DOT_4x4_i8_SDOT_INTRIN,
)
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN, AVX512_DOT_16x4_INTRIN
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN, VDMPY_i16i16i32_INTRIN
# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
@T.prim_func
def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=64, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=64, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=64, offset_factor=1)
with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block("update"):
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
@T.prim_func
def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=64, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=64, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=64, offset_factor=1)
with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
T.evaluate(
T.tvm_mma_sync(
C.data,
C.elem_offset // 256,
A.data,
A.elem_offset // 256,
B.data,
B.elem_offset // 256,
C.data,
C.elem_offset // 256,
dtype="handle",
)
)
@T.prim_func
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,))
B = T.match_buffer(b, (4,))
C = T.match_buffer(c, ())
with T.block("root"):
T.reads(C[()], A[0 : 4], B[0 : 4])
T.writes(C[()])
for i in range(0, 4):
with T.block("update"):
vi = T.axis.remap("R", [i])
C[()] = C[()] + A[vi] * B[vi]
@T.prim_func
def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), offset_factor=1)
B = T.match_buffer(b, (4,), offset_factor=1)
C = T.match_buffer(c, (), offset_factor=1)
with T.block("root"):
T.reads(C[()], A[0 : 4], B[0 : 4])
T.writes(C[()])
T.evaluate(
T.call_extern(
"vec4add",
C.data,
C.elem_offset,
A.data,
A.elem_offset,
B.data,
B.elem_offset,
dtype="int32",
)
)
@T.prim_func
def dot_product_intrin_annotated(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), offset_factor=1)
B = T.match_buffer(b, (4,), offset_factor=1)
C = T.match_buffer(c, (), offset_factor=1)
with T.block("root"):
T.reads(C[()], A[0 : 4], B[0 : 4])
T.writes(C[()])
T.block_attr({"test_annotation": True})
T.evaluate(
T.call_extern(
"vec4add",
C.data,
C.elem_offset,
A.data,
A.elem_offset,
B.data,
B.elem_offset,
dtype="int32",
)
)
@T.prim_func
def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 1), offset_factor=1)
B = T.match_buffer(b, (16, 1), offset_factor=1)
C = T.match_buffer(c, (16, 16), offset_factor=1)
with T.block("root"):
T.reads(
C[0 : 16, 0 : 16],
A[0 : 16, 0 : 1],
B[0 : 16, 0 : 1],
)
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("update"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]
@T.prim_func
def outer_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 1), offset_factor=1)
B = T.match_buffer(b, (16, 1), offset_factor=1)
C = T.match_buffer(c, (16, 16), offset_factor=1)
with T.block("root"):
T.reads(
C[0 : 16, 0 : 16],
A[0 : 16, 0 : 1],
B[0 : 16, 0 : 1],
)
T.writes(C[0 : 16, 0 : 16])
T.evaluate(
T.call_extern(
"outer_product",
C.data,
C.elem_offset,
A.data,
A.elem_offset,
B.data,
B.elem_offset,
dtype="int32",
)
)
@T.prim_func
def matmul(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
) -> None:
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1)
B = T.match_buffer(b, [128, 128], elem_offset=0, align=64, offset_factor=1)
A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1)
for i_outer, j_outer in T.grid(8, 8):
for i_inner_init, j_inner_init in T.grid(16, 16):
with T.block("init"):
vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
C[vi_init, vj_init] = T.float32(0)
for k_outer in T.grid(8):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
T.reads(
[
C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
]
)
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A_elem_offset = T.int32()
B_elem_offset = T.int32()
C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
elem_offset=A_elem_offset,
)
B_sub = T.match_buffer(
B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
elem_offset=B_elem_offset,
)
C_sub = T.match_buffer(
C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
[16, 16],
elem_offset=C_elem_offset,
)
T.evaluate(
T.tvm_mma_sync(
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
A_sub.data,
T.floordiv(A_sub.elem_offset, 256),
B_sub.data,
T.floordiv(B_sub.elem_offset, 256),
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
dtype="handle",
)
)
@T.prim_func
def batch_matmul(
A: T.Buffer((16, 128, 128), "float32"),
B: T.Buffer((16, 128, 128), "float32"),
C: T.Buffer((16, 128, 128), "float32"),
) -> None:
for n, i, j in T.grid(16, 128, 128):
with T.block("init"):
vn, vi, vj = T.axis.remap("SSS", [n, i, j])
C[vn, vi, vj] = T.float32(0)
for n, i, j, k in T.grid(16, 128, 128, 128):
with T.block("update"):
vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
C[vn, vi, vj] = C[vn, vi, vj] + A[vn, vi, vk] * B[vn, vj, vk]
@T.prim_func
def tensorized_batch_matmul_mma(
A: T.Buffer((16, 128, 128), "float32"),
B: T.Buffer((16, 128, 128), "float32"),
C: T.Buffer((16, 128, 128), "float32"),
) -> None:
for n, i, j in T.grid(16, 128, 128):
with T.block("init"):
vn, vi, vj = T.axis.remap("SSS", [n, i, j])
T.reads()
T.writes(C[vn, vi, vj])
C[vn, vi, vj] = T.float32(0)
for n in range(0, 16):
for i, j, k in T.grid(8, 8, 8):
with T.block("update"):
vn, vi, vj, vk = T.axis.remap("SSSR", [n, i, j, k])
T.reads(
C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
)
T.writes(C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A_elem_offset = T.int32()
B_elem_offset = T.int32()
C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vn : vn + 1, vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
(16, 16),
elem_offset=A_elem_offset,
)
B_sub = T.match_buffer(
B[vn : vn + 1, vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
(16, 16),
elem_offset=B_elem_offset,
)
C_sub = T.match_buffer(
C[vn : vn + 1, vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
(16, 16),
elem_offset=C_elem_offset,
)
T.evaluate(
T.tvm_mma_sync(
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
A_sub.data,
T.floordiv(A_sub.elem_offset, 256),
B_sub.data,
T.floordiv(B_sub.elem_offset, 256),
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
dtype="handle",
)
)
@T.prim_func
def tensorized_batch_matmul_dot_product(
A: T.Buffer((16, 128, 128), "float32"),
B: T.Buffer((16, 128, 128), "float32"),
C: T.Buffer((16, 128, 128), "float32"),
) -> None:
for n, i, j in T.grid(16, 128, 128):
with T.block("init"):
vn, vi, vj = T.axis.remap("SSS", [n, i, j])
T.reads()
T.writes(C[vn, vi, vj])
C[vn, vi, vj] = T.float32(0)
for n, i, j, k_0 in T.grid(16, 128, 128, 32):
with T.block("blockized_update"):
vn, vi, vj, vko = T.axis.remap("SSSR", [n, i, j, k_0])
T.reads(
C[vn, vi, vj], A[vn, vi, vko * 4 : vko * 4 + 4], B[vn, vj, vko * 4 : vko * 4 + 4]
)
T.writes(C[vn, vi, vj])
A_1 = T.match_buffer(
A[vn, vi, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1
)
B_1 = T.match_buffer(
B[vn, vj, vko * 4 : vko * 4 + 4], [4], dtype="float32", offset_factor=1
)
C_1 = T.match_buffer(C[vn, vi, vj], [], dtype="float32", offset_factor=1)
T.evaluate(
T.call_extern(
"vec4add",
C_1.data,
C_1.elem_offset,
A_1.data,
A_1.elem_offset,
B_1.data,
B_1.elem_offset,
dtype="int32",
)
)
@T.prim_func
def tensorized_batch_matmul_outer_product(
A: T.Buffer((16, 128, 128), "float32"),
B: T.Buffer((16, 128, 128), "float32"),
C: T.Buffer((16, 128, 128), "float32"),
) -> None:
for n, i, j in T.grid(16, 128, 128):
with T.block("init"):
vn, vi, vj = T.axis.remap("SSS", [n, i, j])
T.reads()
T.writes(C[vn, vi, vj])
C[vn, vi, vj] = T.float32(0)
for n, i_0, j_0, k in T.grid(16, 8, 8, 128):
with T.block("blockized_update"):
vn, vio, vjo, vk = T.axis.remap("SSSR", [n, i_0, j_0, k])
T.reads(
C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16],
A[vn, vio * 16 : vio * 16 + 16, vk],
B[vn, vjo * 16 : vjo * 16 + 16, vk],
)
T.writes(C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16])
A_1 = T.match_buffer(A[vn, vio * 16 : vio * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1)
B_1 = T.match_buffer(B[vn, vjo * 16 : vjo * 16 + 16, vk], [16, 1], dtype="float32", offset_factor=1
)
C_1 = T.match_buffer(
C[vn, vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16], [16, 16], dtype="float32", offset_factor=1
)
T.evaluate(
T.call_extern("outer_product", C_1.data, C_1.elem_offset, A_1.data, A_1.elem_offset,
B_1.data, B_1.elem_offset, dtype="int32"
)
)
@T.prim_func
def annotated_mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), align=64, offset_factor=1)
B = T.match_buffer(b, (16, 16), align=64, offset_factor=1)
C = T.match_buffer(c, (16, 16), align=64, offset_factor=1)
with T.block("root"):
T.reads(C[0 : 16, 0 : 16], A[0 : 16, 0 : 16], B[0 : 16, 0 : 16])
T.writes(C[0 : 16, 0 : 16])
for i, j, k in T.grid(16, 16, 16):
with T.block("update"):
T.block_attr({"test_annotation": True})
vii, vjj, vkk = T.axis.remap("SSR", [i, j, k])
C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk]
@T.prim_func
def annotated_matmul(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32"),
) -> None:
for i, j, k in T.grid(128, 128, 128):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
T.block_attr({"test_annotation": True})
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C = T.match_buffer(c, [128, 128], elem_offset=0, align=64, offset_factor=1)
B = T.match_buffer(b, [128, 128], elem_offset=0, align=64, offset_factor=1)
A = T.match_buffer(a, [128, 128], elem_offset=0, align=64, offset_factor=1)
for i_outer, j_outer in T.grid(8, 8):
for i_inner_init, j_inner_init in T.grid(16, 16):
with T.block("init"):
vi_init = T.axis.S(128, ((i_outer * 16) + i_inner_init))
vj_init = T.axis.S(128, ((j_outer * 16) + j_inner_init))
T.block_attr({"test_annotation": True})
C[vi_init, vj_init] = T.float32(0)
for k_outer in T.grid(8):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
T.reads(
[
C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
]
)
T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16])
A_elem_offset = T.int32()
B_elem_offset = T.int32()
C_elem_offset = T.int32()
A_sub = T.match_buffer(
A[vi * 16 : vi * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
elem_offset=A_elem_offset,
)
B_sub = T.match_buffer(
B[vj * 16 : vj * 16 + 16, vk * 16 : vk * 16 + 16],
[16, 16],
elem_offset=B_elem_offset,
)
C_sub = T.match_buffer(
C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16],
[16, 16],
elem_offset=C_elem_offset,
)
T.evaluate(
T.tvm_mma_sync(
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
A_sub.data,
T.floordiv(A_sub.elem_offset, 256),
B_sub.data,
T.floordiv(B_sub.elem_offset, 256),
C_sub.data,
T.floordiv(C_sub.elem_offset, 256),
dtype="handle",
)
)
# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
tir.TensorIntrin.register("test_annotated_mma_intrin", annotated_mma_desc, mma_intrin)
tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin)
tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin)
tir.TensorIntrin.register("test_dot_product_intrin_annotated", dot_product_desc, dot_product_intrin_annotated)
def test_tensorize_matmul():
func = matmul
# schedule
s = tir.Schedule(func, debug_mask="all")
update = s.get_block("update")
i, j, k = s.get_loops(update)
io, ii = s.split(i, factors=[None, 16])
jo, ji = s.split(j, factors=[None, 16])
ko, ki = s.split(k, factors=[None, 16])
s.reorder(io, jo, ko, ii, ji, ki)
s.decompose_reduction(update, ko)
s.tensorize(ii, "test_mma_intrin")
assert_structural_equal_ignore_global_symbol(tensorized_matmul, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
def test_tensorize_batch_matmul():
func = batch_matmul
s = tir.Schedule(func, debug_mask="all")
update = s.get_block("update")
_, i, j, k = s.get_loops(update)
io, ii = s.split(i, factors=[None, 16])
jo, ji = s.split(j, factors=[None, 16])
ko, ki = s.split(k, factors=[None, 16])
s.reorder(io, jo, ko, ii, ji, ki)
s.tensorize(ii, "test_mma_intrin")
assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_mma, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=batch_matmul)
def test_tensorize_dot_product():
func = batch_matmul
s = tir.Schedule(func, debug_mask="all")
C = s.get_block("update")
_, _, _, k = s.get_loops(C)
_, ki = s.split(k, factors=[None, 4])
s.tensorize(ki, "test_dot_product_intrin")
assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_dot_product, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
def test_tensorize_outer_product():
func = batch_matmul
s = tir.Schedule(func, debug_mask="all")
C = s.get_block("update")
_, i, j, k = s.get_loops(C)
io, ii = s.split(i, factors=[None, 16])
jo, ji = s.split(j, factors=[None, 16])
s.reorder(io, jo, k, ii, ji)
s.tensorize(ii, "test_outer_product_intrin")
assert_structural_equal_ignore_global_symbol(tensorized_batch_matmul_outer_product, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
def test_tensorize_with_annotation():
func = annotated_matmul
s = tir.Schedule(func, debug_mask="all")
update = s.get_block("update")
i, j, k = s.get_loops(update)
io, ii = s.split(i, factors=[None, 16])
jo, ji = s.split(j, factors=[None, 16])
ko, ki = s.split(k, factors=[None, 16])
s.reorder(io, jo, ko, ii, ji, ki)
s.decompose_reduction(update, ko)
s.tensorize(ii, "test_annotated_mma_intrin")
assert_structural_equal_ignore_global_symbol(annotated_tensorized_matmul, s.mod["main"])
verify_trace_roundtrip(sch=s, mod=func)
def test_tensorize_intrinsic_with_annotation():
func = matmul
s = tir.Schedule(func, debug_mask="all")
update = s.get_block("update")
_, _, k = s.get_loops(update)
ko, ki = s.split(k, factors=[None, 4])
s.decompose_reduction(update, ko)
s.tensorize(ki, "test_dot_product_intrin_annotated")
b = s.get(s.get_block("update_update_o"))
assert b.annotations["test_annotation"] == T.bool(True)
verify_trace_roundtrip(sch=s, mod=func)
def get_matmul_packed(m, n, k, lhs_type, rhs_dtype="int8"):
X = te.placeholder((m, k), name="X", dtype=lhs_type)
W = te.placeholder((n, k), name="W", dtype=rhs_dtype)
ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32") * W[j, ak].astype("int32"),
axis=ak,
),
name="compute",
)
return te.create_prim_func([X, W, matmul])
def tensorize_16x4_test(intrin=VNNI_DOT_16x4_INTRIN):
m, n, k = 128, 128, 128
func = get_matmul_packed(m, n, k, "uint8")
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//16, j//4, i%16, j%4])
_, j, k = sch.get_loops(block)
_, ji = sch.split(j, factors=[None, 16])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)
sch.decompose_reduction(block, ko)
sch.tensorize(ji, intrin)
verify_trace_roundtrip(sch=sch, mod=func)
def test_tensorize_vnni():
tensorize_16x4_test()
def test_tensorize_avx512():
tensorize_16x4_test(AVX512_DOT_16x4_INTRIN)
def test_tensorize_arm_dot():
m, n, k = 128, 128, 128
func = get_matmul_packed(m, n, k, "int8")
for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]:
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//4, j//4, i%4, j%4])
_, j, k = sch.get_loops(block)
_, ji = sch.split(j, factors=[None, 4])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)
sch.decompose_reduction(block, ko)
sch.tensorize(ji, intrin)
verify_trace_roundtrip(sch=sch, mod=func)
def test_tensorize_vrmpy():
m, n, k = 128, 128, 128
func = get_matmul_packed(m, n, k, "uint8", "uint8")
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//32, j//4, i%32, j%4])
_, j, k = sch.get_loops(block)
_, ji = sch.split(j, factors=[None, 32])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)
sch.decompose_reduction(block, ko)
sch.tensorize(ji, VRMPY_u8u8i32_INTRIN)
verify_trace_roundtrip(sch=sch, mod=func)
def test_tensorize_vdmpy():
m, n, k = 128, 128, 128
func = get_matmul_packed(m, n, k, "int16", "int16")
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//32, j//2, i%32, j%2])
_, j, k = sch.get_loops(block)
_, ji = sch.split(j, factors=[None, 32])
ko, ki = sch.split(k, factors=[None, 2])
sch.reorder(ko, ji, ki)
sch.decompose_reduction(block, ko)
sch.tensorize(ji, VDMPY_i16i16i32_INTRIN)
verify_trace_roundtrip(sch=sch, mod=func)
def test_tensorize_dp4a():
# pylint: disable=too-many-locals
def _test_intrin(dtype_a, dtype_b, dtype_c, intrin):
m, n, k = 128, 128, 128
X = te.placeholder((m, k), name="X", dtype=dtype_a)
W = te.placeholder((n, k), name="W", dtype=dtype_b)
ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype(dtype_c) * W[j, ak].astype(dtype_c),
axis=ak,
),
name="compute",
)
func = te.create_prim_func([X, W, matmul])
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
i, j, k = sch.get_loops(block)
by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3))
bx, tx, xi = sch.split(j, factors=sch.sample_perfect_tile(j, n=3))
ko, ki = sch.split(k, [None, 4])
ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
sch.reorder(by, bx, ty, tx, yi, xi)
CC = sch.cache_write(block, 0, "local")
sch.reverse_compute_at(CC, tx)
def fetch_to_shared(block, idx):
block_read = sch.cache_read(block, idx, "shared")
sch.compute_at(block_read, ko, True)
return block_read
fetch_to_shared(block, 0)
fetch_to_shared(block, 1)
sch.decompose_reduction(block, ko)
sch.tensorize(ki, intrin)
verify_trace_roundtrip(sch=sch, mod=func)
for args in [
("int8", "int8", "int32", AMDGPU_SDOT4_INTRIN),
("int8", "int8", "int32", DP4A_S8S8S32_INTRIN),
("int8", "uint8", "int32", DP4A_S8U8S32_INTRIN),
("uint8", "int8", "int32", DP4A_U8S8S32_INTRIN),
("uint8", "uint8", "uint32", DP4A_U8U8U32_INTRIN),
]:
_test_intrin(*args)
def test_tensor_intrin_look_up():
intrin_name = 'non_existent_intrin'
assert tir.TensorIntrin.get(intrin_name, allow_missing=True) is None
with pytest.raises(ValueError):
tir.TensorIntrin.get(intrin_name)
def test_tensorize_matmul_mixed_dtype():
# fmt: off
@T.prim_func
def matmul_int64_shape(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32")
) -> None:
for i_0, j_0 in T.grid(T.int64(8), T.int64(8)):
for i_1_init, j_1_init in T.grid(T.int64(16), T.int64(16)):
with T.block("init"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1_init)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1_init)
C[vi, vj] = T.float32(0)
for k_0, i_1, j_1, k_1 in T.grid(T.int64(8), T.int64(16), T.int64(16), T.int64(16)):
with T.block("update"):
vi = T.axis.spatial(T.int64(128), i_0 * T.int64(16) + i_1)
vj = T.axis.spatial(T.int64(128), j_0 * T.int64(16) + j_1)
vk = T.axis.reduce(T.int64(128), k_0 * T.int64(16) + k_1)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
@T.prim_func
def tensorized_matmul_int64_shape(
A: T.Buffer((T.int64(128), T.int64(128)), "float32"),
B: T.Buffer((T.int64(128), T.int64(128)), "float32"),
C: T.Buffer((T.int64(128), T.int64(128)), "float32")
) -> None:
for i_outer, j_outer in T.grid(T.int64(8), T.int64(8)):
for i_inner_init, j_inner_init in T.grid(T.int64(16), T.int64(16)):
with T.block("init"):
vi = T.axis.spatial(T.int64(128), i_outer * T.int64(16) + i_inner_init)
vj = T.axis.spatial(T.int64(128), j_outer * T.int64(16) + j_inner_init)
C[vi, vj] = T.float32(0)
for k_outer in T.grid(T.int64(8)):
with T.block("update"):
vi, vj, vk = T.axis.remap("SSR", [i_outer, j_outer, k_outer])
T.reads(
[
C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)],
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
]
)
T.writes(C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)])
A_elem_offset = T.int64()
B_elem_offset = T.int64()
C_elem_offset = T.int64()
A_sub = T.match_buffer(
A[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
[T.int64(16), T.int64(16)],
elem_offset=A_elem_offset,
)
B_sub = T.match_buffer(
B[vj * T.int64(16) : vj * T.int64(16) + T.int64(16), vk * T.int64(16) : vk * T.int64(16) + T.int64(16)],
[T.int64(16), T.int64(16)],
elem_offset=B_elem_offset,
)
C_sub = T.match_buffer(
C[vi * T.int64(16) : vi * T.int64(16) + T.int64(16), vj * T.int64(16) : vj * T.int64(16) + T.int64(16)],
[T.int64(16), T.int64(16)],
elem_offset=C_elem_offset,
)
T.evaluate(
T.tvm_mma_sync(
C_sub.data,
T.floordiv(C_sub.elem_offset, T.int64(256)),
A_sub.data,
T.floordiv(A_sub.elem_offset, T.int64(256)),
B_sub.data,
T.floordiv(B_sub.elem_offset, T.int64(256)),
C_sub.data,
T.floordiv(C_sub.elem_offset, T.int64(256)),
dtype="handle",
)
)
# fmt: on
s = tir.Schedule(matmul_int64_shape, debug_mask="all")
update = s.get_block("update")
ii = s.get_loops(update)[-3]
s.tensorize(ii, "test_mma_intrin")
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape)
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)
def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)
def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))
return f_convert
@T.prim_func
def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)
with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
for i in T.grid(8):
with T.block("decode"):
vi = T.axis.remap("S", [i])
Decompressed[vi] = _tir_packed_int_to_int_to_float(32)(
4,
Compressed[vi // 8],
vi % 8,
dtype="float16",
)
@T.prim_func
def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)
with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
T.call_extern(
"handle",
"test_decode_i4s_to_f16",
Compressed.data,
Decompressed.data,
8,
)
tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl)
def test_tensorize_arith_simplification():
# fmt: off
@T.prim_func
def decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 8):
with T.block("B_decode_local"):
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
T.reads(B_local[v0, v1 // 8])
T.writes(B_decode_local[v0, v1])
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))
@T.prim_func
def tensorized_decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0 in range(1):
with T.block("B_decode_local_o"):
v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1)
T.reads(B_local[v0_o, v1_o])
T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8])
Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local")
Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local")
T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8)
s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all")
update = s.get_block("B_decode_local")
ii = s.get_loops(update)[-1]
s.tensorize(ii, "test_decode_i4s_to_f16_intrin")
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16)
verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16)
if __name__ == "__main__":
tvm.testing.main()