blob: 1b7bf14c38ae9e7586311d419d9d9577d405e584 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_async_dma.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/arith/bound.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <optional>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer {
public:
explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer)
: IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {}
// TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend
Stmt VisitStmt_(const ForNode* loop) final {
// if for loop is not within async_commit_queue_scope
if (!async_queue_id_.has_value()) {
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}
// if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior
std::optional<tvm::tir::MemCpyDetails> mem_copy =
IdentifyMemCpy(ffi::GetRef<For>(loop), analyzer_);
if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 ||
mem_copy->source->region.size() != 1) {
return arith::IRMutatorWithAnalyzer::VisitStmt_(loop);
}
// now that we are about to perform the `copy` transform
// save queue ID for inspection in `wait` transform
// and, increment the number of DMA copies in the group
queue_ids_.insert(async_queue_id_.value());
dmas_in_group_++;
tvm::PrimExpr src_min = mem_copy->source->region[0]->min;
tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min;
tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent;
auto src = BufferLoad(mem_copy->source->buffer, {src_min});
auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min});
return Evaluate(
Call(DataType::Int(32), builtin::dma_copy(),
{async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}),
Call(DataType::Handle(), builtin::address_of(), {src}),
dst_extent * src->dtype.bytes(), dma_bypass_cache_}));
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
// populate analyzer knowledge of loop iterators
auto previsit = arith::IRMutatorWithAnalyzer::VisitStmt_(op);
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
// attr [0] "async_wait_inflight_count" = 0;
//
// To this:
// @tir.dma_wait(
// 0, /* queue id */
// 0, /* in flight count */
// dtype=int32
// )
if (op->attr_key == tir::attr::async_wait_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;
// abort if we have not seen this queue ID in `copy` transform
if (queue_ids_.find(queue_id) == queue_ids_.end()) {
DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
"`async_wait_queue_scope` transform has not been previously observed in the "
"`async_commit_queue_scope` transform";
return previsit;
}
auto async_wait = op->body.as<AttrStmtNode>();
if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
"`async_wait_inflight_count`";
return previsit;
}
auto call_dma_wait =
Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));
// concatenate the call with the body and return
return SeqStmt({call_dma_wait, arith::IRMutatorWithAnalyzer::VisitStmt(async_wait->body)});
// Convert this, for example:
// attr [0] "async_commit_queue_scope" = 0;
// attr [0] "async_scope" = 1;
// for (ax0: int32, 0, 128) {
// A_global[ax0] = A[ax0]
// }
//
// To this:
// @tir.dma_copy(
// 0, /* queue id */
// @tir.address_of(A_global[0], dtype=handle),
// @tir.address_of(A[0], dtype=handle),
// 128, /* size */
// dtype=int32
// )
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
async_queue_id_ = queue_id_node->value;
auto result = arith::IRMutatorWithAnalyzer::VisitStmt_(op);
if (dmas_in_group_ > 1) {
auto call_dma_start_group = Evaluate(
Call(DataType::Int(32), builtin::dma_start_group(), {async_queue_id_.value()}));
auto call_dma_end_group =
Evaluate(Call(DataType::Int(32), builtin::dma_end_group(), {async_queue_id_.value()}));
result = SeqStmt({call_dma_start_group, result, call_dma_end_group});
}
async_queue_id_ = std::nullopt;
dmas_in_group_ = 0;
return result;
}
return arith::IRMutatorWithAnalyzer::VisitStmt_(op);
}
private:
int dmas_in_group_ = 0;
std::set<int> queue_ids_;
std::optional<int> async_queue_id_ = std::nullopt;
bool dma_bypass_cache_;
ffi::Map<Var, Range> input_iters = ffi::Map<Var, Range>();
};
namespace transform {
Pass LowerAsyncDMA() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
arith::Analyzer analyzer;
bool dma_bypass_cache =
ctx->GetConfig<Bool>("tir.experimental_dma_bypass_cache", Bool(false)).value();
fptr->body = AsyncDMALowerer(dma_bypass_cache, &analyzer)(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.LowerAsyncDMA", LowerAsyncDMA);
}
} // namespace transform
} // namespace tir
} // namespace tvm