blob: 0e9820aa659e41318ea1ea1b3521c54ce0317ea4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../ir/buffer_common.h"
#include "storage_access.h"
#include "tvm/tir/stmt.h"
namespace tvm {
namespace tir {
class PTXAsyncCopyInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* attr) {
if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
auto body = this->VisitStmt(attr->body);
in_async = false;
return body;
}
return StmtMutator::VisitStmt_(attr);
}
Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());
const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();
if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type annotation.";
int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type->bytes();
}
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
ffi::Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated cp.async
if (predicated) {
args.push_back(predicate_value);
}
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args));
}
// Predicated load don't support vectorized indexing.
if (!predicated) {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
} else {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();
auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();
if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes), predicate_value}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
Stmt VisitStmt_(const BufferStoreNode* store) {
if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) {
if (auto* load = store->value.as<BufferLoadNode>()) {
return InjectPTX(load, store);
} else if (auto* call = store->value.as<CallNode>()) {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) {
if (auto* load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value used by cp.async
// ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto* b = call->args[2].as<BroadcastNode>()) {
if (auto* f = b->value.as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}
}
if (auto* f = call->args[2].as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}
if (else_value_is_zero) {
return InjectPTX(load, store, true, call->args[0]);
}
}
}
}
}
return StmtMutator::VisitStmt_(store);
}
private:
bool in_async{false};
};
namespace transform {
Pass InjectPTXAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = PTXAsyncCopyInjector()(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
}
} // namespace transform
} // namespace tir
} // namespace tvm