blob: f72b4e105749aac0ac076c9bb6d98bd510364071 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file feature.cc
* \brief Detect features used in Expr/Module
*/
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include "../transforms/pass_utils.h"
namespace tvm {
namespace relay {
FeatureSet DetectFeature(const Expr& expr) {
if (!expr.defined()) {
return FeatureSet::No();
}
struct FeatureDetector : ExprVisitor {
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited_;
FeatureSet fs = FeatureSet::No();
void VisitExpr(const Expr& expr) final {
if (visited_.count(expr) == 0) {
visited_.insert(expr);
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
fs += fGraph;
}
}
}
#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \
void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { STMT fs += f##CONSTRUCT_NAME; }
#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) \
DETECT_CONSTRUCT(CONSTRUCT_NAME, { ExprVisitor::VisitExpr_(op); })
DETECT_DEFAULT_CONSTRUCT(Var)
DETECT_DEFAULT_CONSTRUCT(GlobalVar)
DETECT_DEFAULT_CONSTRUCT(Constant)
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_CONSTRUCT(Function, {
if (!op->HasNonzeroAttr(attr::kPrimitive)) {
ExprVisitor::VisitExpr_(op);
}
})
DETECT_DEFAULT_CONSTRUCT(Op)
DETECT_DEFAULT_CONSTRUCT(Call)
DETECT_CONSTRUCT(Let, {
for (const Var& v : FreeVars(op->value)) {
if (op->var == v) {
fs += fLetRec;
}
}
ExprVisitor::VisitExpr_(op);
})
DETECT_DEFAULT_CONSTRUCT(If)
DETECT_DEFAULT_CONSTRUCT(RefCreate)
DETECT_DEFAULT_CONSTRUCT(RefRead)
DETECT_DEFAULT_CONSTRUCT(RefWrite)
DETECT_DEFAULT_CONSTRUCT(Constructor)
DETECT_DEFAULT_CONSTRUCT(Match)
#undef DETECT_DEFAULT_CONSTRUCT
} fd;
fd(expr);
return fd.fs;
}
std::string FeatureSet::ToString() const {
std::string ret;
ret += "[";
size_t detected = 0;
#define DETECT_FEATURE(FEATURE_NAME) \
++detected; \
if (bs_[FEATURE_NAME]) { \
ret += #FEATURE_NAME; \
ret += ", "; \
}
DETECT_FEATURE(fVar);
DETECT_FEATURE(fGlobalVar);
DETECT_FEATURE(fConstant);
DETECT_FEATURE(fTuple);
DETECT_FEATURE(fTupleGetItem);
DETECT_FEATURE(fFunction);
DETECT_FEATURE(fOp);
DETECT_FEATURE(fCall);
DETECT_FEATURE(fLet);
DETECT_FEATURE(fIf);
DETECT_FEATURE(fRefCreate);
DETECT_FEATURE(fRefRead);
DETECT_FEATURE(fRefWrite);
DETECT_FEATURE(fConstructor);
DETECT_FEATURE(fMatch);
DETECT_FEATURE(fGraph);
DETECT_FEATURE(fLetRec);
#undef DETECT_FEATURE
ICHECK(detected == feature_count) << "some feature not printed";
ret += "]";
return ret;
}
FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
return fs;
}
Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod) {
FeatureSet fs = DetectFeature(expr);
if (mod.defined()) {
fs = fs + DetectFeature(mod.value());
}
return static_cast<Array<Integer>>(fs);
}
TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);
void CheckFeature(const Expr& expr, const FeatureSet& fs) {
auto dfs = DetectFeature(expr);
ICHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
<< "\nhas unsupported feature: " << (dfs - fs).ToString();
}
void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
for (const auto& f : mod->functions) {
CheckFeature(f.second, fs);
}
}
} // namespace relay
} // namespace tvm