blob: 53d2a60f7abd83aa932ce0b321f49d7ec71b61bd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E711, F821, F841
import itertools
import numpy as np
import pytest
import tvm
from tvm import tir
from tvm.base import TVMError
from tvm.ir.transform import PassContext
from tvm.script import tir as T
def build_tir_func(func):
func = func.with_attr("global_symbol", "main")
pass_ctx = PassContext.current()
if pass_ctx.config.get("tir.noalias", True):
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({"main": func})
func = tvm.compile(mod)
return func
def test_scalar_add():
# All these types should be interchangeable with each other
# E.g. float16 + float32 upconverts the float16 --> float32
# Meanwhile if an int or float or together the int will be
# cast to the float type.
lhs_types = ["float32", "float16", "int32", "int64"]
rhs_types = ["float32", "float16"]
for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types):
# Input vars should be float32, we will cast to test for upcasting between them
lhs_input = tir.Var("lhs", "float32")
rhs_input = tir.Var("rhs", "float32")
lhs = tir.Cast(lhs_type, lhs_input)
rhs = tir.Cast(rhs_type, rhs_input)
output = lhs + rhs
output = tir.ret(output)
output = tir.Evaluate(output)
func = tir.PrimFunc([lhs_input, rhs_input], output)
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 3.0
def assignment_helper(store_dtype, value_dtype):
store = tir.Var("store", dtype=store_dtype)
value = tir.Var("value", dtype=value_dtype)
tir.Let(store, value, body=store)
def test_fail_implicit_downcasts_same_type():
# These lists should be sorted
bits = [8, 16, 32, 64]
for type in ["float", "int", "uint"]:
for i in range(len(bits) - 1):
with pytest.raises(TVMError):
assignment_helper(
store_dtype=f"{type}{bits[i]}", value_dtype=f"{type}{bits[i + 1]}"
)
def test_cast_between_types():
# We should only be able to assign values with the same types
bits = [16, 32]
types = ["float", "int", "uint"]
for store_type, store_bits, value_type, value_bits in itertools.product(
types, bits, types, bits
):
store_dtype = f"{store_type}{store_bits}"
value_dtype = f"{value_type}{value_bits}"
if store_dtype == value_dtype:
assignment_helper(store_dtype, value_dtype)
else:
# TODO: we might want to allow casts between uint and int types
with pytest.raises(TVMError):
assignment_helper(store_dtype, value_dtype)
def test_ret_const():
a = tir.const(0)
b = tir.ret(a)
b = tir.Evaluate(b)
func = tir.PrimFunc([], b)
func = build_tir_func(func)
out = func()
assert out == 0
def test_control_flow_jump():
@T.prim_func
def func(a: T.float32, b: T.float32):
if True:
T.evaluate(T.ret(a))
T.evaluate(T.ret(b))
func = build_tir_func(func)
out = func(1.0, 2.0)
assert out == 1.0
def test_break_loop():
@T.prim_func
def func(In: T.Buffer[(2,), "int32"], Out: T.Buffer[(2,), "int32"]):
Out[0] = 0
Out[1] = 1
for i in range(10):
for j in range(10):
if i * 10 + j == In[0]:
Out[0] = i + j
break
if Out[0] > 0:
break
while Out[1] > 0:
Out[1] = Out[1] + 1
if Out[1] > In[1]:
break
func = build_tir_func(func)
a = np.asarray([49, 8], "int32")
b = np.zeros([2], "int32")
if not hasattr(b, "__dlpack__"):
return
func(a, b)
assert b[0] == 13
assert b[1] == 9
def test_continue_loop():
@T.prim_func
def func(Out: T.Buffer[(2,), "int32"]):
T.func_attr({"global_symbol": "main"})
Out[0] = 0
Out[1] = 0
for i in range(10):
for j in range(10):
if (i * 10 + j) % 3 != 0:
continue
Out[0] = Out[0] + 1
k = T.decl_buffer([], "int32")
k[()] = 0
while k[()] < Out[0]:
k[()] = k[()] + 1
if k[()] % 6 == 0:
Out[1] = Out[1] + 1
continue
func = build_tir_func(func)
b = np.zeros([2], "int32")
if not hasattr(b, "__dlpack__"):
return
func(b)
assert b[0] == 34
assert b[1] == 5 # 6, 12, 18, 24, 30
def test_exception():
with pytest.raises(TypeError):
x = tir.Var(name=1, dtype="int")
def test_eq_ops():
a = tir.IntImm("int8", 1)
with pytest.raises(ValueError):
assert a != None
with pytest.raises(ValueError):
assert not a == None
b = tir.StringImm("abc")
assert b != None
assert not b == None
if __name__ == "__main__":
test_scalar_add()
test_ret_const()
test_control_flow_jump()
test_exception()
test_eq_ops()