blob: c5785369f8d553f38f078ba965f494d9dc21a6f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/op_strategy.h
* \brief The Relay operator Strategy and related data structure.
*/
#ifndef TVM_RELAY_OP_STRATEGY_H_
#define TVM_RELAY_OP_STRATEGY_H_
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/target.h>
#include <tvm/te/schedule.h>
#include <tvm/te/tensor.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Operator implementation that includes compute and schedule function.
*/
class OpImplementationNode : public Object {
public:
/*! \brief Compute function */
FTVMCompute fcompute;
/*! \brief Schedule function */
FTVMSchedule fschedule;
/*! \brief Name of the implementation */
String name;
/*! \brief Priority level */
int plevel;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("plevel", &plevel);
}
static constexpr const char* _type_key = "relay.OpImplementation";
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
};
/*!
* \brief Operator implementation class.
*/
class OpImplementation : public ObjectRef {
public:
/*!
* \brief Invoke the operator compute function.
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information.
* \return The output compute description of the operator.
*/
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type);
/*!
* \brief Build the computation schedule.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return The computation schedule.
*/
TVM_DLL te::Schedule Schedule(const Attrs& attrs, const Array<te::Tensor>& outs,
const Target& target);
TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
};
/*!
* \brief Specialized implementations for operators under certain conditions.
*/
class OpSpecializationNode : public Object {
public:
/*! \brief List of implementations. */
Array<OpImplementation> implementations;
/*! \brief Condition to enable the specialization.
* Could be undefined to represent generic case. */
te::SpecializedCondition condition;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("implementations", &implementations);
}
static constexpr const char* _type_key = "relay.OpSpecialization";
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
};
/*!
* \brief Operator specialization class.
*/
class OpSpecialization : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
};
/*!
* \brief Operator strategy to choose implementation.
*/
class OpStrategyNode : public Object {
public:
/*! \brief List of operator specializations. */
Array<OpSpecialization> specializations;
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("specializations", &specializations); }
static constexpr const char* _type_key = "relay.OpStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
};
/*!
* \brief Operator strategy class.
*/
class OpStrategy : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
int plevel);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_STRATEGY_H_