blob: 8930159481cb5d00dcc40599da039b052c3156c7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Codegen tests for VLA extensions
"""
import re
import pytest
import tvm
from tvm import te
from tvm.script import tir as T
from tvm.target.codegen import llvm_version_major
@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
@tvm.testing.parametrize_targets(
"llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_codegen_vscale(target):
vscale = tvm.tir.vscale()
@T.prim_func
def main(A: T.Buffer((5,), "int32")):
for i in range(5):
A[i] = 2 * vscale
with tvm.target.Target(target):
build_mod = tvm.tir.build(main)
llvm = build_mod.inspect_source()
assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
@tvm.testing.parametrize_targets(
"llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_scalable_buffer_load_store(target):
@T.prim_func
def my_func(a: T.handle, b: T.handle):
A = T.match_buffer(a, (128,), "float32")
B = T.match_buffer(b, (128,), "float32")
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
with tvm.target.Target(target):
mod = tvm.tir.build(my_func)
llvm = mod.inspect_source("ll")
assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load in generated LLVM."
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."
@pytest.mark.skipif(
llvm_version_major() < 11, reason="Vscale is not supported in earlier versions of LLVM"
)
@tvm.testing.parametrize_targets(
"llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_scalable_broadcast(target):
@T.prim_func
def my_func(a: T.handle):
A = T.match_buffer(a, (128,), "float32")
T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
with tvm.target.Target(target):
mod = tvm.tir.build(my_func)
llvm = mod.inspect_source("ll")
assert re.findall(
r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x float>", llvm
), "No scalable broadcast in generated LLVM."
assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable store in generated LLVM."
@pytest.mark.skip(
reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM",
)
@tvm.testing.parametrize_targets(
"llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_get_active_lane_mask(target):
@T.prim_func
def before(a: T.handle):
A = T.match_buffer(a, (30,), "int1")
for i in range(T.ceildiv(30, T.vscale() * 4)):
A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30)
with tvm.target.Target(target):
out = tvm.tir.build(before)
ll = out.inspect_source("ll")
assert "get.active.lane.mask" in ll
@pytest.mark.skip(
reason="Vscale and get.active.lane.mask are not supported in earlier versions of LLVM",
)
@tvm.testing.parametrize_targets(
"llvm -mtriple=aarch64-linux-gnu -mattr=+sve",
"llvm -device=riscv_cpu -mtriple=riscv64-linux-gnu -mcpu=generic-rv64 -mattr=+64bit,+a,+c,+d,+f,+m,+v",
)
def test_predicated_scalable_buffer(target):
@T.prim_func
def before(a: T.handle, b: T.handle):
A = T.match_buffer(a, (16,), "float32")
B = T.match_buffer(b, (16,), "float32")
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())):
for i_1 in T.vectorized(4 * T.vscale()):
if i_0 * 4 * T.vscale() + i_1 < 14:
B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0
with tvm.target.Target(target):
out = tvm.tir.build(before)
ll = out.inspect_source("ll")
assert "get.active.lane.mask" in ll
assert "llvm.masked.load" in ll
assert "llvm.masked.store" in ll
if __name__ == "__main__":
tvm.testing.main()