blob: d3a2001c6ec58fe9dc8451761e62260fe55615f9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_storage_access.cc
* \brief Lower the special device storage access.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target_info.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../../runtime/thread_storage_scope.h"
#include "ir_utils.h"
namespace tvm {
namespace tir {
using runtime::StorageRank;
using runtime::StorageScope;
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode* op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (scope.tag.length() != 0 && scope.tag != ".dyn") {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined()) << "Cannot find memory info of " << scope.to_string();
ICHECK(storage_info_.find(op->buffer_var.get()) == storage_info_.end())
<< "Double allocation of " << scope.to_string();
storage_info_[op->buffer_var.get()] = info;
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
if (info->head_address.defined()) {
return LetStmt(op->buffer_var, info->head_address, op->body);
} else {
return op->body;
}
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Stmt VisitStmt_(const DeclBufferNode* op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto it = storage_info_.find(node->buffer->data.get());
it != storage_info_.end() && !it->second->head_address.defined()) {
return node->body;
} else {
return node;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
ICHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.defined()) {
return MakeTaggedAccessPtr(op->dtype, buffer_var, dtype, offset, it->second);
}
ICHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type, Var buffer_var, DataType dtype, PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
ICHECK(info->head_address.defined()) << buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
ICHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type, analyzer_.Simplify(
offset / make_const(offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage scope of each buffer
std::unordered_map<const VarNode*, MemoryInfo> storage_info_;
// analyzer
arith::Analyzer analyzer_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) { return StorageAccessInfoLower()(std::move(stmt)); }
namespace transform {
Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.LowerDeviceStorageAccessInfo", LowerDeviceStorageAccessInfo);
}
} // namespace transform
} // namespace tir
} // namespace tvm