| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Compile executable modules. |
| * \file driver_api.cc |
| */ |
| #include <dmlc/thread_local.h> |
| #include <tvm/driver/driver_api.h> |
| #include <tvm/ir/transform.h> |
| #include <tvm/runtime/container.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/target/codegen.h> |
| #include <tvm/te/operation.h> |
| #include <tvm/tir/analysis.h> |
| #include <tvm/tir/transform.h> |
| |
| #include <algorithm> |
| #include <mutex> |
| #include <stack> |
| |
| namespace tvm { |
| |
| // Register build pipeline related options |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); |
| TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>); |
| |
| using runtime::PackedFunc; |
| using runtime::TVMArgs; |
| using runtime::TVMRetValue; |
| |
| bool LLVMEnabled() { |
| const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); |
| return pf != nullptr; |
| } |
| |
| /*! \return The default host target for a given device target */ |
| Target DefaultTargetHost(Target target) { |
| if (target.defined() && target->kind->device_type == kDLCPU) { |
| return target; |
| } else { |
| if (LLVMEnabled()) { |
| return Target("llvm"); |
| } else { |
| return Target("stackvm"); |
| } |
| } |
| } |
| |
| tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std::string name, |
| int data_alignment, int offset_factor, bool compact) { |
| DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); |
| auto data = tir::Var(name, PointerType(PrimType(storage_dtype))); |
| bool has_any = false; |
| if (!compact) { |
| for (const auto& it : shape) { |
| if (it.as<tir::VarNode>()) { |
| has_any = true; |
| break; |
| } |
| } |
| } |
| tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; |
| |
| PrimExpr elem_offset; |
| if (offset_factor != 0) { |
| elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); |
| } else { |
| elem_offset = PrimExpr(); |
| } |
| |
| return tir::Buffer(data, dtype, shape, Array<PrimExpr>(), elem_offset, name, "", data_alignment, |
| offset_factor, buffer_type); |
| } |
| |
| void GetBinds(const Array<te::Tensor>& args, bool compact, |
| const std::unordered_map<te::Tensor, tir::Buffer>& binds, |
| Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) { |
| *out_binds = binds; |
| |
| for (const auto& x : args) { |
| if (out_binds->find(x) == out_binds->end()) { |
| auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, -1, 0, compact); |
| out_binds->Set(x, buf); |
| out_arg_list->push_back(buf); |
| } else { |
| out_arg_list->push_back((*out_binds)[x]); |
| } |
| } |
| } |
| |
| transform::Pass BindTarget(Target target) { |
| auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { |
| return WithAttr(std::move(f), tvm::attr::kTarget, target); |
| }; |
| return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {}); |
| } |
| |
| template <typename FCond> |
| transform::Pass Filter(FCond fcond) { |
| auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { |
| if (fcond(f)) { |
| return f; |
| } else { |
| return tir::PrimFunc(nullptr); |
| } |
| }; |
| return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); |
| } |
| |
| IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name, |
| const std::unordered_map<te::Tensor, tir::Buffer>& binds) { |
| Array<ObjectRef> out_arg_list; |
| auto pass_ctx = transform::PassContext::Current(); |
| |
| sch = sch.normalize(); |
| |
| // Before TIR transformation. |
| auto bounds = te::InferBound(sch); |
| auto stmt = te::ScheduleOps(sch, bounds, false); |
| bool compact = te::VerifyCompactBuffer(stmt); |
| |
| Map<te::Tensor, tir::Buffer> out_binds; |
| GetBinds(args, compact, binds, &out_binds, &out_arg_list); |
| |
| // build the function |
| tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds); |
| f = WithAttr(std::move(f), "global_symbol", runtime::String(name)); |
| |
| bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value(); |
| bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value(); |
| bool instrument_bound_checkers = |
| pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value(); |
| |
| if (noalias) { |
| f = WithAttr(std::move(f), "tir.noalias", Bool(true)); |
| } |
| |
| auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}})); |
| auto pass_list = Array<tvm::transform::Pass>(); |
| |
| // Phase 0 |
| pass_list.push_back(tir::transform::InjectPrefetch()); |
| pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); |
| // Phase 1 |
| pass_list.push_back(tir::transform::BF16Legalize()); |
| pass_list.push_back(tir::transform::NarrowDataType(32)); |
| pass_list.push_back(tir::transform::Simplify()); |
| pass_list.push_back(tir::transform::LoopPartition()); |
| pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); |
| pass_list.push_back(tir::transform::InjectVirtualThread()); |
| pass_list.push_back(tir::transform::InjectDoubleBuffer()); |
| pass_list.push_back(tir::transform::StorageRewrite()); |
| pass_list.push_back(tir::transform::UnrollLoop()); |
| // Phase 2 |
| pass_list.push_back(tir::transform::Simplify()); |
| pass_list.push_back(tir::transform::RemoveNoOp()); |
| pass_list.push_back(tir::transform::RewriteUnsafeSelect()); |
| if (instrument_bound_checkers) { |
| pass_list.push_back(tir::transform::InstrumentBoundCheckers()); |
| } |
| // run |
| auto optimize = transform::Sequential(pass_list); |
| mod = optimize(std::move(mod)); |
| return mod; |
| } |
| |
| std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target& target, |
| const Target& target_host, |
| const transform::PassContext& pass_ctx) { |
| Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target), |
| tir::transform::VerifyMemory()}; |
| |
| if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value()) { |
| mixed_pass_list.push_back(tir::transform::ThreadSync("global")); |
| } |
| mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); |
| mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); |
| mixed_pass_list.push_back(tir::transform::InferFragment()); |
| mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); |
| mixed_pass_list.push_back(tir::transform::MakePackedAPI(0)); |
| mixed_pass_list.push_back(tir::transform::SplitHostDevice()); |
| auto opt_mixed = transform::Sequential(mixed_pass_list); |
| mod_mixed = opt_mixed(std::move(mod_mixed)); |
| |
| auto host_pass_list = { |
| Filter([](const tir::PrimFunc& f) { |
| return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != |
| CallingConv::kDeviceKernelLaunch; |
| }), |
| BindTarget(target_host), |
| tir::transform::LowerTVMBuiltin(), |
| tir::transform::LowerCustomDatatypes(), |
| tir::transform::LowerIntrin(), |
| tir::transform::LowerDeviceStorageAccessInfo(), |
| tir::transform::CombineContextCall(), |
| }; |
| auto opt_host = transform::Sequential(host_pass_list); |
| ICHECK(mod_mixed.defined()) << "This module must be defined"; |
| auto mhost = opt_host(mod_mixed); |
| |
| // device pipeline |
| auto device_pass_list = { |
| Filter([](const tir::PrimFunc& f) { |
| return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == |
| CallingConv::kDeviceKernelLaunch; |
| }), |
| BindTarget(target), |
| tir::transform::LowerWarpMemory(), |
| tir::transform::Simplify(), |
| tir::transform::LowerCustomDatatypes(), |
| tir::transform::LowerIntrin(), |
| tir::transform::LowerDeviceStorageAccessInfo(), |
| }; |
| auto opt_device = transform::Sequential(device_pass_list); |
| auto mdevice = opt_device(mod_mixed); |
| |
| // some final misc checks. |
| auto keys = target->GetKeys(); |
| bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); |
| if (target_is_gpu && mdevice->functions.size() == 0) { |
| LOG(WARNING) << "Specified target " << target->str() |
| << " but cannot find device code. Did you forget to bind?"; |
| } |
| |
| if (target->kind->device_type == kDLCPU && target_host == target) { |
| ICHECK(mdevice->functions.empty()) << "No device code should be generated when target " |
| << "and host_target are both llvm target." |
| << "\n"; |
| } |
| |
| return {mhost, mdevice}; |
| } |
| |
| // Build for heterogeneous execution. |
| runtime::Module build(const Map<Target, IRModule>& inputs, const Target& target_host) { |
| auto pass_ctx = transform::PassContext::Current(); |
| |
| std::vector<runtime::Module> device_modules; |
| Target target_host_val = target_host; |
| if (!target_host.defined()) { |
| for (const auto& it : inputs) { |
| if (it.first->kind->device_type == kDLCPU || it.first->kind->device_type == kDLMicroDev) { |
| target_host_val = it.first; |
| break; |
| } |
| } |
| } |
| |
| if (!target_host_val.defined()) { |
| target_host_val = DefaultTargetHost(target_host_val); |
| } |
| |
| IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>()); |
| |
| ICHECK(mhost_all.defined()) << "The host module must be defined"; |
| |
| for (const auto& it : inputs) { |
| if (it.second.defined()) { |
| auto pair = SplitDevHostFuncs(it.second, it.first, target_host_val, pass_ctx); |
| auto& mhost = pair.first; |
| auto& mdevice = pair.second; |
| |
| ICHECK(mhost.defined()) << "The split host module must be defined"; |
| |
| ICHECK(mhost_all.defined()) << "The host module must be defined"; |
| |
| mhost_all->Update(mhost); |
| |
| if (mdevice->functions.size() != 0) { |
| device_modules.push_back(codegen::Build(mdevice, it.first)); |
| } |
| } |
| } |
| |
| runtime::Module mhost = codegen::Build(mhost_all, target_host_val); |
| // Import all modules |
| for (const auto& it : device_modules) { |
| if (it.operator->()) { |
| mhost.Import(it); |
| } |
| } |
| return mhost; |
| } |
| |
| // Build for heterogeneous execution when target is a string. |
| runtime::Module build(const Map<String, IRModule>& inputs, const Target& target_host) { |
| Map<Target, IRModule> updated_input; |
| for (const auto& it : inputs) { |
| auto target = Target(it.first); |
| Optional<String> device = target->GetAttr<String>("device"); |
| if (device.defined() && device.value() == "vta") { |
| target = Target("ext_dev"); |
| } |
| updated_input.Set(target, it.second); |
| } |
| return build(updated_input, target_host); |
| } |
| |
| // Build for homogeneous execution. |
| runtime::Module build(const IRModule& funcs, const Target& target, const Target& target_host) { |
| Map<Target, IRModule> inputs = {{target, funcs}}; |
| return build(inputs, target_host); |
| } |
| |
| } // namespace tvm |