| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import tvm |
| from tvm import te |
| |
| # register the ops |
| tvm.ir.register_op_attr("tir.cop.coproc_sync", "TGlobalSymbol", "coproc_sync") |
| tvm.ir.register_op_attr("tir.cop.coproc_read_barrier", "TGlobalSymbol", "coproc_readb") |
| tvm.ir.register_op_attr("tir.cop.coproc_write_barrier", "TGlobalSymbol", "coproc_writeb") |
| tvm.ir.register_op_attr("tir.cop.coproc_dep_push", "TGlobalSymbol", "coproc_dep_push") |
| tvm.ir.register_op_attr("tir.cop.coproc_dep_pop", "TGlobalSymbol", "coproc_dep_pop") |
| |
| |
| def test_coproc_sync(): |
| @tvm.register_func("tvm.info.mem.global.cache") |
| def meminfo_cache(): |
| return tvm.ir.make_node( |
| "MemoryInfo", |
| unit_bits=8, |
| max_simd_bits=32, |
| max_num_bits=128, |
| head_address=tvm.tir.call_extern("handle", "global_cache"), |
| ) |
| |
| ib = tvm.tir.ir_builder.create() |
| n = te.size_var("n") |
| cp = te.thread_axis((0, 1), "cop") |
| A = ib.allocate("float32", 128, name="A", scope="global.cache") |
| with ib.for_range(0, n, name="i") as i: |
| A[i] = A[i] + 1 |
| with ib.for_range(0, 8, name="k") as k: |
| with ib.for_range(0, 10, name="j") as j: |
| ib.scope_attr(cp, "coproc_scope", 1) |
| A[j] = A[j + k * 10] + 2 |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) |
| stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body |
| |
| body = stmt.body.body |
| blist = tvm.tir.stmt_list(body) |
| |
| assert blist[1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_read_barrier")) |
| assert blist[1].value.args[3].value == 80 |
| assert blist[-2].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_sync")) |
| assert blist[-1].value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_write_barrier")) |
| assert blist[-1].value.args[3].value == 10 |
| |
| |
| def test_coproc_sync2(): |
| ib = tvm.tir.ir_builder.create() |
| n = te.size_var("n") |
| cp = te.thread_axis((0, 1), "cop") |
| ty = te.thread_axis("cthread") |
| A = ib.allocate("float32", 128, name="A") |
| ib.scope_attr(ty, "virtual_thread", 2) |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 2) |
| A[ty] = 0.0 |
| with ib.for_range(0, n, name="i") as i: |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 1) |
| A[ty] = 1.0 |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 2) |
| A[ty] = 1.0 |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) |
| stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body |
| |
| |
| def test_coproc_sync3(): |
| def __check_list(tvm_array, py_list): |
| for ti, li in zip(tvm_array, py_list): |
| if ti.value != li: |
| return False |
| return True |
| |
| ib = tvm.tir.ir_builder.create() |
| n = te.size_var("n") |
| cp = te.thread_axis((0, 1), "cop") |
| A = ib.allocate("float32", 128, name="A", scope="global.cache") |
| with ib.for_range(0, n, name="i") as i: |
| with ib.for_range(0, n, name="i") as j: |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 1) |
| A[i] = 1.0 |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 2) |
| A[i] = 1.0 |
| with ib.new_scope(): |
| ib.scope_attr(cp, "coproc_scope", 3) |
| A[0] = 0.0 |
| |
| stmt = ib.get() |
| |
| mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) |
| stmt = tvm.tir.transform.CoProcSync()(mod)["main"].body |
| |
| slist = tvm.tir.stmt_list(stmt[0].body) |
| push_st = slist[2] |
| slist = tvm.tir.stmt_list(slist[-1]) |
| pop_st = slist[0].body[0] |
| |
| assert push_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_push")) |
| assert __check_list(push_st.value.args, [2, 3]) |
| assert pop_st.value.op.same_as(tvm.ir.Op.get("tir.cop.coproc_dep_pop")) |
| assert __check_list(pop_st.value.args, [2, 3]) |
| |
| |
| if __name__ == "__main__": |
| test_coproc_sync() |
| test_coproc_sync2() |
| test_coproc_sync3() |