blob: fa70ad894ea67bea3771a817ecfd9ebff5a78dbe [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file op_module.cc
* \brief Invoke registered TVM operators.
* \author Yizhi Liu
*/
#if MXNET_USE_TVM_OP
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string>
#include <vector>
#include "op_module.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::tvm::runtime::TVMOpConfig);
} // namespace dmlc
namespace tvm {
namespace runtime {
void TVMOpModule::Load(const std::string& filepath) {
static const PackedFunc* f_load = Registry::Get("runtime.ModuleLoadFromFile");
std::lock_guard<std::mutex> lock(mutex_);
Module module = (*f_load)(filepath, "");
module_ptr_ = std::make_shared<Module>();
*module_ptr_ = module;
}
void TVMOpModule::Import(const TVMOpModule& module) {
CHECK(module_ptr_ != nullptr) << "module_ptr_ is not initialized.";
std::lock_guard<std::mutex> lock(mutex_);
module_ptr_->Import(*(module.module_ptr_));
}
PackedFunc GetFunction(const std::shared_ptr<Module>& module,
const std::string& op_name,
const std::vector<mxnet::TBlob>& args) {
std::ostringstream func_name;
func_name << op_name;
for (const auto& arg : args) {
switch (arg.type_flag_) {
case mshadow::kFloat32:
func_name << "float32";
break;
case mshadow::kFloat64:
func_name << "float64";
break;
case mshadow::kFloat16:
func_name << "float16";
break;
case mshadow::kUint8:
func_name << "uint8";
break;
case mshadow::kUint16:
func_name << "uint16";
break;
case mshadow::kUint32:
func_name << "uint32";
break;
case mshadow::kUint64:
func_name << "uint64";
break;
case mshadow::kInt16:
func_name << "int16";
break;
case mshadow::kInt32:
func_name << "int32";
break;
case mshadow::kInt8:
func_name << "int8";
break;
case mshadow::kInt64:
func_name << "int64";
break;
case mshadow::kBool:
func_name << "bool";
break;
default:
LOG(FATAL) << "Unknown dtype " << arg.type_flag_;
}
func_name << "_" << arg.shape_.ndim();
}
return module->GetFunction(func_name.str(), false);
}
void TVMOpModule::Call(const std::string& func_name,
const mxnet::OpContext& ctx,
const std::vector<mxnet::TBlob>& args) const {
std::vector<int> type_codes;
std::vector<TVMValue> values;
type_codes.resize(args.size());
values.resize(args.size());
for (size_t i = 0; i < args.size(); ++i) {
type_codes[i] = kTVMDLTensorHandle;
values[i].v_handle = const_cast<DLTensor*>(&(args[i].dltensor()));
}
TVMArgs tvm_args(&values[0], &type_codes[0], args.size());
TVMRetValue rv;
#if MXNET_USE_CUDA
int dev_type = (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU;
int dev_id = ctx.run_ctx.ctx.dev_id;
if (dev_type == kDLGPU) {
void* stream = static_cast<void*>(ctx.run_ctx.get_stream<mxnet::gpu>()->stream_);
TVMSetStream(dev_type, dev_id, stream);
}
#endif
GetFunction(module_ptr_, func_name, args).CallPacked(tvm_args, &rv);
#if MXNET_USE_CUDA
if (dev_type == kDLGPU) {
TVMSetStream(dev_type, dev_id, nullptr);
}
#endif
}
void TVMOpModule::CallEx(const std::string& func_name,
const mxnet::OpContext& ctx,
const std::vector<mxnet::TBlob>& tblobs,
TVMArgs tvm_args) const {
TVMRetValue rv;
#if MXNET_USE_CUDA
int dev_type = (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU;
int dev_id = ctx.run_ctx.ctx.dev_id;
if (dev_type == kDLGPU) {
void* stream = static_cast<void*>(ctx.run_ctx.get_stream<mxnet::gpu>()->stream_);
TVMSetStream(dev_type, dev_id, stream);
}
#endif
GetFunction(module_ptr_, func_name, tblobs).CallPacked(tvm_args, &rv);
#if MXNET_USE_CUDA
if (dev_type == kDLGPU) {
TVMSetStream(dev_type, dev_id, nullptr);
}
#endif
}
const TVMOpConfig& GetOpConfig(const std::string& name) {
const TVMOpConfig* ret = ::dmlc::Registry<TVMOpConfig>::Get()->Find(name);
CHECK(ret != nullptr) << "op " << name << "does not exist.";
return *ret;
}
} // namespace runtime
} // namespace tvm
#endif // MXNET_USE_TVM_OP