blob: f5c25ecded6dffdc323d2e245c2857c9e4d657db [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Codegen tests for AArch64
"""
import re
import pytest
import tvm
from tvm.script import tir as T, ir as I
from tvm.target.codegen import llvm_version_major
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_mul(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] * B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"mul\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_add(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] + B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"add\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_sub(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] - B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"sub\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_muladd(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_D: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
D = T.match_buffer(var_D, (m,), dtype=dtype)
for i in range(m):
with T.sblock("D"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i], C[v_i])
T.writes(D[v_i])
D[v_i] = A[v_i] * B[v_i] + C[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"mad|mla\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_max(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = T.max(A[v_i], B[v_i])
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
compare = re.findall(
r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly)
max_instr = re.findall(
r"max\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert (len(compare) > 1 and len(select) == len(compare)) or len(max_instr) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_min(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = T.min(A[v_i], B[v_i])
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
compare = re.findall(
r"cmgt\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
select = re.findall("sel\tz[0-9].[shdb], p[0-9], z[0-9].[shdb], z[0-9].[shdb]", assembly)
min_instr = re.findall(
r"min\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert (len(compare) > 1 and len(select) == len(compare)) or len(min_instr) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_div(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = tvm.tir.div(A[v_i], B[v_i])
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"div\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) >= 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"]
)
def test_mod(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = T.floormod(A[v_i], B[v_i])
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"mls\tz[0-9].[shdb],( p[0-9]/[m],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 0
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_eq(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), "bool")
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] == B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"cm(p)?eq\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype",
["float", "float16", "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"],
)
def test_neq(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), "bool")
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] != B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"cm(p)?(gt|ne)\tp[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"]
)
def test_or(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] | B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"orr\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"]
)
def test_and(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i], B[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] & B[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"and\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.parametrize(
"dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"]
)
def test_not(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i])
T.writes(C[v_i])
C[v_i] = ~A[v_i]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
matches = re.findall(
r"eor\tz[0-9].[shdb],( p[0-9]/[zm],)? z[0-9].[shdb], z[0-9].[shdb]", assembly
)
assert len(loads) > 1
assert len(matches) > 1
@pytest.mark.skipif(
llvm_version_major() < 15, reason="Test requires an LLVM version of at least 15 to target SVE"
)
@pytest.mark.xfail(
reason="Awaiting llvm support for gathered loads",
strict=True,
)
@pytest.mark.parametrize(
"dtype", ["uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"]
)
def test_memcpy(dtype):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,), dtype=dtype)
B = T.match_buffer(var_B, (m,), "int32")
C = T.match_buffer(var_C, (m,), dtype=dtype)
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(B[v_i], A[B[v_i]])
T.writes(C[v_i])
C[v_i] = A[B[v_i]]
with tvm.target.Target(target):
f = tvm.tir.build(Module)
assembly = f.inspect_source("asm")
loads = re.findall("ld1[whdb] { z", assembly)
assert len(loads) > 0
@pytest.mark.skipif(
llvm_version_major() < 13,
reason="Function attribute vscale_range() is not supported in earlier versions of LLVM",
)
@pytest.mark.parametrize(
"mattr,expect_attr",
[
("+neon", False),
("+sve", True),
("+v9a", True),
("+sme", True),
],
)
def test_vscale_range_function_attribute(mattr, expect_attr):
target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": [mattr]}
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_C: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int32()
A = T.match_buffer(var_A, (m,))
C = T.match_buffer(var_C, (m,))
for i in range(m):
with T.sblock("C"):
v_i = T.axis.spatial(m, i)
T.reads(A[v_i])
T.writes(C[v_i])
C[v_i] = A[v_i] + T.float32(1)
with tvm.target.Target(target):
f = tvm.tir.build(Module)
# Check if the vscale_range() attribute exists
ll = f.inspect_source("ll")
attr = re.findall(rf".*vscale_range\(\d+,\d+\)*.", ll)
if expect_attr:
assert (
len(attr) > 0
), f"Function attribute vscale_range() was not found in generated LLVM IR"
else:
assert (
len(attr) == 0
), f"Unexpected function attribute vscale_range() was found in generated LLVM IR"
if __name__ == "__main__":
tvm.testing.main()