blob: bc75d555c9af9fb6ae413ba3ecab97b985528594 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace transform {
namespace {
/*! \brief Collect all attributes whose name contains "layout".
struct CollectAttrs : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
if (std::string(key).find("layout") != std::string::npos) {
attrs[key] = String(*value);
void Visit(const char* key, double* value) final {}
void Visit(const char* key, uint64_t* value) final {}
void Visit(const char* key, int* value) final {}
void Visit(const char* key, int64_t* value) final {}
void Visit(const char* key, bool* value) final {}
void Visit(const char* key, runtime::NDArray* value) final {}
void Visit(const char* key, ObjectRef* value) final {
if (std::string(key).find("layout") != std::string::npos) {
attrs[key] = *value;
void Visit(const char* key, DataType* value) final {}
void Visit(const char* key, void** value) final {}
std::unordered_map<std::string, ObjectRef> attrs;
} // namespace
/*! \brief Visitor to add structural hash and layout information to `Function`
* nodes. Sets the "hash" field on the attr to the structural hash of the
* function. Propogates any attributes with "layout" in their name from call
* nodes in the Function to the Function's attrs.
class LabelOpsMutator : public MixedModeMutator {
using MixedModeMutator::VisitExpr_;
std::unordered_map<std::string, ObjectRef> body_attrs;
Expr VisitExpr_(const FunctionNode* op) final {
if (op->GetAttr<String>("hash").defined()) {
// Already labelled.
return ExprMutator::VisitExpr_(op);
// body_attrs collects attrs from Calls in the body of this Function. Reset
// it so we only get attrs from this Function.
body_attrs = {};
auto updated = ExprMutator::VisitExpr_(op);
size_t hash = StructuralHash()(updated);
// format hash as fixed length hex string so it is easier to read
std::stringstream s;
s << std::setfill('0') << std::setw(sizeof(size_t) * 2) << std::hex << hash;
Function f = WithAttr(Downcast<Function>(updated), "hash", String(s.str()));
for (auto p : body_attrs) {
f = WithAttr(f, p.first, p.second);
return std::move(f);
Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
auto post_visit = [this](const LetNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
auto expr = GetRef<Expr>(op);
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
this->memo_[expr] = expr;
} else {
this->memo_[expr] = Let(var, value, body);
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
Expr Rewrite_(const CallNode* op, const Expr& post) final {
auto updated = MixedModeMutator::Rewrite_(op, post);
if (op->attrs.defined()) {
CollectAttrs collect;
for (auto p : collect.attrs) {
if (body_attrs.find(p.first) != body_attrs.end() && p.second == body_attrs[p.first]) {
LOG(WARNING) << "LabelOps found two call sites with different values for " << p.first
<< " (" << p.second << " vs " << body_attrs[p.first]
<< "). Only the first will be recorded.";
body_attrs[p.first] = p.second;
return updated;
/*! \brief Add structural hash and layout information to Function nodes. This
* information is used later by the profiler.
* The hash and layout information is added to the attrs field of the Function.
* The key "hash" contains the structural hash of the node. Any attributes with
* "layout" in their name are also added to attrs (for example,
* `attrs["src_layout"]` contains the `src_layout` attribute of the TVM op
* corresponding to this function call).
Pass LabelOps() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(LabelOpsMutator().Mutate(f));
return CreateFunctionPass(pass_func, 1, "LabelOps", {});
} // namespace transform
} // namespace relay
} // namespace tvm