blob: b0e96cc47514827d849b604008cda95ff41440cb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file canonicalize_cast.cc
* \brief Canonicalize cast expressions to make operator fusion more efficient.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include "pass_utils.h"
#include "pattern_utils.h"
namespace tvm {
namespace relay {
// This pass finds upcast that is referred by multiple elemwise/broadcast operators, and creates a
// copy of it in each branch such that after fusion the previous function have output with fewer
// bits.
//
// Consider the following example:
// \code
// def @main(x: int8) {
// %1 = cast(%x, f32)
// %2 = exp(%1)
// %3 = log(%1)
// (%3, 4)
// }
// \endcode
//
// We would like to prevent sharing of the cast expression such that operator fusion can produce
// more efficient result as below.
// \code
// def @main(x: int8) {
// %1 = fn (%p1: i8) {
// exp(cast(%p1, f32)
// }
// %3 = %1(%x)
// %2 = fn (%p1: i8) {
// log(cast(%p1, f32)
// }
// %4 = %2(%x)
// (%3, 4)
// }
// \endcode
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
if (const OpNode* opnode = call->op.as<OpNode>()) {
auto pattern = fpattern[GetRef<Op>(opnode)];
if (pattern <= kBroadcast) {
Array<Expr> call_args = call->args;
bool unchanged = true;
for (size_t i = 0; i < call_args.size(); ++i) {
Expr arg = call_args[i];
Expr new_arg = GetNewCallArg(arg);
if (!arg.same_as(new_arg)) {
call_args.Set(i, new_arg);
unchanged = false;
}
}
if (unchanged) {
return GetRef<Expr>(call);
}
return Call(call->op, call_args, call->attrs, call->type_args);
}
}
Expr new_expr = ExprMutator::VisitExpr_(call);
return new_expr;
}
private:
std::unordered_map<const Object*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;
Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
Expr new_expr = this->VisitExpr(e);
if (const CallNode* call = e.as<CallNode>()) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
ICHECK(from_type);
if (from_type->dtype.bits() < attrs->dtype.bits()) {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
ICHECK(new_call->op == cast_op_);
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args);
}
}
}
}
return new_expr;
}
};
Expr CanonicalizeCast(const Expr& e) { return CastCanonicalizer().Mutate(e); }
namespace transform {
Pass CanonicalizeCast() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CanonicalizeCast(f));
};
return CreateFunctionPass(pass_func, 3, "CanonicalizeCast", {"InferType"});
}
TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeCast").set_body_typed(CanonicalizeCast);
} // namespace transform
} // namespace relay
} // namespace tvm