| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint: disable=missing-function-docstring,missing-module-docstring |
| import numpy as np |
| import tvm |
| import tvm.testing |
| from tvm import te, tir, topi |
| from tvm.script import tir as T |
| import pytest |
| |
| |
| def test_unique_name_complete_block(): |
| A = te.placeholder((16, 16), name="A") |
| B = te.compute((16, 16), lambda x, y: A[x, y] * 2, name="main") |
| C = te.compute((16, 16), lambda x, y: B[x, y] + 1, name="main") |
| func = te.create_prim_func([A, C]) |
| s = tir.Schedule(func, debug_mask="all") |
| assert isinstance(s.get_sref(s.get_block("main")), tir.schedule.StmtSRef) |
| assert isinstance(s.get_sref(s.get_block("main_1")), tir.schedule.StmtSRef) |
| |
| |
| def test_unique_name_reduction_block(): |
| k1 = te.reduce_axis((0, 16), "k1") |
| k2 = te.reduce_axis((0, 16), "k2") |
| A = te.placeholder((16, 16), name="A") |
| B = te.compute((16,), lambda i: te.sum(A[i, k1], axis=k1), name="sum") |
| C = te.compute((), lambda: te.sum(B[k2], axis=k2), name="sum") |
| func = te.create_prim_func([A, C]) |
| s = tir.Schedule(func, debug_mask="all") |
| assert isinstance(s.get_sref(s.get_block("sum")), tir.schedule.StmtSRef) |
| assert isinstance(s.get_sref(s.get_block("sum_1")), tir.schedule.StmtSRef) |
| |
| |
| def _check_workload(te_workload, tir_workload, index_dtype_override=None, do_simplify=False): |
| func = te.create_prim_func(te_workload(), index_dtype_override) |
| if do_simplify: |
| simplify = tir.transform.Simplify() |
| func = simplify(tvm.IRModule.from_expr(func))["main"] |
| tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"] |
| tvm.ir.assert_structural_equal(func, tir_workload) |
| # make sure that we can create schedule from the func |
| s = tir.Schedule(func, debug_mask="all") |
| assert s |
| |
| |
| def te_matmul(): |
| k = te.reduce_axis((0, 128), "k") |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") |
| return [A, B, C] |
| |
| |
| @T.prim_func |
| def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(a, (128, 128)) |
| B = T.match_buffer(b, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| |
| for i0, j0, k0 in T.grid(128, 128, 128): |
| with T.block(): |
| i, j, k = T.axis.remap("SSR", [i0, j0, k0]) |
| with T.init(): |
| C[i, j] = 0.0 |
| C[i, j] += A[i, k] * B[j, k] |
| |
| |
| @T.prim_func |
| def tir_matmul_int64( |
| A: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| B: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| C: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| for i0, j0, k0 in T.grid(T.int64(128), T.int64(128), T.int64(128)): |
| with T.block(): |
| i, j, k = T.axis.remap("SSR", [i0, j0, k0]) |
| with T.init(): |
| C[i, j] = 0.0 |
| C[i, j] += A[i, k] * B[j, k] |
| |
| |
| def test_matmul(): |
| _check_workload(te_matmul, tir_matmul) |
| |
| |
| def test_matmul_int64(): |
| _check_workload(te_matmul, tir_matmul_int64, index_dtype_override="int64") |
| |
| |
| def te_element_wise(): |
| A = te.placeholder((128, 128), name="A") |
| B = te.compute((128, 128), lambda x, y: A[x, y] * 2, name="B") |
| C = te.compute((128, 128), lambda x, y: B[x, y] + 1, name="C") |
| return [A, C] |
| |
| |
| @T.prim_func |
| def tir_element_wise(a: T.handle, c: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(a, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| B = T.alloc_buffer((128, 128)) |
| |
| for i0, j0 in T.grid(128, 128): |
| with T.block(): |
| i, j = T.axis.remap("SS", [i0, j0]) |
| B[i, j] = A[i, j] * 2.0 |
| for i0, j0 in T.grid(128, 128): |
| with T.block(): |
| i, j = T.axis.remap("SS", [i0, j0]) |
| C[i, j] = B[i, j] + 1.0 |
| |
| |
| def test_element_wise(): |
| _check_workload(te_element_wise, tir_element_wise) |
| |
| |
| def te_conv2d(): |
| batch = 16 |
| in_channel = 16 |
| out_channel = 32 |
| size = 14 |
| kernel = 3 |
| |
| A = te.placeholder((batch, in_channel, size, size), name="A") |
| W = te.placeholder((in_channel, kernel, kernel, out_channel), name="W") |
| Apad = te.compute( |
| (batch, in_channel, size + 2, size + 2), |
| lambda nn, cc, yy, xx: tvm.tir.if_then_else( |
| tvm.tir.all(yy >= 1, yy - 1 < size, xx >= 1, xx - 1 < size), |
| A[nn, cc, yy - 1, xx - 1], |
| 0.0, |
| ), |
| name="Apad", |
| ) |
| rc = te.reduce_axis((0, in_channel), name="rc") |
| ry = te.reduce_axis((0, kernel), name="ry") |
| rx = te.reduce_axis((0, kernel), name="rx") |
| B = te.compute( |
| (batch, out_channel, size, size), |
| lambda nn, ff, yy, xx: te.sum( |
| Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff], axis=[rc, ry, rx] |
| ), |
| name="B", |
| ) |
| return [A, W, B] |
| |
| |
| @T.prim_func |
| def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(a, [16, 16, 14, 14]) |
| W = T.match_buffer(w, [16, 3, 3, 32]) |
| B = T.match_buffer(b, [16, 32, 14, 14]) |
| Apad = T.alloc_buffer([16, 16, 16, 16]) |
| |
| for n, c, y, x in T.grid(16, 16, 16, 16): |
| with T.block("Apad"): |
| nn, cc, yy, xx = T.axis.remap("SSSS", [n, c, y, x]) |
| Apad[nn, cc, yy, xx] = T.if_then_else( |
| 1 <= yy and yy < 15 and 1 <= xx and xx < 15, |
| A[nn, cc, yy - 1, xx - 1], |
| 0.0, |
| dtype="float32", |
| ) |
| for n, f, y, x, kc, ky, kx in T.grid(16, 32, 14, 14, 16, 3, 3): |
| with T.block("B"): |
| nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [n, f, y, x, kc, ky, kx]) |
| with T.init(): |
| B[nn, ff, yy, xx] = 0.0 |
| B[nn, ff, yy, xx] += Apad[nn, rc, yy + ry, xx + rx] * W[rc, ry, rx, ff] |
| |
| |
| def test_conv2d(): |
| _check_workload(te_conv2d, tir_conv2d) |
| |
| |
| def te_multi_output(): |
| n = te.var("n") |
| m = te.var("m") |
| A0 = te.placeholder((m, n), name="A0") |
| A1 = te.placeholder((m, n), name="A1") |
| B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name="B") |
| return [A0, A1, B0, B1] |
| |
| |
| @T.prim_func |
| def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| m = T.int32() |
| n = T.int32() |
| A0 = T.match_buffer(a0, (m, n)) |
| A1 = T.match_buffer(a1, (m, n)) |
| B0 = T.match_buffer(b0, (m, n)) |
| B1 = T.match_buffer(b1, (m, n)) |
| |
| for i0, i1 in T.grid(m, n): |
| with T.block("B.v0"): |
| i, j = T.axis.remap("SS", [i0, i1]) |
| B0[i, j] = A0[i, j] + 2.0 |
| with T.block("B.v1"): |
| i, j = T.axis.remap("SS", [i0, i1]) |
| B1[i, j] = A1[i, j] * 3.0 |
| |
| |
| def test_multi_output(): |
| _check_workload(te_multi_output, tir_multi_output) |
| |
| |
| def te_extern(): |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| C = te.extern( |
| (128, 128), |
| [A, B], |
| lambda ins, outs: tvm.tir.call_packed( |
| "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0 |
| ), |
| name="C", |
| ) |
| return [A, B, C] |
| |
| |
| @T.prim_func |
| def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| off1 = te.var("elem_offset") |
| off2 = te.var("elem_offset_1") |
| off3 = te.var("elem_offset_2") |
| A = T.match_buffer(a, (128, 128), elem_offset=off1) |
| B = T.match_buffer(b, (128, 128), elem_offset=off2) |
| C = T.match_buffer(c, (128, 128), elem_offset=off3) |
| # body |
| with T.block("C"): |
| T.reads() |
| T.writes() |
| T.evaluate( |
| T.tvm_call_packed( |
| "tvm.contrib.cblas.matmul", |
| T.tvm_stack_make_array( |
| A.data, |
| T.tvm_stack_make_shape(128, 128, dtype="handle"), |
| 0, |
| 2, |
| 0.0, |
| off1, |
| dtype="handle", |
| ), |
| T.tvm_stack_make_array( |
| B.data, |
| T.tvm_stack_make_shape(128, 128, dtype="handle"), |
| 0, |
| 2, |
| 0.0, |
| off2, |
| dtype="handle", |
| ), |
| T.tvm_stack_make_array( |
| C.data, |
| T.tvm_stack_make_shape(128, 128, dtype="handle"), |
| 0, |
| 2, |
| 0.0, |
| off3, |
| dtype="handle", |
| ), |
| 0, |
| 0, |
| dtype="int32", |
| ) |
| ) |
| |
| |
| def test_extern(): |
| _check_workload(te_extern, tir_extern) |
| |
| |
| def te_reordered_matmul(): |
| k = te.reduce_axis((0, 128), "k") |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| C = te.compute((128, 128), lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), name="C") |
| return [C, A, B] |
| |
| |
| @T.prim_func |
| def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(a, (128, 128)) |
| B = T.match_buffer(b, (128, 128)) |
| C = T.match_buffer(c, (128, 128)) |
| |
| for i0, j0, k0 in T.grid(128, 128, 128): |
| with T.block(): |
| i, j, k = T.axis.remap("SSR", [i0, j0, k0]) |
| with T.init(): |
| C[i, j] = 0.0 |
| C[i, j] += A[i, k] * B[j, k] |
| |
| |
| def test_arg_order(): |
| _check_workload(te_reordered_matmul, tir_reordered_matmul) |
| |
| |
| def te_scan(): |
| m = te.var("m") |
| n = te.var("n") |
| X = te.placeholder((m, n), name="X") |
| s_state = te.placeholder((m, n)) |
| s_init = te.compute((1, n), lambda _, i: X[0, i]) |
| s_update = te.compute((m, n), lambda t, i: s_state[t - 1, i] + X[t, i]) |
| s_scan = tvm.te.scan(s_init, s_update, s_state, inputs=[X]) |
| return [X, s_scan] |
| |
| |
| def test_error_reporting(): |
| try: |
| te.create_prim_func(te_scan()) |
| assert False |
| except TypeError as e: |
| error_message = str(e) |
| assert error_message.find("Unsupported Operation: te.ScanOp.") != -1 |
| return |
| assert False |
| |
| |
| def test_constant(): |
| M = 11 |
| A = te.placeholder((M,), name="A") |
| B = te.compute(tuple(), lambda: 2, name="B") |
| # Manually craft ProducerLoad because `B[]` is not allowed. |
| C = te.compute( |
| (M,), lambda x: A[x] + tvm.tir.expr.ProducerLoad(B, []), name="C", tag="broadcast" |
| ) |
| |
| func = te.create_prim_func([C, A]) |
| func = tvm.compile(func) |
| a_np = np.random.uniform(size=(M,)).astype(A.dtype) |
| c = tvm.runtime.tensor(np.zeros(M, dtype=C.dtype)) |
| x = func(c, tvm.runtime.tensor(a_np)) |
| tvm.testing.assert_allclose(a_np + 2, c.numpy()) |
| |
| |
| def test_data_dependent_access(): |
| A = te.placeholder((10,), name="A") |
| B = te.placeholder((10,), name="B", dtype="int32") |
| C = te.compute((10,), lambda i: A[B[i]]) |
| |
| func = te.create_prim_func([C, A, B]) |
| func = tvm.compile(func) |
| |
| a_np = np.random.uniform(size=(10,)).astype(A.dtype) |
| b_np = np.arange(10, dtype=B.dtype) |
| c = tvm.runtime.tensor(np.zeros(10, dtype=C.dtype)) |
| func(c, tvm.runtime.tensor(a_np), tvm.runtime.tensor(b_np)) |
| tvm.testing.assert_allclose(a_np[b_np], c.numpy()) |
| |
| |
| def test_select_simplify(): |
| placeholder = te.placeholder([1, 128, 10, 10, 4], dtype="float32") |
| tensor = topi.nn.adaptive_pool(placeholder, [1, 1], "avg", "NCHW4c") |
| result = te.create_prim_func([placeholder, tensor]) |
| script_func = result.script() |
| # There should be no Select |
| assert script_func.find("Select") == -1 |
| # There should be no undefined vars |
| assert script_func.find("Var") == -1 |
| |
| |
| def test_tensor_attr(): |
| k = te.reduce_axis((0, 128), "k") |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| C = te.compute( |
| (128, 128), |
| lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), |
| name="C", |
| attrs={"layout_free_placeholders": [B]}, |
| ) |
| func = te.create_prim_func([A, B, C]) |
| rt_func = tvm.script.from_source(func.script()) |
| tvm.ir.assert_structural_equal(func, rt_func) |
| |
| |
| @T.prim_func |
| def expected_layout_attr( |
| A: T.Buffer((128, 128), "float32"), |
| B: T.Buffer((128, 128), "float32"), |
| D: T.Buffer((128, 128), "float32"), |
| ) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) |
| C = T.alloc_buffer([128, 128], dtype="float32") |
| for i0, i1, i2 in T.grid(128, 128, 128): |
| with T.block("C"): |
| x, y, k = T.axis.remap("SSR", [i0, i1, i2]) |
| with T.init(): |
| C[x, y] = T.float32(0) |
| C[x, y] = C[x, y] + A[x, k] * B[y, k] |
| for i0, i1 in T.grid(128, 128): |
| with T.block("D"): |
| T.block_attr({"layout_free_placeholders": [C]}) |
| x, y = T.axis.remap("SS", [i0, i1]) |
| D[x, y] = C[x, y] + T.float32(1) |
| |
| |
| @T.prim_func |
| def expected_layout_attr_int64( |
| A: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| B: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| D: T.Buffer((T.int64(128), T.int64(128)), "float32"), |
| ): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) |
| C = T.alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") |
| for x, y, k in T.grid(T.int64(128), T.int64(128), T.int64(128)): |
| with T.block("C"): |
| v_x, v_y, v_k = T.axis.remap("SSR", [x, y, k]) |
| T.reads(A[v_x, v_k], B[v_y, v_k]) |
| T.writes(C[v_x, v_y]) |
| with T.init(): |
| C[v_x, v_y] = T.float32(0) |
| C[v_x, v_y] = C[v_x, v_y] + A[v_x, v_k] * B[v_y, v_k] |
| for x, y in T.grid(T.int64(128), T.int64(128)): |
| with T.block("D"): |
| T.block_attr({"layout_free_placeholders": [C]}) |
| v_x, v_y = T.axis.remap("SS", [x, y]) |
| T.reads(C[v_x, v_y]) |
| T.writes(D[v_x, v_y]) |
| D[v_x, v_y] = C[v_x, v_y] + T.float32(1) |
| |
| |
| @pytest.mark.parametrize( |
| "index_dtype_override, expected", |
| [(None, expected_layout_attr), ("int64", expected_layout_attr_int64)], |
| ) |
| def test_tensor_layout_attr(index_dtype_override, expected): |
| k = te.reduce_axis((0, 128), "k") |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| C = te.compute( |
| (128, 128), |
| lambda x, y: te.sum(A[x, k] * B[y, k], axis=k), |
| name="C", |
| attrs={"layout_free_placeholders": [B]}, |
| ) |
| D = te.compute( |
| (128, 128), |
| lambda x, y: C[x, y] + 1, |
| name="D", |
| attrs={"layout_free_placeholders": [C]}, |
| ) |
| func = te.create_prim_func([A, B, D], index_dtype_override=index_dtype_override) |
| tvm.ir.assert_structural_equal(func, expected) |
| |
| |
| def te_argmax_idx_val(): |
| def f_combine(x, y): |
| lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) |
| rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) |
| return lhs, rhs |
| |
| def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): |
| return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) |
| |
| argmax = te.comm_reducer(f_combine, f_identity, name="argmax") |
| |
| m = te.var("m") |
| n = te.var("n") |
| idx = te.placeholder((m, n), name="idx", dtype="int32") |
| val = te.placeholder((m, n), name="val", dtype="float32") |
| k = te.reduce_axis((0, n), "k") |
| max_idx, max_val = te.compute( |
| (m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="argmax" |
| ) |
| return [idx, val, max_idx, max_val] |
| |
| |
| @T.prim_func |
| def tir_argmax_idx_val( |
| var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle |
| ) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| m = T.int32() |
| n = T.int32() |
| idx = T.match_buffer(var_idx, [m, n], dtype="int32") |
| val = T.match_buffer(var_val, [m, n], dtype="float32") |
| argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="int32") |
| argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="float32") |
| for i0, i1 in T.grid(m, n): |
| with T.block("argmax"): |
| i, k = T.axis.remap("SR", [i0, i1]) |
| T.reads(val[i, k], idx[i, k]) |
| T.writes(argmax_v0[i], argmax_v1[i]) |
| with T.init(): |
| argmax_v0[i] = T.int32(-1) |
| argmax_v1[i] = T.min_value("float32") |
| v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) |
| v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) |
| argmax_v0[i] = v_argmax_v0 |
| argmax_v1[i] = v_argmax_v1 |
| |
| |
| def te_argmax_val_idx(): |
| def f_combine(x, y): |
| lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0]) |
| rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1]) |
| return lhs, rhs |
| |
| def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): |
| return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1) |
| |
| argmax = te.comm_reducer(f_combine, f_identity, name="argmax") |
| |
| m = te.var("m") |
| n = te.var("n") |
| val = te.placeholder((m, n), name="val", dtype="float32") |
| idx = te.placeholder((m, n), name="idx", dtype="int32") |
| k = te.reduce_axis((0, n), "k") |
| max_val, max_idx = te.compute( |
| (m,), lambda i: argmax((val[i, k], idx[i, k]), axis=k), name="argmax" |
| ) |
| return [val, idx, max_val, max_idx] |
| |
| |
| @T.prim_func |
| def tir_argmax_val_idx( |
| var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle |
| ) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| m = T.int32() |
| n = T.int32() |
| val = T.match_buffer(var_val, [m, n], dtype="float32") |
| idx = T.match_buffer(var_idx, [m, n], dtype="int32") |
| argmax_v0 = T.match_buffer(var_argmax_v0, [m], dtype="float32") |
| argmax_v1 = T.match_buffer(var_argmax_v1, [m], dtype="int32") |
| for i0, i1 in T.grid(m, n): |
| with T.block("argmax"): |
| i, k = T.axis.remap("SR", [i0, i1]) |
| T.reads(val[i, k], idx[i, k]) |
| T.writes(argmax_v0[i], argmax_v1[i]) |
| with T.init(): |
| argmax_v0[i] = T.min_value("float32") |
| argmax_v1[i] = T.int32(-1) |
| v_argmax_v0: T.float32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v0[i], val[i, k]) |
| v_argmax_v1: T.int32 = T.Select(argmax_v0[i] >= val[i, k], argmax_v1[i], idx[i, k]) |
| argmax_v0[i] = v_argmax_v0 |
| argmax_v1[i] = v_argmax_v1 |
| |
| |
| def test_argmax_idx_val(): |
| _check_workload(te_argmax_idx_val, tir_argmax_idx_val) |
| |
| |
| def test_argmax_val_idx(): |
| _check_workload(te_argmax_val_idx, tir_argmax_val_idx) |
| |
| |
| def test_int64_indices(): |
| n = te.var("n", "int64") |
| A = te.placeholder((n,), name="A") |
| B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") |
| prim_func = te.create_prim_func([A, B]) |
| loop = prim_func.body.block.body |
| assert loop.loop_var.dtype == "int64" |
| assert loop.min.dtype == "int64" |
| assert loop.extent.dtype == "int64" |
| |
| |
| def test_zero_dim_add(): |
| def te_func(): |
| a = te.placeholder((), name="a", dtype="int32") |
| b = te.placeholder((), name="b", dtype="int32") |
| c = te.compute(a.shape, lambda *i: a(*i) + b(*i), name="c") |
| return [a, b, c] |
| |
| @T.prim_func |
| def expected( |
| a: T.Buffer((), "int32"), |
| b: T.Buffer((), "int32"), |
| c: T.Buffer((), "int32"), |
| ) -> None: |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| with T.block("root"): |
| T.reads() |
| T.writes() |
| with T.block("c"): |
| vi = T.axis.spatial(1, 0) |
| T.reads(a[()], b[()]) |
| T.writes(c[()]) |
| c[()] = a[()] + b[()] |
| |
| _check_workload(te_func, expected) |
| |
| |
| def te_reshape(): |
| # The following is possible to be generated by TOPI. So we test this case. |
| A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A") |
| B = topi.reshape(A, (4, 2)) |
| return [A, B] |
| |
| |
| @T.prim_func |
| def tir_reshape( |
| A: T.Buffer((T.int64(2), T.int64(4)), "float32"), |
| T_reshape: T.Buffer((T.int64(4), T.int64(2)), "float32"), |
| ): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| for i0, i1 in T.grid(T.int64(4), T.int64(2)): |
| with T.block("T_reshape"): |
| ax0, ax1 = T.axis.remap("SS", [i0, i1]) |
| T.reads( |
| A[ |
| (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), |
| (ax0 * T.int64(2) + ax1) % T.int64(4), |
| ] |
| ) |
| T.writes(T_reshape[ax0, ax1]) |
| T_reshape[ax0, ax1] = A[ |
| (ax0 * T.int64(2) + ax1) % T.int64(8) // T.int64(4), |
| (ax0 * T.int64(2) + ax1) % T.int64(4), |
| ] |
| |
| |
| def test_reshape(): |
| _check_workload(te_reshape, tir_reshape, index_dtype_override="int64") |
| |
| |
| def te_resize2d_symbolic(): |
| oh = tir.Var("oh", "int64") |
| ow = tir.Var("ow", "int64") |
| roi = (0.0, 0.0, 0.0, 0.0) |
| A = te.placeholder((2, 3, 128, 128), "float32", name="A") |
| B = topi.image.resize2d( |
| A, |
| roi, |
| size=(oh, ow), |
| method="nearest_neighbor", |
| coordinate_transformation_mode="asymmetric", |
| rounding_method="round", |
| ) |
| return [A, B] |
| |
| |
| @T.prim_func |
| def tir_resize2d_symbolic( |
| A: T.Buffer((T.int64(2), T.int64(3), T.int64(128), T.int64(128)), "float32"), |
| var_resize: T.handle, |
| ): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| oh = T.int64() |
| ow = T.int64() |
| resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32") |
| for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), oh, ow): |
| with T.block("resize"): |
| v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| T.reads(A[v_i0, v_i1, T.int64(0) : T.int64(128), T.int64(0) : T.int64(128)]) |
| T.writes(resize[v_i0, v_i1, v_i2, v_i3]) |
| resize[v_i0, v_i1, v_i2, v_i3] = A[ |
| v_i0, |
| v_i1, |
| T.max( |
| T.min( |
| T.Cast( |
| "int64", |
| T.round( |
| T.float32(128) / T.Cast("float32", oh) * T.Cast("float32", v_i2), |
| dtype="float32", |
| ), |
| ), |
| T.int64(127), |
| ), |
| T.int64(0), |
| ), |
| T.max( |
| T.min( |
| T.Cast( |
| "int64", |
| T.round( |
| T.float32(128) / T.Cast("float32", ow) * T.Cast("float32", v_i3), |
| dtype="float32", |
| ), |
| ), |
| T.int64(127), |
| ), |
| T.int64(0), |
| ), |
| ] |
| |
| |
| def test_resize2d_symbolic(): |
| _check_workload(te_resize2d_symbolic, tir_resize2d_symbolic, index_dtype_override="int64") |
| |
| |
| def test_extern_with_explicit_buffer_access(): |
| def te_extern(): |
| A = te.placeholder((128, 128), name="A") |
| B = te.placeholder((128, 128), name="B") |
| P = te.placeholder((1,), name="P") |
| C = te.extern( |
| (128, 128), |
| [A, B, P], |
| lambda ins, outs: tvm.tir.call_extern( |
| "", "myfunc", ins[0].data, ins[1].data, outs[0].data, ins[2][0] |
| ), |
| name="C", |
| ) |
| return [A, B, P, C] |
| |
| @T.prim_func |
| def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handle): |
| T.func_attr({"global_symbol": "main", "tir.noalias": True}) |
| A = T.match_buffer(var_A, [128, 128], dtype="float32", offset_factor=1) |
| B = T.match_buffer(var_B, [128, 128], dtype="float32", offset_factor=1) |
| P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1) |
| C = T.match_buffer(var_C, [128, 128], dtype="float32", offset_factor=1) |
| with T.block("C"): |
| T.reads() |
| T.writes() |
| T.call_extern("myfunc", A.data, B.data, C.data, P[0], dtype="") |
| |
| _check_workload(te_extern, tir_extern) |
| |
| |
| def te_slice_with_var_input(): |
| idx = te.var("idx", dtype="int64") |
| m = te.var("m", dtype="int64") |
| n = te.var("n", dtype="int64") |
| tensor = te.placeholder((m, n), name="tensor") |
| slice0 = te.compute((idx, n), lambda i, j: tensor[i, j], name="slice") |
| return [tensor, idx, slice0] |
| |
| |
| @T.prim_func |
| def tir_slice_with_var_input(var_tensor: T.handle, idx: T.int64, var_slice: T.handle): |
| T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| m, n = T.int64(), T.int64() |
| tensor = T.match_buffer(var_tensor, (m, n)) |
| slice = T.match_buffer(var_slice, (idx, n)) |
| # with T.block("root"): |
| for i, j in T.grid(idx, n): |
| with T.block("slice"): |
| v_i = T.axis.spatial(idx, i) |
| v_j = T.axis.spatial(n, j) |
| T.reads(tensor[v_i, v_j]) |
| T.writes(slice[v_i, v_j]) |
| slice[v_i, v_j] = tensor[v_i, v_j] |
| |
| |
| def test_with_var_input(): |
| _check_workload(te_slice_with_var_input, tir_slice_with_var_input, index_dtype_override="int64") |
| |
| |
| def test_loop_aware_initial_value(): |
| """Test initial value aware of spatial iter position""" |
| |
| @T.prim_func |
| def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): |
| T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| a = T.match_buffer(var_a, (5, 5)) |
| b = T.match_buffer(var_b, (5,)) |
| sum_red = T.match_buffer(var_sum_red, (5,)) |
| for i, ax in T.grid(5, 5): |
| with T.block("sum_red"): |
| v_i, v_ax = T.axis.remap("SR", [i, ax]) |
| T.reads(b[v_i], a[v_i, v_ax]) |
| T.writes(sum_red[v_i]) |
| with T.init(): |
| sum_red[v_i] = b[v_i] |
| sum_red[v_i] = sum_red[v_i] + a[v_i, v_ax] |
| |
| def te_workload(): |
| data = te.placeholder((5, 5), "float32", "a") |
| init = te.placeholder((5,), "float32", "b") |
| ax = te.reduce_axis((0, 5), "ax") |
| sum_red = te.compute( |
| (5,), |
| lambda i: te.comm_reducer( |
| lambda x, y: x + y, |
| lambda t: init[i], |
| )(data[i, ax], axis=[ax]), |
| name="sum_red", |
| ) |
| return [data, init, sum_red] |
| |
| _check_workload(te_workload, tir_workload) |
| |
| |
| def test_loop_aware_reducer_combiner(): |
| """Test combiner aware of spatial iter position""" |
| |
| @T.prim_func |
| def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): |
| T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| a = T.match_buffer(var_a, (5, 5)) |
| b = T.match_buffer(var_b, (5,)) |
| sum_red = T.match_buffer(var_sum_red, (5,)) |
| for i, ax in T.grid(5, 5): |
| with T.block("sum_red"): |
| v_i = T.axis.spatial(5, i) |
| v_ax = T.axis.reduce(5, ax) |
| T.reads(a[v_i, 0:5]) |
| T.writes(sum_red[v_i]) |
| with T.init(): |
| sum_red[v_i] = T.float32(0.0) |
| sum_red[v_i] = T.if_then_else( |
| a[v_i, sum_red[v_i]] < a[v_i, v_ax], sum_red[v_i], T.Cast("float32", v_ax) |
| ) |
| |
| def te_workload(): |
| data = te.placeholder((5, 5), "float32", "a") |
| init = te.placeholder((5,), "float32", "b") |
| ax = te.reduce_axis((0, 5), "ax") |
| sum_red = te.compute( |
| (5,), |
| lambda i: te.comm_reducer( |
| lambda x, y: te.if_then_else(data[i, x] < y, x, ax), |
| lambda _: te.const(0, "float32"), |
| )(data[i, ax], axis=[ax]), |
| name="sum_red", |
| ) |
| return [data, init, sum_red] |
| |
| _check_workload(te_workload, tir_workload) |
| |
| |
| def test_adaptive_pooling_window(): |
| @T.prim_func |
| def tir_workload( |
| x: T.Buffer((1, 1024, 16, 40), "float32"), |
| adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"), |
| ): |
| T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| # fmt: off |
| adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30)) |
| for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): |
| with T.block("adaptive_pool_sum_1"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) |
| T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) |
| for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 % 3 * 10 + 40) // 30 + 1): |
| with T.block("adaptive_pool_sum"): |
| v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) |
| v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) |
| v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2) |
| v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3) |
| v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1]) |
| T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]) |
| T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1]) |
| with T.init(): |
| adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0) |
| adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1] |
| for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): |
| with T.block("adaptive_pool_avg"): |
| v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) |
| T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) |
| T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) |
| adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1)) |
| # fmt: on |
| |
| def te_workload(): |
| x = te.placeholder([1, 1024, 16, 40], "float32", "x") |
| y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg") |
| f = te.create_prim_func([x, y]) |
| return [x, y] |
| |
| _check_workload(te_workload, tir_workload) |
| |
| |
| def test_global_pool(): |
| # fix the issue-17938 |
| data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data") |
| op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW") |
| f = te.create_prim_func([data, op_output]) |
| assert f |
| |
| |
| def test_nested_reduce_domain_dependency(): |
| @T.prim_func |
| def tir_workload( |
| x: T.Buffer((8, 8, 8, 8, 8), "float32"), compute: T.Buffer((8, 8, 8), "float32") |
| ): |
| T.func_attr({"tir.noalias": True, "global_symbol": "main"}) |
| for i0, i1, i2 in T.grid(8, 8, 8): |
| with T.block("compute_2"): |
| v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) |
| T.reads(x[v_i0, v_i1, v_i2, 0:v_i1, 0 : v_i1 - 1]) |
| T.writes(compute[v_i0, v_i1, v_i2]) |
| for rv in range(v_i1): |
| with T.block("compute_1"): |
| v_i0_1 = T.axis.spatial((v_i0, v_i0 + 1), v_i0) |
| v_i1_1 = T.axis.spatial((v_i1, v_i1 + 1), v_i1) |
| v_i2_1 = T.axis.spatial((v_i2, v_i2 + 1), v_i2) |
| v_rv = T.axis.reduce(v_i1, rv) |
| T.reads(x[v_i0_1, v_i1_1, v_i2_1, v_rv, 0:v_rv]) |
| T.writes(compute[v_i0_1, v_i1_1, v_i2_1]) |
| with T.init(): |
| compute[v_i0_1, v_i1_1, v_i2_1] = T.float32(0.0) |
| for rv_1 in range(v_rv): |
| with T.block("compute"): |
| v_i0_2 = T.axis.spatial((v_i0_1, v_i0_1 + 1), v_i0_1) |
| v_i1_2 = T.axis.spatial((v_i1_1, v_i1_1 + 1), v_i1_1) |
| v_i2_2 = T.axis.spatial((v_i2_1, v_i2_1 + 1), v_i2_1) |
| v_rv_1 = T.axis.reduce((v_rv, v_rv + 1), v_rv) |
| v_rv_2 = T.axis.reduce(v_rv, rv_1) |
| T.reads(x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2]) |
| T.writes(compute[v_i0_2, v_i1_2, v_i2_2]) |
| with T.init(): |
| compute[v_i0_2, v_i1_2, v_i2_2] = T.float32(0.0) |
| compute[v_i0_2, v_i1_2, v_i2_2] = ( |
| compute[v_i0_2, v_i1_2, v_i2_2] |
| + x[v_i0_2, v_i1_2, v_i2_2, v_rv_1, v_rv_2] |
| ) |
| |
| def te_workload(): |
| x = te.placeholder([8, 8, 8, 8, 8], "float32", "x") |
| |
| def fcompute(*axes): |
| r1 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, axes[1])) |
| r2 = te.reduce_axis(tvm.ir.Range.from_min_extent(0, r1)) |
| all_axes = [*axes, r1, r2] |
| return te.sum(x(*all_axes), [r1, r2]) |
| |
| y = te.compute([8, 8, 8], fcompute) |
| f = te.create_prim_func([x, y]) |
| return [x, y] |
| |
| _check_workload(te_workload, tir_workload) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |