| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import pytest |
| import tvm |
| import tvm.testing |
| from tvm import te |
| from tvm.ir.module import IRModule |
| from tvm.script import tir as T |
| import numpy |
| |
| |
| def collect_visit(stmt, f): |
| ret = [] |
| tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) |
| return ret |
| |
| |
| def test_basic(): |
| n = te.size_var("n") |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| |
| T = te.compute((n,), lambda i: A[i] + B[i]) |
| s = te.create_schedule(T.op) |
| xo, xi = s[T].split(T.op.axis[0], factor=4) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"] |
| |
| assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_const_loop(): |
| n = 21 |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| |
| T = te.compute((n,), lambda i: A[i] + B[i]) |
| s = te.create_schedule(T.op) |
| xo, xi = s[T].split(T.op.axis[0], factor=4) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_no_unroll_loop(): |
| n = 21 |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| |
| T = te.compute((n,), lambda i: A[i] + B[i]) |
| s = te.create_schedule(T.op) |
| xo, xi = s[T].split(T.op.axis[0], factor=4) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| with tvm.transform.PassContext( |
| config={ |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| "no_unroll_loop_with_extent_one": True, |
| } |
| } |
| ): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| mod = tvm.tir.transform.Simplify()(mod) |
| stmt = tvm.tir.transform.RemoveNoOp()(mod)["main"].body |
| |
| assert sum(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.For))) == 4 |
| |
| |
| def test_multi_loop(): |
| ib = tvm.tir.ir_builder.create() |
| m = te.size_var("m") |
| n = te.size_var("n") |
| with ib.for_range(0, 4, "i") as i: |
| with ib.for_range(0, n, "j") as j: |
| with ib.for_range(0, m, "k") as k: |
| with ib.if_scope(ib.likely(i * m + j + k < n)): |
| ib.emit(tvm.tir.Evaluate(m)) |
| with ib.else_scope(): |
| ib.emit(tvm.tir.Evaluate(n)) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n, m], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_multi_if(): |
| ib = tvm.tir.ir_builder.create() |
| m = te.size_var("m") |
| n = te.size_var("n") |
| with ib.for_range(0, 4, "i") as i: |
| with ib.for_range(0, n, "j") as j: |
| with ib.for_range(0, m, "k") as k: |
| with ib.if_scope(ib.likely(i * m + j + k < n)): |
| ib.emit(tvm.tir.Evaluate(m)) |
| with ib.else_scope(): |
| ib.emit(tvm.tir.Evaluate(n)) |
| with ib.if_scope(ib.likely(i * m + j - k < n)): |
| ib.emit(tvm.tir.Evaluate(m)) |
| with ib.else_scope(): |
| ib.emit(tvm.tir.Evaluate(n)) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_thread_axis(): |
| m = te.size_var("m") |
| l = te.size_var("l") |
| A = te.placeholder((m, l), name="A") |
| B = te.compute((m, l), lambda i, j: A[i, j] + 3, name="B") |
| s = te.create_schedule(B.op) |
| |
| s[B].set_scope("shared") |
| num_thread = 16 |
| xo, xi = s[B].split(B.op.axis[0], 32) |
| xi0, xi1 = s[B].split(xi, nparts=num_thread) |
| s[B].bind(xi0, te.thread_axis("threadIdx.x")) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"] |
| |
| assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_vectorize(): |
| n = te.size_var("n") |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| bias = te.size_var("bias", dtype="float32") |
| scale = te.size_var("scale", dtype="float32") |
| C = te.compute(A.shape, lambda *i: A(*i) + B(*i) * scale + bias, name="C") |
| # schedule |
| s = te.create_schedule(C.op) |
| # create iter var and assign them tags. |
| num_thread = 32 |
| bx, x = s[C].split(C.op.axis[0], factor=num_thread * 4) |
| tx, x = s[C].split(x, nparts=num_thread) |
| _, x = s[C].split(x, factor=4) |
| s[C].bind(bx, te.thread_axis("blockIdx.x")) |
| s[C].bind(tx, te.thread_axis("threadIdx.x")) |
| s[C].vectorize(x) |
| stmt = tvm.lower(s, [A, B], name="main")["main"] |
| body = stmt.body.body.body.body |
| assert x.var.name not in str(body.condition) |
| assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) |
| |
| |
| def test_condition(): |
| ib = tvm.tir.ir_builder.create() |
| m = te.size_var("m") |
| n = te.size_var("n") |
| with ib.for_range(0, tvm.tir.truncdiv(n + 3, 4), "i") as i: |
| with ib.for_range(0, 4, "j") as j: |
| ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(i * 4 + j < n), m, n))) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) |
| |
| |
| def test_condition_EQ(): |
| ib = tvm.tir.ir_builder.create() |
| m = te.size_var("m") |
| n = te.size_var("n") |
| with ib.for_range(0, 10, "i") as i: |
| ib.emit(tvm.tir.Evaluate(tvm.tir.Select(ib.likely(tvm.tir.EQ(i, 5)), m, n))) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) |
| |
| |
| def test_thread_axis2(): |
| n = tvm.runtime.convert(4096) |
| m = te.size_var("m") |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") |
| s = te.create_schedule(C.op) |
| num_thread = 32 |
| bx, x = s[C].split(C.op.axis[0], factor=32) |
| tx, x = s[C].split(x, nparts=num_thread) |
| _, x = s[C].split(x, factor=m) |
| s[C].bind(bx, te.thread_axis("blockIdx.x")) |
| s[C].bind(tx, te.thread_axis("threadIdx.x")) |
| stmt = tvm.lower(s, [A, B], name="main")["main"] |
| for_body = stmt.body.body.body.body[0] |
| assert "threadIdx" not in str(for_body.extent) |
| |
| |
| def test_everything_during_deduction(): |
| m = te.size_var("m") |
| n = te.size_var("n") |
| ib = tvm.tir.ir_builder.create() |
| with ib.for_range(0, n, "i") as i: |
| with ib.for_range(0, 32, "j") as j: |
| with ib.if_scope(ib.likely(tvm.tir.truncdiv(i, j) < m)): |
| # this guard will produce everything during deduction |
| ib.emit(tvm.tir.Evaluate(m)) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([m, n], stmt).with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert isinstance(stmt.body.body, tvm.tir.IfThenElse) |
| |
| |
| def test_single_likely(): |
| n = 60 |
| A = te.placeholder((n,), name="A") |
| B = te.placeholder((n,), name="B") |
| |
| T = te.compute((n,), lambda i: A[i] + B[i]) |
| s = te.create_schedule(T.op) |
| x = T.op.axis[0] |
| xo, xi = s[T].split(x, factor=16) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_multi_likely(): |
| n = 94 |
| m = 62 |
| A = te.placeholder((n, m), name="A") |
| B = te.placeholder((n, m), name="B") |
| |
| T = te.compute((n, m), lambda i, j: A[i, j] + B[i, j]) |
| s = te.create_schedule(T.op) |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| x, y = T.op.axis |
| xo, xi = s[T].split(x, factor=16) |
| yo, yi = s[T].split(y, factor=16) |
| s[T].reorder(xo, yo, xi, yi) |
| |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_oneD_pool(): |
| m = te.size_var("m") |
| ib = tvm.tir.ir_builder.create() |
| # data = te.placeholder((16,), name = 'data') |
| data = ib.pointer("float32", name="A") |
| out = ib.pointer("float32", name="A") |
| with ib.for_range(0, 16, "ow") as ow: |
| with ib.for_range(0, 3, "kw") as kw: |
| with ib.if_scope(ib.likely(ow > 0)): |
| with ib.if_scope(ib.likely(ow < 15)): |
| out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) |
| with ib.for_range(0, 16, "ow") as ow: |
| with ib.for_range(0, 3, "kw") as kw: |
| with ib.if_scope(ib.likely(ow < 1)): |
| with ib.if_scope(ib.likely(kw > 0)): |
| out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) |
| with ib.for_range(0, 16, "ow") as ow: |
| with ib.for_range(0, 3, "kw") as kw: |
| with ib.if_scope(ib.likely(ow > 14)): |
| with ib.if_scope(ib.likely(kw < 2)): |
| out[ow] = tvm.te.max(out[ow], data[ow + kw - 1]) |
| |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr( |
| tvm.tir.PrimFunc([m, data, out], stmt).with_attr("global_symbol", "main") |
| ) |
| |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_cce_loop_1(): |
| ib = tvm.tir.ir_builder.create() |
| dtype = "float16" |
| n = 514 |
| m = 514 |
| _A = te.placeholder((n * m,), name="A") |
| Ab = tvm.tir.decl_buffer((n * m,), dtype, name="A") |
| A = ib.buffer_ptr(Ab) |
| _B = te.placeholder((n * m,), name="B") |
| Bb = tvm.tir.decl_buffer((n * m,), dtype, name="B") |
| B = ib.buffer_ptr(Bb) |
| # for i in 0 to n-1: |
| with ib.for_range(0, 11, name="i") as i: |
| with ib.for_range(0, 160, name="j") as j: |
| with ib.if_scope(ib.likely(((i * 160) + j) < 1600)): |
| A[(i + 1) * m + j + 1] = ( |
| B[(i) * m + j + 1] + B[(i + 1) * m + j + 1] + B[(i + 2) * m + j + 1] |
| ) |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr( |
| tvm.tir.PrimFunc([Ab, Bb], stmt).with_attr("global_symbol", "main") |
| ) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_cce_loop_2(): |
| ib = tvm.tir.ir_builder.create() |
| len = 112 |
| tile = 32 |
| loop = (len + tile - 1) // tile |
| with ib.for_range(0, loop, "i") as i: |
| head = i * tile |
| with ib.if_scope(ib.likely(head + tile > len)): |
| tail = len |
| ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail)) |
| with ib.else_scope(): |
| tail = head + tile |
| ib.emit(tvm.tir.call_extern("float32", "cce_intrisic", head, tail)) |
| |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_cce_loop_3(): |
| ib = tvm.tir.ir_builder.create() |
| loop1 = 4 |
| loop2 = 9998 |
| tile = 39991 |
| with ib.for_range(0, loop2, "i") as i: |
| with ib.for_range(0, loop1, "j") as j: |
| head1 = i |
| head2 = j |
| with ib.if_scope(ib.likely(head1 * loop1 + head2 < tile)): |
| ib.emit(tvm.tir.call_extern("float16", "cce_intrisic", head1)) |
| |
| stmt = ib.get() |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_conv_tiling(): |
| HSTR = WSTR = 1 |
| in_channel = 128 |
| kernel_height = kernel_width = 3 |
| out_channel = 64 |
| batch_size = 1 |
| in_height = in_width = 64 |
| out_height = out_width = in_height - kernel_height + 1 |
| data = te.placeholder((batch_size, in_channel, in_height, in_width), name="data") |
| kernel = te.placeholder((kernel_height, kernel_width, in_channel, out_channel), name="kernel") |
| ic = te.reduce_axis((0, in_channel), name="ic") |
| kh = te.reduce_axis((0, kernel_height), name="kh") |
| kw = te.reduce_axis((0, kernel_width), name="kw") |
| conv = te.compute( |
| (batch_size, out_channel, out_height, out_width), |
| lambda n, oc, oh, ow: te.sum( |
| data[n, ic, oh * HSTR + kh, ow * WSTR + kw] * kernel[kh, kw, ic, oc], axis=[ic, kh, kw] |
| ), |
| name="conv2d", |
| ) |
| s = te.create_schedule(conv.op) |
| |
| n, oc, oh, ow = conv.op.axis |
| oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt = tvm.te.schedule.ScheduleOps(s, bounds) |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt).with_attr("global_symbol", "main")) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| stmt = tvm.tir.transform.Simplify()(mod)["main"].body |
| |
| assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| |
| def test_multilevel_splitting_with_indivisble_factors(): |
| from tvm import topi |
| |
| A = te.placeholder((130,), dtype="float32") |
| B = topi.nn.relu(A) |
| s = te.create_schedule(B.op) |
| (y,) = s[B].op.axis |
| (yo, yi) = s[B].split(y, factor=8) |
| (yoo, yoi) = s[B].split(yo, factor=16) |
| s[B].reorder(yoo, yoi, yi) |
| s[B].unroll(yi) |
| |
| ## But this does the right thing. |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| lowered_body = tvm.lower(s, [A, B], name="x")["x"].body |
| |
| def visit_stmt(op): |
| return isinstance(op, tvm.tir.Max) |
| |
| num_max = collect_visit(lowered_body, visit_stmt) |
| assert num_max.count(True) == 10 |
| |
| |
| def test_double_splitting_with_indivisible_factors(): |
| m = 48 |
| dtype = "float32" |
| A = te.placeholder((m,), name="A", dtype=dtype) |
| C = te.compute((m,), lambda i: A[i], name="C") |
| D = te.compute((m,), lambda i: C[i], name="D") |
| |
| s = te.create_schedule(D.op) |
| co, ci = s[C].split(C.op.axis[0], factor=10) |
| do, di = s[D].split(D.op.axis[0], 32) |
| s[C].compute_at(s[D], do) |
| |
| target = "llvm" |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False) |
| func = tvm.build(f, target=target) |
| |
| top_produce = f["fadd1"].body |
| assert not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))) |
| |
| # check functional correctness of generated code |
| dev = tvm.device(target, 0) |
| a = tvm.nd.array( |
| numpy.ones( |
| m, |
| ).astype(dtype), |
| dev, |
| ) |
| c = tvm.nd.array( |
| numpy.zeros( |
| m, |
| ).astype(dtype), |
| dev, |
| ) |
| d = tvm.nd.array( |
| numpy.zeros( |
| m, |
| ).astype(dtype), |
| dev, |
| ) |
| func(a, c, d) |
| tvm.testing.assert_allclose(c.numpy(), a.numpy(), rtol=1e-5) |
| tvm.testing.assert_allclose(d.numpy(), a.numpy(), rtol=1e-5) |
| |
| |
| def test_simple_rfactor(): |
| K = 16 * 4 + 4 |
| k = te.reduce_axis((0, K), "k") |
| |
| A = te.placeholder((1, K), name="A") |
| |
| B = te.compute((1,), lambda b: te.sum(A[b, k], axis=k), name="B") |
| |
| s = te.create_schedule(B.op) |
| ko, _ = s[B].split(s[B].op.reduce_axis[0], 16) |
| BF = s.rfactor(B, ko, 0) |
| |
| s.normalize() |
| bounds = tvm.te.schedule.InferBound(s) |
| stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) |
| |
| mod1 = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt1).with_attr("global_symbol", "main")) |
| stmt1 = tvm.tir.transform.Simplify()(mod1)["main"].body |
| |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod2 = tvm.tir.transform.LoopPartition()(mod1) |
| stmt2 = tvm.tir.transform.Simplify()(mod2)["main"].body |
| |
| # make sure loop partition actually did something |
| assert not tvm.ir.structural_equal(stmt1.body, stmt2.body) |
| |
| |
| @T.prim_func |
| def partitioned_concat( |
| A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C: T.Buffer((32,), "float32") |
| ) -> None: |
| T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) |
| for i in T.serial(0, 16): |
| C[i] = A[i] |
| for i in T.serial(0, 16): |
| C[i + 16] = B[i + 16] |
| |
| |
| def test_explicit_partition_hint(): |
| A = te.placeholder((16,), name="A") |
| B = te.placeholder((16,), name="B") |
| C = te.compute((32,), lambda i: te.if_then_else(i < 16, A[i], B[i]), name="C") |
| s = te.create_schedule(C.op) |
| s.normalize() |
| s[C].pragma(s[C].op.axis[0], "loop_partition_hint", True) |
| mod = tvm.driver.build_module.schedule_to_module(s, [A, B, C], "main", None) |
| with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): |
| mod = tvm.tir.transform.StorageFlatten(64)(mod) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| mod = tvm.tir.transform.Simplify()(mod) |
| tvm.ir.assert_structural_equal(mod["main"], partitioned_concat) |
| |
| |
| def partition_from_scheduled_tir(prim_func, pass_cfg): |
| with tvm.transform.PassContext(config=pass_cfg): |
| mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main")) |
| mod = tvm.tir.transform.LowerOpaqueBlock()(mod) |
| mod = tvm.tir.transform.FlattenBuffer()(mod) |
| mod = tvm.tir.transform.LoopPartition()(mod) |
| mod = tvm.tir.transform.Simplify()(mod) |
| mod = tvm.tir.transform.RemoveNoOp()(mod) |
| return mod |
| |
| |
| @T.prim_func |
| def partitioned_concat_3( |
| placeholder: T.Buffer((1, 64, 28, 28), "int8"), |
| placeholder_1: T.Buffer((1, 32, 28, 28), "int8"), |
| placeholder_2: T.Buffer((1, 32, 28, 28), "int8"), |
| T_concat: T.Buffer((1, 128, 28, 28), "int8"), |
| ) -> None: |
| placeholder_flat = T.Buffer([50176], "int8", data=placeholder.data) |
| placeholder_1_flat = T.Buffer([25088], "int8", data=placeholder_1.data) |
| placeholder_2_flat = T.Buffer([25088], "int8", data=placeholder_2.data) |
| T_concat_flat = T.Buffer([100352], "int8", data=T_concat.data) |
| for i1, i2, i3 in T.grid(64, 28, 28): |
| T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] |
| for i1, i2, i3 in T.grid(32, 28, 28): |
| T_concat_flat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1_flat[i1 * 784 + i2 * 28 + i3] |
| for i1, i2, i3 in T.grid(32, 28, 28): |
| T_concat_flat[i1 * 784 + i2 * 28 + i3 + 75264] = placeholder_2_flat[i1 * 784 + i2 * 28 + i3] |
| |
| |
| @T.prim_func |
| def concat_func_3( |
| placeholder: T.Buffer((1, 64, 28, 28), "int8"), |
| placeholder_1: T.Buffer((1, 32, 28, 28), "int8"), |
| placeholder_2: T.Buffer((1, 32, 28, 28), "int8"), |
| T_concat: T.Buffer((1, 128, 28, 28), "int8"), |
| ) -> None: |
| placeholder_flat = T.Buffer([50176], "int8", data=placeholder.data) |
| placeholder_1_flat = T.Buffer([25088], "int8", data=placeholder_1.data) |
| placeholder_2_flat = T.Buffer([25088], "int8", data=placeholder_2.data) |
| T_concat_flat = T.Buffer([100352], "int8", data=T_concat.data) |
| for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): |
| for i2, i3 in T.grid(28, 28): |
| if 96 <= i1: |
| T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_2_flat[ |
| i1 * 784 + i2 * 28 + i3 - 75264 |
| ] |
| if 64 <= i1 and i1 < 96: |
| T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_1_flat[ |
| i1 * 784 + i2 * 28 + i3 - 50176 |
| ] |
| if i1 < 64: |
| T_concat_flat[i1 * 784 + i2 * 28 + i3] = placeholder_flat[i1 * 784 + i2 * 28 + i3] |
| |
| |
| def test_condition_mutually_exclusive(): |
| mod = partition_from_scheduled_tir( |
| concat_func_3, {"tir.LoopPartition": {"partition_const_loop": True}} |
| ) |
| tvm.ir.assert_structural_equal( |
| mod["main"], partitioned_concat_3.with_attr("global_symbol", "main") |
| ) |
| |
| |
| def test_loop_partition_unroll_hint(): |
| @T.prim_func |
| def main( |
| A_arg: T.Buffer((1, 3, 224, 224), "int8"), B_arg: T.Buffer((1, 224, 7, 16), "int8") |
| ) -> None: |
| A = T.Buffer(150528, "int8", data=A_arg.data) |
| B = T.Buffer(25088, "int8", data=B_arg.data) |
| for ax0 in T.serial( |
| 112, |
| annotations={"pragma_loop_partition_hint": True}, |
| ): |
| for ax1, ax2, ax3 in T.grid(224, 7, 16): |
| if 3 <= ax0 * 2 + ax2 and ax0 * 2 + ax2 < 227 and ax3 < 3: |
| B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 - 3] |
| |
| @T.prim_func |
| def partitioned_main( |
| A_arg: T.Buffer((1, 3, 224, 224), "int8"), B_arg: T.Buffer((1, 224, 7, 16), "int8") |
| ) -> None: |
| A = T.Buffer(150528, dtype="int8", data=A_arg.data) |
| B = T.Buffer(25088, dtype="int8", data=B_arg.data) |
| # body |
| for ax1, ax2, ax3 in T.grid(224, 7, 16): |
| if 3 <= ax2 and ax3 < 3: |
| B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 - 3] |
| for ax1, ax2, ax3 in T.grid(224, 7, 16): |
| if 1 <= ax2 and ax3 < 3: |
| B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 - 1] |
| for ax0, ax1, ax2, ax3 in T.grid(109, 224, 7, 16): |
| if ax3 < 3: |
| B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax0 * 2 + ax2 + 1] |
| for ax1, ax2, ax3 in T.grid(224, 7, 16): |
| if ax2 < 5 and ax3 < 3: |
| B[ax1 * 112 + ax2 * 16 + ax3] = A[ax3 * 50176 + ax1 * 224 + ax2 + 219] |
| |
| mod = partition_from_scheduled_tir( |
| main, |
| { |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| "unroll_loop_with_partition_hint_no_interval": True, |
| } |
| }, |
| ) |
| mod = tvm.tir.transform.UnrollLoop()(mod) |
| mod = tvm.tir.transform.RemoveNoOp()(mod) |
| mod = tvm.tir.transform.Simplify()(mod) |
| tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) |
| |
| |
| def test_loop_partition_recursive_unroll_hint(): |
| @T.prim_func |
| def main(): |
| placeholder_0_dm = T.decl_buffer([1, 32, 32, 16], dtype="int8") |
| for i3_0 in T.serial(5, annotations={"pragma_loop_partition_hint": 1}): |
| for i2_0 in T.serial(2, annotations={"pragma_loop_partition_hint": 1}): |
| pad_temp = T.decl_buffer([1, 16, 16, 16], dtype="int8") |
| for ax0, ax1, ax2 in T.grid(16, 16, 16): |
| if ( |
| 6 <= i2_0 * 4 + ax0 |
| and i2_0 * 4 + ax0 < 26 |
| and 6 <= i3_0 * 4 + ax1 |
| and i3_0 * 4 + ax1 < 26 |
| ): |
| pad_temp[ |
| 0, |
| i2_0 * 4 + ax0 - 6 + 6 - i2_0 * 4, |
| i3_0 * 4 + ax1 - 6 + 6 - i3_0 * 4, |
| ax2, |
| ] = placeholder_0_dm[ |
| 0, |
| i2_0 * 4 + ax0 - 6 - -6, |
| i3_0 * 4 + ax1 - 6 - -6, |
| ax2, |
| ] |
| |
| @T.prim_func |
| def partitioned_main(): |
| placeholder_0_dm = T.allocate([16384], "int8", "global") |
| placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm) |
| for i3_0 in T.unroll(2): |
| for i2_0 in T.unroll(2): |
| pad_temp = T.allocate([4096], "int8", "global") |
| pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp) |
| for ax0, ax1, ax2 in T.grid(16, 16, 16): |
| if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: |
| pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ |
| i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 |
| ] |
| for i2_0 in T.unroll(2): |
| pad_temp_2 = T.allocate([4096], "int8", "global") |
| pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2) |
| for ax0, ax1, ax2 in T.grid(16, 16, 16): |
| if 6 <= i2_0 * 4 + ax0: |
| pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ |
| i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 |
| ] |
| for i3_0 in T.unroll(2): |
| for i2_0 in T.unroll(2): |
| pad_temp_4 = T.allocate([4096], "int8", "global") |
| pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4) |
| for ax0, ax1, ax2 in T.grid(16, 16, 16): |
| if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: |
| pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ |
| i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 |
| ] |
| |
| mod = partition_from_scheduled_tir( |
| main, |
| { |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| "unroll_loop_with_partition_hint_no_interval": True, |
| } |
| }, |
| ) |
| tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) |
| |
| |
| def test_loop_partition_keep_loop_annotations(): |
| @T.prim_func |
| def before(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: |
| for i in T.serial( |
| 160, |
| annotations={"pragma_loop_partition_hint": True, "key": "value"}, |
| ): |
| if i < 10: |
| B[i] = A[i] + 1 |
| elif 10 <= i and i < 150: |
| B[i] = A[i] + 2 |
| else: |
| B[i] = A[i] + 3 |
| |
| @T.prim_func |
| def after(A: T.Buffer(160, "int32"), B: T.Buffer(160, "int32")) -> None: |
| for i in T.serial(10, annotations={"key": "value"}): |
| B[i] = A[i] + 1 |
| for i in T.serial(140, annotations={"key": "value"}): |
| B[i + 10] = A[i + 10] + 2 |
| for i in T.serial(10, annotations={"key": "value"}): |
| B[i + 150] = A[i + 150] + 3 |
| |
| mod = partition_from_scheduled_tir( |
| before, |
| { |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| } |
| }, |
| ) |
| tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) |
| |
| |
| def test_loop_partition_with_unit_loop_in_condition(): |
| @T.prim_func |
| def before( |
| placeholder: T.Buffer((50176,), "int8"), |
| placeholder_1: T.Buffer((25088,), "int8"), |
| placeholder_2: T.Buffer((25088,), "int8"), |
| T_concat: T.Buffer((100352,), "int8"), |
| ) -> None: |
| for k in range(1, annotations={"preserve_unit_loop": True}): |
| for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): |
| for i2, i3 in T.grid(28, 28): |
| if 96 <= k * 128 + i1: |
| T_concat[k * i1 * 784 + i2 * 28 + i3] = placeholder_2[ |
| i1 * 784 + i2 * 28 + i3 - 75264 |
| ] |
| if 64 <= k * 128 + i1 and k * 128 + i1 < 96: |
| T_concat[i1 * 784 + i2 * 28 + i3] = placeholder_1[ |
| i1 * 784 + i2 * 28 + i3 - 50176 |
| ] |
| if k * 128 + i1 < 64: |
| T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] |
| |
| @T.prim_func |
| def after( |
| placeholder: T.Buffer(50176, "int8"), |
| placeholder_1: T.Buffer(25088, "int8"), |
| placeholder_2: T.Buffer(25088, "int8"), |
| T_concat: T.Buffer(100352, "int8"), |
| ) -> None: |
| for _ in T.serial(1, annotations={"preserve_unit_loop": True}): |
| for i1, i2, i3 in T.grid(64, 28, 28): |
| T_concat[i1 * 784 + i2 * 28 + i3] = placeholder[i1 * 784 + i2 * 28 + i3] |
| for i1, i2, i3 in T.grid(32, 28, 28): |
| T_concat[i1 * 784 + i2 * 28 + i3 + 50176] = placeholder_1[i1 * 784 + i2 * 28 + i3] |
| for i1, i2, i3 in T.grid(32, 28, 28): |
| T_concat[i2 * 28 + i3] = placeholder_2[i1 * 784 + i2 * 28 + i3] |
| |
| mod = partition_from_scheduled_tir( |
| before, |
| { |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| } |
| }, |
| ) |
| tvm.ir.assert_structural_equal(mod["main"], after.with_attr("global_symbol", "main")) |
| |
| |
| @T.prim_func |
| def concat_func_single_point( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ) -> None: |
| for i0 in range(28): |
| for i1 in T.serial(128, annotations={"pragma_loop_partition_hint": 1}): |
| if i1 > 63: |
| T_concat[i0, i1] = placeholder[i0, i1 - 64] |
| elif i1 == 63: |
| T_concat[i0, i1] = placeholder_1[i0, i1 - 63] |
| else: |
| T_concat[i0, i1] = placeholder_2[i0, i1] |
| |
| |
| @T.prim_func |
| def expected_partitioned_concat_single_point( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ): |
| for i0 in range(28): |
| T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) |
| for i1 in range(63): |
| placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) |
| T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1] |
| placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) |
| T_concat_1[i0 * 128 + 63] = placeholder_1_1[i0] |
| for i1 in range(64): |
| placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) |
| T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] |
| |
| |
| @T.prim_func |
| def concat_func_start_point_equality( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ) -> None: |
| for i0 in range(28): |
| for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}): |
| if i1 == 0: |
| # Special case for i1 == 0 |
| T_concat[i0, i1] = placeholder_1[i0, 0] |
| elif i1 < 64: |
| # Normal case for i1 in [1, 63] |
| T_concat[i0, i1] = placeholder_2[i0, i1] |
| else: |
| # Case for i1 in [64, 127] |
| T_concat[i0, i1] = placeholder[i0, i1 - 64] |
| |
| |
| @T.prim_func |
| def concat_func_start_point_equality_expected( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ): |
| for i0 in range(28): |
| T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) |
| placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) |
| T_concat_1[i0 * 128] = placeholder_1_1[i0] |
| for i1 in range(63): |
| placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) |
| T_concat_1[i0 * 128 + i1 + 1] = placeholder_2_1[i0 * 63 + i1 + 1] |
| for i1 in range(64): |
| placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) |
| T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] |
| |
| |
| @T.prim_func |
| def concat_func_end_point_equality( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ) -> None: |
| for i0 in range(28): |
| for i1 in range(128, annotations={"pragma_loop_partition_hint": 1}): |
| if i1 == 127: |
| # Explicit equality check for the end point i1 == 127 |
| T_concat[i0, i1] = placeholder_1[i0, 0] |
| elif i1 >= 64: |
| # Case for i1 in [64, 126] |
| T_concat[i0, i1] = placeholder[i0, i1 - 64] |
| else: |
| # Case for i1 in [0, 63] |
| T_concat[i0, i1] = placeholder_2[i0, i1] |
| |
| |
| @T.prim_func |
| def concat_func_end_point_equality_expected( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 63), "int8"), |
| T_concat: T.Buffer((28, 128), "int8"), |
| ): |
| for i0 in range(28): |
| T_concat_1 = T.Buffer((3584,), "int8", data=T_concat.data) |
| for i1 in range(64): |
| placeholder_2_1 = T.Buffer((1764,), "int8", data=placeholder_2.data) |
| T_concat_1[i0 * 128 + i1] = placeholder_2_1[i0 * 63 + i1] |
| for i1 in range(63): |
| placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) |
| T_concat_1[i0 * 128 + i1 + 64] = placeholder_3[i0 * 64 + i1] |
| placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) |
| T_concat_1[i0 * 128 + 127] = placeholder_1_1[i0] |
| |
| |
| @T.prim_func |
| def concat_func_edge_equalities( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 1), "int8"), |
| T_concat: T.Buffer((28, 66), "int8"), |
| ) -> None: |
| for i0 in range(28): |
| for i1 in range( |
| 66, annotations={"pragma_loop_partition_hint": 1} |
| ): # Loop from 0 to 65 inclusive |
| if i1 == 0: |
| # Handle equality at the start of the range: i1 == 0 |
| T_concat[i0, i1] = placeholder_2[i0, 0] |
| elif i1 == 65: |
| # Handle equality at the end of the range: i1 == 65 |
| T_concat[i0, i1] = placeholder_1[i0, 0] |
| else: |
| # Copying from placeholder (from 0 to 63) |
| T_concat[i0, i1] = placeholder[i0, i1 - 1] |
| |
| |
| @T.prim_func |
| def concat_func_edge_equalities_expected( |
| placeholder: T.Buffer((28, 64), "int8"), |
| placeholder_1: T.Buffer((28, 1), "int8"), |
| placeholder_2: T.Buffer((28, 1), "int8"), |
| T_concat: T.Buffer((28, 66), "int8"), |
| ): |
| for i0 in range(28): |
| T_concat_1 = T.Buffer((1848,), "int8", data=T_concat.data) |
| placeholder_2_1 = T.Buffer((28,), "int8", data=placeholder_2.data) |
| T_concat_1[i0 * 66] = placeholder_2_1[i0] |
| for i1 in range(64): |
| placeholder_3 = T.Buffer((1792,), "int8", data=placeholder.data) |
| T_concat_1[i0 * 66 + i1 + 1] = placeholder_3[i0 * 64 + i1] |
| placeholder_1_1 = T.Buffer((28,), "int8", data=placeholder_1.data) |
| T_concat_1[i0 * 66 + 65] = placeholder_1_1[i0] |
| |
| |
| @T.prim_func |
| def concat_five_buffers_with_equalities( |
| buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0 |
| buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63 |
| buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64 |
| buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128 |
| buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129 |
| T_concat: T.Buffer((28, 129), "int8"), |
| ) -> None: |
| for i0 in range(28): |
| for i1 in range(130, annotations={"pragma_loop_partition_hint": 1}): |
| if i1 == 0: |
| T_concat[i0, i1] = buffer_a[i0, 0] |
| elif i1 == 64: |
| T_concat[i0, i1] = buffer_c[i0, 0] |
| elif i1 == 129: |
| T_concat[i0, i1] = buffer_e[i0, 0] |
| elif i1 < 64: |
| T_concat[i0, i1] = buffer_b[i0, i1 - 1] |
| else: # i1 > 64 and i1 < 128 |
| T_concat[i0, i1] = buffer_d[i0, i1 - 65] |
| |
| |
| @T.prim_func |
| def concat_five_buffers_with_equalities_expected( |
| buffer_a: T.Buffer((28, 1), "int8"), # Used for i1 == 0 |
| buffer_b: T.Buffer((28, 63), "int8"), # Fills i1 from 1 to 63 |
| buffer_c: T.Buffer((28, 1), "int8"), # Used for i1 == 64 |
| buffer_d: T.Buffer((28, 63), "int8"), # Fills i1 from 65 to 128 |
| buffer_e: T.Buffer((28, 1), "int8"), # Used for i1 == 129 |
| T_concat: T.Buffer((28, 129), "int8"), |
| ): |
| for i0 in range(28): |
| T_concat_1 = T.Buffer((3612,), "int8", data=T_concat.data) |
| buffer_a_1 = T.Buffer((28,), "int8", data=buffer_a.data) |
| T_concat_1[i0 * 129] = buffer_a_1[i0] |
| for i1 in range(63): |
| buffer_b_1 = T.Buffer((1764,), "int8", data=buffer_b.data) |
| T_concat_1[i0 * 129 + i1 + 1] = buffer_b_1[i0 * 63 + i1] |
| buffer_c_1 = T.Buffer((28,), "int8", data=buffer_c.data) |
| T_concat_1[i0 * 129 + 64] = buffer_c_1[i0] |
| for i1 in range(64): |
| buffer_d_1 = T.Buffer((1764,), "int8", data=buffer_d.data) |
| T_concat_1[i0 * 129 + i1 + 65] = buffer_d_1[i0 * 63 + i1] |
| buffer_e_1 = T.Buffer((28,), "int8", data=buffer_e.data) |
| T_concat_1[i0 * 129 + 129] = buffer_e_1[i0] |
| |
| |
| @pytest.mark.parametrize( |
| "origin,expected", |
| [ |
| (concat_func_single_point, expected_partitioned_concat_single_point), |
| (concat_func_start_point_equality, concat_func_start_point_equality_expected), |
| (concat_func_end_point_equality, concat_func_end_point_equality_expected), |
| (concat_func_edge_equalities, concat_func_edge_equalities_expected), |
| (concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected), |
| ], |
| ) |
| def test_single_point_partition(origin, expected): |
| origin = origin.with_attr({"global_symbol": "main"}) |
| expected = expected.with_attr({"global_symbol": "main"}) |
| mod = partition_from_scheduled_tir( |
| origin, |
| { |
| "tir.LoopPartition": { |
| "partition_const_loop": True, |
| "unroll_loop_with_partition_hint_no_interval": True, |
| } |
| }, |
| ) |
| tvm.ir.assert_structural_equal(mod["main"], expected) |
| |
| |
| if __name__ == "__main__": |
| tvm.testing.main() |