| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * Copyright (c) 2019 by Contributors |
| * \file op_module.cc |
| * \brief Invoke registered TVM operators. |
| * \author Yizhi Liu |
| */ |
| #if MXNET_USE_TVM_OP |
| #include <tvm/runtime/packed_func.h> |
| #include <tvm/runtime/registry.h> |
| #include <tvm/runtime/c_runtime_api.h> |
| #include <string> |
| #include <vector> |
| #include "op_module.h" |
| |
| using namespace tvm::runtime; |
| |
| namespace tvm { |
| namespace runtime { |
| |
| void TVMOpModule::Load(const std::string &filepath) { |
| static const PackedFunc *f_load = Registry::Get("module._LoadFromFile"); |
| std::lock_guard<std::mutex> lock(mutex_); |
| Module module = (*f_load)(filepath, ""); |
| module_ptr_ = std::make_shared<Module>(); |
| *module_ptr_ = module; |
| } |
| |
| PackedFunc GetFunction(const std::shared_ptr<Module> &module, |
| const std::string &op_name, |
| const std::vector<mxnet::TBlob> &args) { |
| std::ostringstream func_name; |
| func_name << op_name; |
| for (const auto &arg : args) { |
| switch (arg.type_flag_) { |
| case mshadow::kFloat32: |
| func_name << "float32"; |
| break; |
| case mshadow::kFloat64: |
| func_name << "float64"; |
| break; |
| case mshadow::kFloat16: |
| func_name << "float16"; |
| break; |
| case mshadow::kUint8: |
| func_name << "uint8"; |
| break; |
| case mshadow::kInt32: |
| func_name << "int32"; |
| break; |
| case mshadow::kInt8: |
| func_name << "int8"; |
| break; |
| case mshadow::kInt64: |
| func_name << "int64"; |
| break; |
| default: |
| LOG(FATAL) << "Unknown dtype " << arg.type_flag_; |
| } |
| func_name << "_" << arg.shape_.ndim(); |
| } |
| return module->GetFunction(func_name.str(), false); |
| } |
| |
| void TVMOpModule::Call(const std::string &func_name, |
| const mxnet::OpContext& ctx, |
| const std::vector<mxnet::TBlob> &args) { |
| std::vector<int> type_codes; |
| std::vector<TVMValue> values; |
| |
| type_codes.resize(args.size()); |
| values.resize(args.size()); |
| for (size_t i = 0; i < args.size(); ++i) { |
| type_codes[i] = kArrayHandle; |
| values[i].v_handle = const_cast<DLTensor *>(&(args[i].dltensor())); |
| } |
| |
| TVMArgs tvm_args(&values[0], &type_codes[0], args.size()); |
| TVMRetValue rv; |
| |
| #if MXNET_USE_CUDA |
| int dev_type = (ctx.run_ctx.ctx.dev_type == mxnet::Context::DeviceType::kGPU) ? kDLGPU : kDLCPU; |
| int dev_id = ctx.run_ctx.ctx.dev_id; |
| if (dev_type == kDLGPU) { |
| void *stream = static_cast<void *>(ctx.run_ctx.get_stream<mxnet::gpu>()->stream_); |
| TVMSetStream(dev_type, dev_id, stream); |
| } |
| #endif |
| GetFunction(module_ptr_, func_name, args).CallPacked(tvm_args, &rv); |
| #if MXNET_USE_CUDA |
| if (dev_type == kDLGPU) { |
| TVMSetStream(dev_type, dev_id, nullptr); |
| } |
| #endif |
| } |
| |
| } // namespace runtime |
| } // namespace tvm |
| #endif // MXNET_USE_TVM_OP |