| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \brief Replace copy from global to shared with async copy |
| * \file inject_ptx_async_copy.cc |
| */ |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/builtin.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/op.h> |
| #include <tvm/tir/stmt_functor.h> |
| #include <tvm/tir/transform.h> |
| |
| #include "../ir/buffer_common.h" |
| #include "storage_access.h" |
| #include "tvm/tir/stmt.h" |
| |
| namespace tvm { |
| namespace tir { |
| |
| class PTXAsyncCopyInjector : public StmtMutator { |
| public: |
| Stmt VisitStmt_(const AttrStmtNode* attr) { |
| if (attr->attr_key == tir::attr::async_scope) { |
| ICHECK(in_async == false) << "Nested async scopes not supported"; |
| in_async = true; |
| auto body = this->VisitStmt(attr->body); |
| in_async = false; |
| return body; |
| } |
| return StmtMutator::VisitStmt_(attr); |
| } |
| |
| Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, |
| PrimExpr predicate_value = PrimExpr()) { |
| if (load->buffer.scope() == "global") { |
| ICHECK(load->indices.size() == 1 && store->indices.size() == 1); |
| ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); |
| |
| const int indices_lanes = load->indices[0]->dtype.lanes(); |
| const int bytes = indices_lanes * load->buffer->dtype.bytes(); |
| |
| if (bytes == 4 || bytes == 8 || bytes == 16) { |
| auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); |
| auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); |
| ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) |
| << "Both store and load buffer should have a pointer type annotation."; |
| |
| int index_factor = 1; |
| if (dst_elem_type.value() != src_elem_type.value()) { |
| // The only case where src and dst have different dtypes is when the dst shared memory |
| // is a byte buffer generated by merging dynamic shared memory. |
| ICHECK(store->buffer.scope() == "shared.dyn"); |
| ICHECK(dst_elem_type.value() == DataType::UInt(8)); |
| // BufferStore/Load have the "pointer reinterpret" semantics according to their |
| // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, |
| // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; |
| // To replace BufferStore/Load with cp.async, we need to multiply the store index by |
| // the byte size of the "value" dtype, to get the correct offset into the byte buffer. |
| index_factor = src_elem_type->bytes(); |
| } |
| |
| if (indices_lanes == 1) { |
| auto src_offset = load->indices[0]; |
| auto dst_offset = store->indices[0]; |
| ffi::Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), |
| load->buffer->data, src_offset, PrimExpr(bytes)}; |
| // use arguments size to indicate whether or not to use predicated cp.async |
| if (predicated) { |
| args.push_back(predicate_value); |
| } |
| return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); |
| } |
| |
| // Predicated load don't support vectorized indexing. |
| if (!predicated) { |
| // Only some vectorized indexing patterns are supported for now. |
| auto src_offset = [=]() -> PrimExpr { |
| if (load->indices[0]->IsInstance<RampNode>()) { |
| return load->indices[0].as<RampNode>()->base; |
| } |
| return PrimExpr(); |
| }(); |
| |
| auto dst_offset = [=]() -> PrimExpr { |
| if (store->indices[0].as<RampNode>()) { |
| return store->indices[0].as<RampNode>()->base; |
| } else if (store->indices[0].as<AddNode>()) { |
| // The case where the dst buffer is a byte buffer generated by merging dynamic |
| // shared memory. |
| // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] |
| auto* add = store->indices[0].as<AddNode>(); |
| if (!add->a->IsInstance<RampNode>()) return PrimExpr(); |
| if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr(); |
| return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value); |
| } |
| return PrimExpr(); |
| }(); |
| if (src_offset.defined() && dst_offset.defined()) { |
| return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), |
| {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), |
| load->buffer->data, src_offset, PrimExpr(bytes)})); |
| } |
| } else { |
| // Only some vectorized indexing patterns are supported for now. |
| auto src_offset = [=]() -> PrimExpr { |
| if (load->indices[0]->IsInstance<RampNode>()) { |
| return load->indices[0].as<RampNode>()->base; |
| } |
| return PrimExpr(); |
| }(); |
| |
| auto dst_offset = [=]() -> PrimExpr { |
| if (store->indices[0].as<RampNode>()) { |
| return store->indices[0].as<RampNode>()->base; |
| } else if (store->indices[0].as<AddNode>()) { |
| // The case where the dst buffer is a byte buffer generated by merging dynamic |
| // shared memory. |
| // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] |
| auto* add = store->indices[0].as<AddNode>(); |
| if (!add->a->IsInstance<RampNode>()) return PrimExpr(); |
| if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr(); |
| return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value); |
| } |
| return PrimExpr(); |
| }(); |
| |
| if (src_offset.defined() && dst_offset.defined()) { |
| return Evaluate( |
| Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), |
| {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), |
| load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); |
| } |
| } |
| } |
| } |
| return StmtMutator::VisitStmt_(store); |
| } |
| |
| Stmt VisitStmt_(const BufferStoreNode* store) { |
| if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { |
| if (auto* load = store->value.as<BufferLoadNode>()) { |
| return InjectPTX(load, store); |
| } else if (auto* call = store->value.as<CallNode>()) { |
| // tir.if_then_else is a call to tir::builtin::if_then_else() |
| if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { |
| if (auto* load = call->args[1].as<BufferLoadNode>()) { |
| // Only default value of 0 is supported since 0 is the default value used by cp.async |
| // ptx. @see section 9.7.8.22.3. of |
| // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations |
| bool else_value_is_zero = false; |
| if (auto* b = call->args[2].as<BroadcastNode>()) { |
| if (auto* f = b->value.as<FloatImmNode>()) { |
| else_value_is_zero = f->value == 0.0f; |
| } |
| } |
| if (auto* f = call->args[2].as<FloatImmNode>()) { |
| else_value_is_zero = f->value == 0.0f; |
| } |
| if (else_value_is_zero) { |
| return InjectPTX(load, store, true, call->args[0]); |
| } |
| } |
| } |
| } |
| } |
| return StmtMutator::VisitStmt_(store); |
| } |
| |
| private: |
| bool in_async{false}; |
| }; |
| |
| namespace transform { |
| |
| Pass InjectPTXAsyncCopy() { |
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
| auto* n = f.CopyOnWrite(); |
| n->body = PTXAsyncCopyInjector()(n->body); |
| return f; |
| }; |
| return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); |
| } |
| |
| } // namespace transform |
| |
| } // namespace tir |
| } // namespace tvm |