blob: a82de34716c88873ee1971eb725fc4f541bef5f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file verify_memory.cc
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/transform.h>
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace tir {
namespace {
/*!
* \brief Verify if memory accesses are legal.
*
* In the case that tgt is cuda, if workload is not bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal.
*
* This pass performs such verification by checking if all
* memory accesses are bound with threads when device type is GPU.
*/
class MemoryAccessVerifier final : protected StmtExprVisitor {
public:
/// Special member functions
//@{
explicit MemoryAccessVerifier(PrimFunc f, int device_type) : func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default;
MemoryAccessVerifier(const MemoryAccessVerifier&) = delete;
MemoryAccessVerifier(MemoryAccessVerifier&&) = delete;
MemoryAccessVerifier& operator=(const MemoryAccessVerifier&) = delete;
MemoryAccessVerifier& operator=(MemoryAccessVerifier&&) = delete;
//@}
/// Interface to perform memory access verification
void Run() {
if (!IsGPUDevice(dev_type_)) return;
StmtExprVisitor::VisitStmt(func_->body);
}
/// Verification result
std::vector<ffi::String> Errors() const { return errs_; }
protected:
/// Visitor implementation
//@{
void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }
void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }
void VisitStmt_(const LetStmtNode* op) final {
// Book keep definitions
defs_[op->var.get()] = op->value;
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) final {
if (!InThreadEnv() &&
(op->attr_key == attr::thread_extent || op->attr_key == attr::pipeline_exec_scope)) {
EnterThreadEnv();
StmtExprVisitor::VisitStmt_(op);
ExitThreadEnv();
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitExpr_(const BufferLoadNode* op) final {
HandleLoadStoreToVariable(op->buffer->data);
return StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
HandleLoadStoreToVariable(op->buffer->data);
return StmtExprVisitor::VisitStmt_(op);
}
//@}
/// Check if the value of a Variable comes from function argument.
bool IsFromFunctionArgs(const VarNode* var) const {
const VarNode* V = var;
for (auto kv : func_->buffer_map) {
if (V == kv.second->data.get()) return true;
}
while (true) {
// Variable is from function args. Return true.
if (V == func_->params[0].get()) return true;
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
const auto& iter = defs_.find(V);
if (iter == defs_.end()) return false;
const CallNode* C = iter->second.as<const CallNode>();
if (!C || !C->op.same_as(builtin::tvm_struct_get())) return false;
V = C->args[0].as<VarNode>();
}
return false;
}
/// Handle memory access to a Variable
void HandleLoadStoreToVariable(const Var& var) {
// We skip the access within thread env.
if (InThreadEnv()) return;
// We only handle the variable from function argument.
// If it does not come from args, then it could be allocated internally,
// it may possibly be in host or device address space.
// We do not handle this case, and skip it conservatively.
if (!IsFromFunctionArgs(var.get())) return;
// The verification fails in this case.
std::stringstream s;
s << "Variable `" << var
<< "` is directly accessed by host memory (it is not contained in a thread environment or in "
"the function arguments.";
errs_.push_back(s.str());
}
/// Status getter/setter
//@{
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
//@}
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static bool IsGPUDevice(int dev_type) {
return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
kDLMetal == dev_type || kDLROCM == dev_type;
}
private:
/// Status of visitor
//@{
bool in_thread_env_{false};
std::vector<ffi::String> errs_;
//@}
tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
std::unordered_map<const VarNode*, PrimExpr> defs_; ///< Variable definitions
};
} // namespace
/// Interface of VerifyMemory pass
std::vector<ffi::String> VerifyMemory_(const PrimFunc& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "VerifyMemory: Require the target attribute";
VLOG(1) << "verifying memory for target '" << target.value()->str()
<< "' for primitive:" << std::endl
<< func;
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
v.Run();
return v.Errors();
} else {
return {};
}
}
bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; }
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.analysis.verify_memory", VerifyMemory);
}
namespace transform {
Pass VerifyMemory() {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto func = kv.second.as<PrimFunc>()) {
auto errs = VerifyMemory_(func.value());
if (errs.size() > 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << "\n";
}
LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
<< s.str() << " Did you forget to bind?\n"
<< func.value();
}
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.VerifyMemory", VerifyMemory);
}
} // namespace transform
} // namespace tir
} // namespace tvm