blob: dc539b9327c8f02989a1c6de467d0b8428ca8020 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.testing
from tvm.script import tir as T
@T.prim_func
def get_valid_counts(
data: T.handle,
valid_count: T.handle,
out: T.handle,
out_indices: T.handle,
score_threshold: T.float32,
id_index: T.int32,
score_index: T.int32,
) -> None:
data_buf = T.match_buffer(data, (1, 2500, 6), "float32")
valid_count_buf = T.match_buffer(valid_count, (1,), "int32")
out_buf = T.match_buffer(out, (1, 2500, 6), "float32")
out_indices_buf = T.match_buffer(out_indices, (1, 2500), "int32")
with T.block("init"):
vi = T.axis.S(1, 0)
valid_count_buf[vi] = T.int32(0)
for j in range(2500):
with T.block("update"):
vj = T.axis.S(2500, j)
T.reads([data_buf[vi, vj, 6]])
T.writes([valid_count_buf[vi], out_indices_buf[vi, vj], out_buf[vi, vj, 6]])
if (data_buf[vi, vj, score_index] > score_threshold) and (
(id_index < 0) or (data_buf[vi, vj, id_index] >= T.float32(0))
):
for k in T.serial(0, 6):
out_buf[vi, valid_count_buf[vi], k] = data_buf[vi, vj, k]
out_indices_buf[vi, valid_count_buf[vi]] = vj
valid_count_buf[vi] = valid_count_buf[vi] + 1
if vj >= valid_count_buf[vi]:
for k in T.serial(0, 6):
out_buf[vi, vj, k] = T.float32(-1)
out_indices_buf[vi, vj] = T.int32(-1)
def _check_get_valid_counts_with_numpy(f, dshape, score_threshold, id_index, score_index):
dtype = "float32"
ctx = tvm.cpu()
batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,), dtype="int32")
np_out2 = np.zeros(shape=dshape).astype(dtype)
np_out3 = np.zeros(shape=(batch_size, num_anchor), dtype="int32")
for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor):
score = np_data[i, j, score_index]
if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0):
for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
np_out3[i, inter_idx] = j
inter_idx += 1
if j >= np_out1[i]:
for k in range(elem_length):
np_out2[i, j, k] = -1.0
np_out3[i, j] = -1
in_data = tvm.nd.array(np_data, ctx)
out1 = tvm.nd.array(np_out1, ctx)
out2 = tvm.nd.array(np_out2, ctx)
out3 = tvm.nd.array(np_out3, ctx)
f(in_data, out1, out2, out3, score_threshold, id_index, score_index)
tvm.testing.assert_allclose(out1.numpy(), np_out1, rtol=1e-5)
tvm.testing.assert_allclose(out2.numpy(), np_out2, rtol=1e-5)
tvm.testing.assert_allclose(out3.numpy(), np_out3, rtol=1e-5)
print("test get_valid_counts end")
def test_get_valid_counts_script_func():
device = "llvm"
# check lowering
print(get_valid_counts.script())
mod = tvm.ir.IRModule({"get_valid_counts": get_valid_counts})
print(mod.script())
# check building
f = tvm.build(mod["get_valid_counts"], target=device)
_check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1)
@T.prim_func
def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, [], dtype="float32")
B = T.match_buffer(b, [], dtype="float32")
# body
# tir.with block("root")
C = T.alloc_buffer([], dtype="float32")
A[()] = T.float32(2)
C[()] = A[()] + B[()]
B[()] = C[()]
@T.prim_func
def alloc_zero_dim_buffer_block(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (), "float32")
B = T.match_buffer(b, (), "float32")
with T.block("root"):
T.reads([])
T.writes([])
C = T.alloc_buffer((), "float32")
A[()] = T.float32(2)
C[()] = A[()] + B[()]
B[()] = C[()]
def _check_alloc_zero_dim_buffer(f):
dtype = "float32"
ctx = tvm.cpu()
np_data = np.zeros(shape=()).astype(dtype)
np_out = np.zeros(shape=()).astype(dtype)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_out = tvm.nd.array(np_out, ctx)
# np func exection
np_inter = np.array(1)
np_data[()] = 2.0
np_inter[()] = np_data[()] + np_out[()]
np_out[()] = np_inter[()]
# tvm func execution
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.numpy(), np_out, rtol=1e-5)
def test_alloc_zero_dim_buffer_round_trip():
func = alloc_zero_dim_buffer
func_with_block = alloc_zero_dim_buffer_block
rt_func = tvm.script.from_source(func.script())
rt_func_with_block = tvm.script.from_source(func_with_block.script())
rt_mod = tvm.build(rt_func, "llvm")
rt_mod_with_block = tvm.build(rt_func_with_block, "llvm")
tvm.ir.assert_structural_equal(
func.with_attr("global_symbol", "main"), func_with_block.with_attr("global_symbol", "main")
)
tvm.ir.assert_structural_equal(
rt_func.with_attr("global_symbol", "main"),
rt_func_with_block.with_attr("global_symbol", "main"),
)
_check_alloc_zero_dim_buffer(rt_mod)
_check_alloc_zero_dim_buffer(rt_mod_with_block)
@T.prim_func
def ceildiv_test(A: T.Buffer(16, "int32")):
for i in range(16):
A[i] = T.ceildiv(A[i], 4)
@tvm.testing.requires_llvm
def test_ceildiv():
f = tvm.build(ceildiv_test, "llvm")
a = tvm.nd.array(np.arange(16).astype("int32"))
f(a)
ref = (np.arange(16) + 3) // 4
tvm.testing.assert_allclose(a.numpy(), ref)
@T.prim_func
def slice_op_test(
A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32")
):
B[0:5] = A[0:5] + B[0:5]
B[0:5] = A[0:5] - B[0:5]
B[0:5] = A[0:5] * B[0:5]
B[0:5] = A[0:5] / B[0:5]
C[0:5] = C[0:5] % T.broadcast(T.uint32(5), 5)
B[0:5] = -B[0:5]
C[0:5] = C[0:5] >> 4
C[0:5] = C[0:5] << 4
C[0:5] = C[0:5] << C[0:5]
C[0:5] = C[0:5] >> C[0:5]
T.evaluate(A[0:5] > B[0:5])
T.evaluate(A[0:5] > 5)
T.evaluate(A[0:5] >= B[0:5])
T.evaluate(A[0:5] >= 5)
T.evaluate(A[0:5] < B[0:5])
T.evaluate(A[0:5] < 5)
T.evaluate(A[0:5] <= B[0:5])
T.evaluate(A[0:5] <= 5)
T.evaluate(A[0:5] == B[0:5])
T.evaluate(A[0:5] == 5)
T.evaluate(A[0:5] != B[0:5])
T.evaluate(A[0:5] != 5)
T.evaluate((A[0:5] > 0) and (B[0:5] > 0))
T.evaluate((A[0:5] > 0) or (B[0:5] > 0))
T.evaluate((A[0:5] < 0) and (1 > 0))
T.evaluate((A[0:5] > 0) or (1 > 0))
@T.prim_func
def slice_op_test_ref(
A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32"), C: T.Buffer((10,), "uint32")
):
B[0:5] = A[0:5] + B[0:5]
B[0:5] = A[0:5] - B[0:5]
B[0:5] = A[0:5] * B[0:5]
B[0:5] = A[0:5] / B[0:5]
C[0:5] = C[0:5] % T.Broadcast(T.uint32(5), 5)
B[0:5] = B[0:5] * T.Broadcast(T.float32(-1), 5)
C[0:5] = T.shift_right(C[0:5], T.Broadcast(T.uint32(4), 5))
C[0:5] = T.shift_left(C[0:5], T.Broadcast(T.uint32(4), 5))
C[0:5] = T.shift_left(C[0:5], C[0:5])
C[0:5] = T.shift_right(C[0:5], C[0:5])
T.evaluate(A[0:5] > B[0:5])
T.evaluate(A[0:5] > T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] >= B[0:5])
T.evaluate(A[0:5] >= T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] < B[0:5])
T.evaluate(A[0:5] < T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] <= B[0:5])
T.evaluate(A[0:5] <= T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] == B[0:5])
T.evaluate(A[0:5] == T.Broadcast(T.float32(5), 5))
T.evaluate(A[0:5] != B[0:5])
T.evaluate(A[0:5] != T.Broadcast(T.float32(5), 5))
T.bitwise_and(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5))
T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), B[0:5] > T.Broadcast(T.float32(0), 5))
T.bitwise_and(A[0:5] < T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5))
T.bitwise_or(A[0:5] > T.Broadcast(T.float32(0), 5), T.Broadcast(T.bool(1), 5))
def test_slice_op():
tvm.ir.assert_structural_equal(
slice_op_test.with_attr("global_symbol", "main"),
slice_op_test_ref.with_attr("global_symbol", "main"),
)
if __name__ == "__main__":
test_get_valid_counts_script_func()
test_alloc_zero_dim_buffer_round_trip()
test_slice_op()