blob: 926c5162005a7b257655577734987fa5246226b5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/contrib/msc/plugin/tvm_codegen.h
* \brief Codegen for tvm plugin.
*/
#ifndef TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_
#define TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_
#include <set>
#include <string>
#include "base_codegen.h"
#include "codegen_utils.h"
namespace tvm {
namespace contrib {
namespace msc {
/*!
* \brief CodeGen config for tvm plugin
*/
struct TVMPluginCodeGenConfig {
bool as_relay{false};
std::string tvm_root{"tvm"};
PLUGIN_CODEGEN_CONFIG_MEMBERS
void Load(dmlc::JSONReader* reader) {
std::string key;
reader->BeginObject();
while (reader->NextObjectItem(&key)) {
if (key == "as_relay") {
reader->Read(&as_relay);
} else if (key == "tvm_root") {
reader->Read(&tvm_root);
} else {
PLUGIN_CODEGEN_CONFIG_PARSE
}
}
}
};
class TVMPluginCodeGen : public BasePluginCodeGen<TVMPluginCodeGenConfig> {
public:
/*!
* \brief The constructor of TVMPluginCodeGen
* \param config the options for codegen.
*/
explicit TVMPluginCodeGen(const std::string& config = "")
: BasePluginCodeGen<TVMPluginCodeGenConfig>(config) {}
protected:
/*! \brief Codegen plugin attr declare*/
void CodeGenAttrDeclare(const Plugin& plugin) final;
/*! \brief Codegen plugin attr define*/
void CodeGenAttrDefine(const Plugin& plugin) final;
/*! \brief Codegen plugin op declare*/
void CodeGenOpDeclare(const Plugin& plugin) final;
/*! \brief Codegen plugin op define*/
void CodeGenOpDefine(const Plugin& plugin) final;
/*! \brief Codegen plugin runtime*/
void CodeGenOpRuntime(const Plugin& plugin) final;
/*! \brief Codegen cmake file*/
void CodeGenCmake(const std::set<ffi::String>& devices) final;
/*! \brief Codegen manager depends*/
void CodeGenManagerDepends() final;
/*! \brief Codegen manager methods*/
void CodeGenManagerMethods() final;
/*! \brief Codegen manager member for plugin*/
void CodeGenOpBuilder(const Plugin& plugin) final;
private:
/*! \brief Func name of compute*/
const ffi::String ComputeName(const Plugin& plugin) { return plugin->name + "_compute"; }
/*! \brief Codegen compute*/
void CodeGenCompute(const Plugin& plugin, const ffi::String& device);
/*! \brief Type name in tvm*/
const ffi::String ToTVMType(const ffi::String& type) {
if (type == "string") {
return "StringImm";
}
if (StringUtils::StartsWith(type, "float")) {
return "FloatImm";
}
if (type == "bool" || StringUtils::StartsWith(type, "int")) {
return "IntImm";
}
if (IsListType(type)) {
return "Tuple";
}
return BasePluginCodeGen<TVMPluginCodeGenConfig>::ToCppType(type);
}
};
} // namespace msc
} // namespace contrib
} // namespace tvm
#endif // TVM_CONTRIB_MSC_PLUGIN_TVM_CODEGEN_H_