# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
# ruff: noqa: E501, F401, F841

import re

import pytest

import tvm.testing
from tvm import ir, tirx
from tvm.ir import Range
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tirx as T


def _assert_print(obj, expected):
    assert obj.script(verbose_expr=True).strip() == expected.strip()


def test_prim_func():
    a = tirx.Var("a", "handle")
    b = tirx.Var("b", "handle")
    func = tirx.PrimFunc(
        params=[a, b],
        ret_type=None,
        buffer_map={
            a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
            b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
        },
        body=tirx.Evaluate(0),
    ).with_attr("global_symbol", "main")
    _assert_print(
        func,
        expected="""
# from tvm.script import tirx as T

@T.prim_func
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    T.evaluate(0)""",
    )


def test_prim_func_no_sugar_inlined_buffer():
    a = tirx.Var("a", "handle")
    b = tirx.Var("b", "handle")
    func = tirx.PrimFunc(
        params=[a, b],
        ret_type=None,
        buffer_map={
            a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
            b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
        },
        body=tirx.Evaluate(a),
    ).with_attr("global_symbol", "main")
    _assert_print(
        func,
        expected="""
# from tvm.script import tirx as T

@T.prim_func
def main(a: T.handle, B: T.Buffer((256, 256), "float32")):
    A = T.match_buffer(a, (128, 128))
    T.evaluate(a)
""",
    )


def test_prim_func_no_sugar_shared_buffer_data():
    a = tirx.Var("a", "handle")
    b = tirx.Var("b", "handle")
    buffer_data = tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A").data
    func = tirx.PrimFunc(
        params=[a, b],
        ret_type=None,
        buffer_map={
            a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data),
            b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data),
        },
        body=tirx.Evaluate(0),
    ).with_attr("global_symbol", "main")
    _assert_print(
        func,
        expected="""
# from tvm.script import tirx as T

@T.prim_func
def main(a: T.handle, b: T.handle):
    A = T.match_buffer(a, (128, 128))
    B = T.match_buffer(b, (256, 256), data=A.data)
    T.evaluate(0)
""",
    )


def test_block_realize():
    i = tirx.Var("i", "int32")
    j = tirx.Var("j", "int32")
    k = tirx.Var("k", "int32")
    with IRBuilder() as ib:
        with T.sblock(name="block", no_realize=False):
            vi = ib.name("vi", T.axis.spatial(128, i))
            vj = ib.name("vj", T.axis.spatial(64, j))
            vk = ib.name("vk", T.axis.reduce(32, k))
            T.reads()
            T.writes()
            T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
i = T.int32()
j = T.int32()
k = T.int32()
with T.sblock("block"):
    vi = T.axis.spatial(128, i)
    vj = T.axis.spatial(64, j)
    vk = T.axis.reduce(32, k)
    T.reads()
    T.writes()
    T.evaluate(0)""",
    )


def test_block():
    i = tirx.Var("i", "int32")
    j = tirx.Var("j", "int32")
    k = tirx.Var("k", "int32")
    with IRBuilder() as ib:
        with T.sblock(name="block", no_realize=False):
            vi = ib.name("vi", T.axis.spatial(128, i))
            vj = ib.name("vj", T.axis.spatial(64, j))
            vk = ib.name("vk", T.axis.reduce(32, k))
            T.reads()
            T.writes()
            T.evaluate(0)
    obj = ib.get().block
    _assert_print(
        obj,
        """
with T.sblock("block", no_realize=True):
    vi = T.axis.spatial(128)
    vj = T.axis.spatial(64)
    vk = T.axis.reduce(32)
    T.reads()
    T.writes()
    T.evaluate(0)""",
    )


def test_match_buffer_region():
    src = tirx.decl_buffer((128, 128), "float32", name="src")
    tgt = tirx.decl_buffer((64, 64), "float32", name="tgt")
    obj = tirx.MatchBufferRegion(
        tgt,
        tirx.BufferRegion(
            src,
            [
                Range(64, 128),
                Range(64, 128),
            ],
        ),
    )
    _assert_print(
        obj,
        """
src = T.Buffer((128, 128))
tgt = T.match_buffer(src[64:128, 64:128], (64, 64))
""",
    )


def test_buffer():
    a = tirx.decl_buffer((128, 128), "float16", name="A")
    _assert_print(
        a,
        """A = T.Buffer((128, 128), "float16")
A""",
    )


def test_buffer_region():
    src = tirx.decl_buffer((128, 128), "float32", name="src")
    obj = tirx.BufferRegion(
        src,
        [
            Range(64, 128),
            Range(64, 128),
        ],
    )
    _assert_print(
        obj,
        """
src = T.Buffer((128, 128))
src[64:128, 64:128]
""",
    )


def test_buffer_load():
    a = tirx.decl_buffer((128, 128), "float16", name="A")
    obj = tirx.BufferLoad(a, [128, 128])
    _assert_print(
        obj,
        """
A = T.Buffer((128, 128), "float16")
A[128, 128]
""",
    )


def test_buffer_store():
    a = tirx.decl_buffer((128, 128), "float16", name="A")
    with IRBuilder() as ib:
        T.buffer_store(a, a[128, 128] + 1, [128, 128])
    obj = ib.get()
    _assert_print(
        obj,
        """
A = T.Buffer((128, 128), "float16")
A[128, 128] = A[128, 128] + T.float16(1.0)
""",
    )


def test_for():
    with IRBuilder() as ib:
        with T.grid(128, 128, 128) as (i, j, k):
            ib.name_many(["i", "j", "k"], [i, j, k])
            T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
for i, j, k in T.grid(128, 128, 128):
    T.evaluate(0)
""",
    )


def test_bind():
    with IRBuilder() as ib:
        with T.prim_func():
            v = T.bind(T.float32(10))
            ib.name("v", v)
            T.evaluate(1)
    obj = ib.get()
    _assert_print(
        obj,
        """
# from tvm.script import tirx as T

@T.prim_func(private=True)
def main():
    v: T.float32 = T.float32(10.0)
    T.evaluate(1)
""",
    )


def test_attr_stmt():
    with IRBuilder() as ib:
        with T.attr("pragma", "unroll", 1):
            T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
with T.attr("pragma", "unroll", 1):
    T.evaluate(0)
""",
    )


def test_assert_stmt():
    with IRBuilder() as ib:
        with T.Assert(True, "assertion"):
            T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
assert T.bool(True), ("RuntimeError", ["assertion"])
T.evaluate(0)
""",
    )


def test_while():
    with IRBuilder() as ib:
        x = T.int32()
        with T.While(x < 10):
            T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
v = T.int32()
while v < 10:
    T.evaluate(0)
""",
    )


def test_allocate():
    with IRBuilder() as ib:
        with T.prim_func():
            T.func_name("test")
            buf = T.alloc_buffer([128, 128], "float32")
            T.evaluate(1)
    obj = ib.get()
    _assert_print(
        obj.body,
        """
buffer = T.alloc_buffer((128, 128))
T.evaluate(1)
""",
    )


def test_allocate_with_decl_buffer_sugar():
    # AllocBuffer and DeclBuffer are flat siblings
    with IRBuilder() as ib:
        with T.prim_func():
            T.func_name("test")
            buf = T.alloc_buffer([128, 128], "float32")
            buf2 = T.decl_buffer([128, 128], "float32", data=buf.data)
            T.evaluate(1)
    obj = ib.get()
    _assert_print(
        obj.body,
        """
buffer = T.alloc_buffer((128, 128))
buffer_1 = T.decl_buffer((128, 128), data=buffer.data)
T.evaluate(1)
""",
    )


def test_allocate_with_decl_buffer_sugar_multi_usage():
    # AllocBuffer and DeclBuffer are flat siblings
    with IRBuilder() as ib:
        with T.prim_func():
            T.func_name("test")
            buf = T.alloc_buffer([128, 128], "float32")
            buf2 = T.decl_buffer([128, 128], "float32", data=buf.data)
            T.evaluate(buf.data)
    obj = ib.get()
    _assert_print(
        obj.body,
        """
buffer = T.alloc_buffer((128, 128))
buffer_1 = T.decl_buffer((128, 128), data=buffer.data)
T.evaluate(buffer.data)
""",
    )


def test_allocate_with_decl_buffer_no_sugar_mismatch():
    with IRBuilder() as ib:
        with T.prim_func():
            T.func_name("test")
            buf = T.alloc_buffer([128, 128], "float32")
            buf2 = T.decl_buffer([256, 256], "float32", data=buf.data)
            T.evaluate(buf.data)
    obj = ib.get()
    _assert_print(
        obj.body,
        """
buffer = T.alloc_buffer((128, 128))
buffer_1 = T.decl_buffer((256, 256), data=buffer.data)
T.evaluate(buffer.data)
""",
    )


def test_decl_buffer():
    # DeclBuffer is flat: we need a frame to hold multiple stmts
    with IRBuilder() as ib:
        with T.prim_func():
            T.func_name("test")
            buf = T.decl_buffer((10, 10), data=T.ptr("float32"))
            T.evaluate(1)
    obj = ib.get()
    # Print only the body (skip PrimFunc wrapper)
    _assert_print(
        obj.body,
        """
v = T.handle("float32", "global")
buffer = T.decl_buffer((10, 10), data=v)
T.evaluate(1)
""",
    )


def test_seq_stmt():
    with IRBuilder() as ib:
        with T.serial(10):
            T.evaluate(1)
            T.evaluate(2)
    obj = ib.get().body
    _assert_print(
        obj,
        """
T.evaluate(1)
T.evaluate(2)
""",
    )


def test_if_then_else():
    with IRBuilder() as ib:
        with T.If(T.int32() == 1):
            with T.Then():
                T.evaluate(0)

    obj = ib.get()
    _assert_print(
        obj,
        """
v = T.int32()
if v == 1:
    T.evaluate(0)
""",
    )


def test_evaluate():
    with IRBuilder() as ib:
        T.evaluate(0)
    obj = ib.get()
    _assert_print(
        obj,
        """
T.evaluate(0)
""",
    )


def test_var():
    a = tirx.Var("a", "float32")
    _assert_print(
        a,
        """
a = T.float32()
a""",
    )


def test_size_var():
    a = tirx.SizeVar("a", "float32")
    _assert_print(
        a,
        """
a = T.float32(is_size_var=True)
a""",
    )


def test_iter_var():
    a = tirx.IterVar((0, 8), "a", iter_type=tirx.IterVar.DataPar)
    _assert_print(
        a,
        """
a = T.int32()
T.iter_var(a, T.Range(0, 8), "DataPar", "")
""",
    )


def test_string_imm():
    s = tirx.StringImm("str")
    _assert_print(s, '"str"')


def test_cast():
    obj = tirx.Cast("float64", tirx.Var("a", "float32"))
    _assert_print(
        obj,
        """
a = T.float32()
T.Cast("float64", a)
""",
    )


def test_llvm_intrin_imm():
    a = tirx.call_llvm_intrin("int32x4", "llvm.donothing")
    _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing")')
    a = tirx.call_llvm_pure_intrin("int32x4", "llvm.donothing")
    _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing")')


def test_binary_arith():
    a = tirx.Var("a", "int32")
    b = tirx.Var("b", "int32")
    for op, sign in [
        (tirx.Add, "+"),
        (tirx.Sub, "-"),
        (tirx.Mul, "*"),
        (tirx.Mod, "truncmod"),
        (tirx.FloorDiv, "//"),
        (tirx.FloorMod, "%"),
        (tirx.LT, "<"),
        (tirx.LE, "<="),
        (tirx.EQ, "=="),
        (tirx.NE, "!="),
        (tirx.GT, ">"),
        (tirx.GE, ">="),
    ]:
        obj = op(a, b)
        if sign.isalpha():
            expected = f"""
a = T.int32()
b = T.int32()
T.{sign}(a, b)"""
        else:
            expected = f"""
a = T.int32()
b = T.int32()
a {sign} b"""
        _assert_print(obj, expected)


def test_binary_arith_const():
    a = tirx.IntImm("int64", 3)
    b = tirx.IntImm("int64", 4)
    for op, name in [
        (tirx.Add, "Add"),
        (tirx.Sub, "Sub"),
        (tirx.Mul, "Mul"),
        (tirx.Div, "Div"),
        (tirx.Mod, "truncmod"),
        (tirx.FloorDiv, "FloorDiv"),
        (tirx.FloorMod, "FloorMod"),
        (tirx.LT, "LT"),
        (tirx.LE, "LE"),
        (tirx.EQ, "EQ"),
        (tirx.NE, "NE"),
        (tirx.GT, "GT"),
        (tirx.GE, "GE"),
    ]:
        obj = op(a, b)
        expected = f"""
T.{name}({a!s}, {b!s})"""
        _assert_print(obj, expected)


def test_int_div():
    a = tirx.Var("a", "int32")
    b = tirx.Var("b", "int32")
    _assert_print(
        tirx.Div(a, b),
        """
a = T.int32()
b = T.int32()
T.Div(a, b)
""",
    )


def test_logical():
    a = tirx.Var("a", "bool")
    b = tirx.Var("b", "bool")
    _assert_print(
        tirx.And(a, b),
        """
a = T.bool()
b = T.bool()
a and b
""",
    )
    _assert_print(
        tirx.Or(a, b),
        """
a = T.bool()
b = T.bool()
a or b
""",
    )
    _assert_print(
        tirx.Not(a),
        """
a = T.bool()
not a
""",
    )


def test_select():
    obj = tirx.Select(True, 0, 2)
    _assert_print(
        obj,
        """T.Select(T.bool(True), 0, 2)
""",
    )


@pytest.mark.parametrize(
    "lanes, scripted_lanes", [(32, "32"), (tvm.tirx.vscale() * 8, "T.vscale() * 8")]
)
def test_ramp(lanes, scripted_lanes):
    a = tirx.Var("a", "int32")
    obj = tirx.Ramp(a, 1, lanes)
    _assert_print(
        obj,
        f"""
a = T.int32()
T.Ramp(a, 1, {scripted_lanes})
""",
    )


@pytest.mark.parametrize(
    "lanes, scripted_lanes", [(4, "4"), (tvm.tirx.vscale() * 4, "T.vscale() * 4")]
)
def test_broadcast(lanes, scripted_lanes):
    obj = tirx.Broadcast(0, lanes)
    _assert_print(
        obj,
        f"""
T.Broadcast(0, {scripted_lanes})
""",
    )


def test_let_expr():
    x = tirx.Var("x", "int32")
    obj = tirx.Let(x, 1, x + 1)
    _assert_print(
        obj,
        """
x = T.int32()
T.Let(x + 1, where={x: 1})
""",
    )


def test_call():
    obj = tirx.atan(T.float32(1.0))
    _assert_print(
        obj,
        """
T.atan(T.float32(1.0))
""",
    )


def test_comm_reducer():
    obj = T.comm_reducer(lambda x, y: x + y, identity=[T.float32(0)])
    _assert_print(
        obj,
        """
T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)])
""",
    )


def test_int_imm():
    obj = T.int16(1)
    _assert_print(
        obj,
        """
T.int16(1)
""",
    )


def test_float_imm():
    obj = T.float16(1)
    _assert_print(
        obj,
        """
T.float16(1.0)
""",
    )


def test_range():
    obj = Range(0, 10)
    _assert_print(
        obj,
        """
I.Range(0, 10)
""",
    )


def test_prim_type():
    obj = ir.PrimType("float32")
    _assert_print(obj, "T.float32")


def test_pointer_type():
    obj = ir.PointerType(ir.PrimType("int32"), "global")
    _assert_print(obj, 'T.handle("int32", "global")')


def test_tuple_type():
    obj = ir.TupleType([ir.PrimType("float32"), ir.PrimType("int32")])
    _assert_print(obj, "T.Tuple(T.float32, T.int32)")


def test_remap():
    from tvm.script import tirx as T

    @T.prim_func
    def block_with_remap_implicitly():
        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
            with T.sblock("update"):
                v0 = T.axis.spatial(128, i0 + 1)
                v1 = T.axis.spatial(128, i1)
                v2 = T.axis.reduce(128, i2)
                v3 = T.axis.spatial(128, i3 - 1)
                v4 = T.axis.reduce(128, i4)
                v5 = T.axis.spatial(128, i5)

    @T.prim_func
    def block_with_remap_explicitly():
        for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
            with T.sblock("update"):
                v0 = T.axis.spatial(128, i0 + 1)
                v1, v2 = T.axis.remap("SR", [i1, i2])
                v3 = T.axis.spatial(128, i3 - 1)
                v4, v5 = T.axis.remap("RS", [i4, i5])

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def main():
    # with T.sblock("root"):
    for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128):
        with T.sblock("update"):
            v0 = T.axis.spatial(128, i0 + 1)
            v1, v2 = T.axis.remap("SR", [i1, i2])
            v3 = T.axis.spatial(128, i3 - 1)
            v4, v5 = T.axis.remap("RS", [i4, i5])
            T.reads()
            T.writes()
            T.evaluate(0)"""
    _assert_print(block_with_remap_explicitly.with_attr("global_symbol", "main"), expected_output)
    _assert_print(block_with_remap_implicitly.with_attr("global_symbol", "main"), expected_output)


def test_root_block():
    from tvm.script import tirx as T

    @T.prim_func
    def root_block_implicitly():
        a = T.sblock_alloc_buffer([128, 128])
        for i, j in T.grid(128, 128):
            with T.sblock():
                T.evaluate(0)

    @T.prim_func
    def root_block_explicitly():
        with T.sblock("root"):
            a = T.sblock_alloc_buffer([128, 128])
            for i, j in T.grid(128, 128):
                with T.sblock():
                    T.evaluate(0)

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def main():
    # with T.sblock("root"):
    a = T.sblock_alloc_buffer((128, 128))
    for i, j in T.grid(128, 128):
        with T.sblock(""):
            T.reads()
            T.writes()
            T.evaluate(0)
    """
    _assert_print(root_block_implicitly.with_attr("global_symbol", "main"), expected_output)
    _assert_print(root_block_explicitly.with_attr("global_symbol", "main"), expected_output)


def test_private_primfunc():
    from tvm.script import tirx as T

    a = tirx.Var("a", "handle")
    b = tirx.Var("b", "handle")
    func = tirx.PrimFunc(
        params=[a, b],
        ret_type=None,
        buffer_map={
            a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
            b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
        },
        body=tirx.Evaluate(0),
    )
    _assert_print(
        func,
        expected="""
# from tvm.script import tirx as T

@T.prim_func(private=True)
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    T.evaluate(0)""",
    )


def test_prim_func_different_symbol():
    from tvm.script import tirx as T

    @T.prim_func
    def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
        T.func_attr({"global_symbol": "func"})
        T.evaluate(0)

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    T.evaluate(0)
    """
    _assert_print(main, expected_output)


def test_variable_with_cpp_address():
    """The show_object_address option displays the C++ addressess

    Because the C++ address may vary with each execution, the output
    produced with this option cannot be compared to a fixed string.
    Instead, this test uses the normal script output to generate a
    regular expression against with the test output must match.  The
    regular expression validates that all names have been appended
    with "_0x" followed by a hexadecimal number, and that the address
    is the same for each variable.
    """
    from tvm.script import tirx as T

    # The test function has all named objects suffixed with "_name",
    # to avoid spurious replacement when generating the expected
    # regex.
    @T.prim_func
    def func(a_name: T.handle):
        N_name = T.int64()
        A_name = T.match_buffer(a_name, N_name, "float32")
        for i_name in range(N_name):
            A_name[i_name] = A_name[i_name] + 1.0

    without_address = func.script(show_object_address=False)
    script = func.script(show_object_address=True)

    expected_regex = re.escape(without_address)
    for name in ["a_name", "A_name", "N_name", "i_name"]:
        # Replace all occurrences with a backref to an earlier match
        expected_regex = expected_regex.replace(name, rf"(?P={name})")
        # Then replace the first such backref with a capturing group.
        expected_regex = expected_regex.replace(
            rf"(?P={name})", rf"(?P<{name}>{name}_0x[A-Fa-f0-9]+)", 1
        )

    assert re.match(expected_regex, script)


def test_return_statement():
    from tvm.script import tirx as T

    @T.prim_func
    def func():
        T.evaluate(T.ret(5))

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def func():
    return 5
    """
    _assert_print(func, expected_output)


CUSTOM_FLOAT_DTYPES = [
    # Float8 variants
    "float8_e3m4",
    "float8_e4m3",
    "float8_e4m3b11fnuz",
    "float8_e4m3fn",
    "float8_e4m3fnuz",
    "float8_e5m2",
    "float8_e5m2fnuz",
    "float8_e8m0fnu",
    # Float6 variants
    "float6_e2m3fn",
    "float6_e3m2fn",
    # Float4 variant
    "float4_e2m1fn",
]


@pytest.mark.parametrize("dtype", CUSTOM_FLOAT_DTYPES)
def test_custom_float_types(dtype):
    from tvm.script import tirx as T

    @T.prim_func()
    def func():
        T.evaluate(getattr(T, dtype)(0.0))

    expected_output = f"""
# from tvm.script import tirx as T

@T.prim_func
def func():
    T.evaluate(T.{dtype}(0.0))
"""
    _assert_print(func, expected_output)


def test_predicated_load_store():
    from tvm.script import tirx as T

    @T.prim_func
    def main(a: T.handle, b: T.handle):
        A = T.match_buffer(a, (128, 128), "float32")
        B = T.match_buffer(b, (256, 256), "float32")
        T.func_attr({"global_symbol": "func"})
        a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)))
        A.vstore([0, T.Ramp(0, 2, 4)], a_load, predicate=T.Broadcast(T.bool(False), 4))

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    A.vstore([0, T.Ramp(0, 2, 4)], A.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4))
    """
    _assert_print(main, expected_output)


def test_predicated_buffer_load_store():
    a = tirx.Var("a", "handle")
    b = tirx.Var("b", "handle")
    buffer_map = {
        a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"),
        b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"),
    }
    buffer_load = tirx.BufferLoad(
        buffer=buffer_map[b],
        indices=[0, tirx.Ramp(0, 4, 4)],
        predicate=tirx.Broadcast(tirx.IntImm("bool", 0), 4),
    )
    body = tirx.BufferStore(
        buffer=buffer_map[a],
        value=buffer_load,
        indices=[0, tirx.Ramp(0, 2, 4)],
        predicate=tirx.Broadcast(tirx.IntImm("bool", 0), 4),
    )
    func = tirx.PrimFunc(
        params=[a, b],
        ret_type=None,
        buffer_map=buffer_map,
        body=body,
    )

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func(private=True)
def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    A.vstore([0, T.Ramp(0, 2, 4)], B.vload([0, T.Ramp(0, 4, 4)], predicate=T.Broadcast(T.bool(False), 4)), predicate=T.Broadcast(T.bool(False), 4))
    """
    _assert_print(func, expected_output)


def test_predicated_scalable_load_store():
    from tvm.script import tirx as T

    @T.prim_func
    def main(a: T.handle, b: T.handle):
        A = T.match_buffer(a, (128, 128), "float32")
        B = T.match_buffer(b, (256, 256), "float32")
        T.func_attr({"global_symbol": "func"})
        mask = T.meta_var(T.get_active_lane_mask("uint1xvscalex4", 0, 13))
        a_load = T.meta_var(A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=mask))
        A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], a_load, predicate=mask)

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")):
    A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], A.vload([0, T.Ramp(0, 4, T.vscale() * 4)], predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13)), predicate=T.get_active_lane_mask("uint1xvscalex4", 0, 13))
    """
    _assert_print(main, expected_output)


def test_vload_with_explicit_scalable_data_type():
    from tvm.script import tirx as T

    @T.prim_func
    def main(a: T.handle, b: T.handle):
        A = T.match_buffer(a, (128,), "float32")
        B = T.match_buffer(b, (128,), "float32")
        B[0 : T.vscale() * 4] = A.vload([T.Ramp(0, 1, T.vscale() * 4)], dtype="float32xvscalex4")

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")):
    B[0:T.vscale() * 4] = A[0:T.vscale() * 4]
    """
    _assert_print(main, expected_output)


def test_vectorize_llvm_pure_intrin():
    from tvm.script import tirx as T

    @T.prim_func
    def main(a: T.handle, b: T.handle):
        A = T.match_buffer(a, (4,), "float32")
        B = T.match_buffer(b, (4,), "float32")
        A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[T.Ramp(0, 1, 4)])

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
    A[0:4] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[0:4])
    """
    _assert_print(main, expected_output)


def test_func_with_loop_jumps():
    from tvm.script import tirx as T

    @T.prim_func
    def main(a: T.handle, b: T.handle):
        A = T.match_buffer(a, (4,), "float32")
        B = T.match_buffer(b, (4,), "float32")
        for i in range(1000):
            if i % 13 == 0:
                A[1] = A[1] + 1
                continue
            if A[0] >= B[0]:
                break

    expected_output = """
# from tvm.script import tirx as T

@T.prim_func
def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")):
    for i in range(1000):
        if i % 13 == 0:
            A[1] = A[1] + T.float32(1.0)
            T.continue_loop()
        if A[0] >= B[0]:
            T.break_loop()
    """
    _assert_print(main, expected_output)


if __name__ == "__main__":
    tvm.testing.main()
