blob: f0239e424f3004a58928afb36e4d19fa022e576e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "transform/utils.h"
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/index.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/stmt_functor.h>
namespace tvm {
namespace relax {
/*! \brief Helper to implement bind params.*/
class ExprBinder : public ExprMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map)
: args_map_(args_map), symbolic_var_map_(symbolic_var_map) {}
private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const FunctionNode* op) final {
tvm::Array<Var> params;
bool all_params_unchanged = true;
for (const Var& param : op->params) {
if (args_map_.count(param)) {
all_params_unchanged = false;
} else {
Var new_param = this->VisitVarDef(param);
params.push_back(new_param);
if (!param.same_as(new_param)) {
this->var_remap_[param->vid] = new_param;
all_params_unchanged = false;
}
}
}
Expr body = this->VisitWithNewScope(op->body, params);
// FuncStructInfo does not depend on Expr
if (all_params_unchanged && body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
// purity won't be affected, no need to update annotation
return Function(params, body, VisitExprDepStructInfoField(op->ret_struct_info), op->is_pure,
op->attrs);
}
}
Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return ExprMutator::VisitExpr_(op);
}
}
PrimExpr VisitPrimExpr(const PrimExpr& expr) final {
auto new_expr = tir::Substitute(expr, symbolic_var_map_);
if (!expr.same_as(new_expr)) {
arith::Analyzer analyzer;
new_expr = analyzer.Simplify(new_expr);
}
return new_expr;
}
private:
const tvm::Map<Var, Expr>& args_map_;
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map_;
};
/*!
* \brief Bind params on expr
* \param expr The expr where to bind params
* \param binds The map from param var to the expr it binds to
* \param symbolic_var_map The map from symbolic var to the expr it binds to
* \return The result expr after bind params
*/
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}
StructInfo Bind(const StructInfo& sinfo, const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map) {
return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo);
}
tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer* analyzer) {
tvm::Map<tir::Var, PrimExpr> tir_var_remap;
auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape,
const PrimExpr& expr_shape) {
if (auto var = var_shape.as<tir::Var>()) {
tir_var_remap.Set(var.value(), expr_shape);
}
};
auto bind_from_prim_value = [&bind_from_prim_expr](const StructInfo& var,
const StructInfo& expr) {
auto var_sinfo = var.as<PrimStructInfoNode>();
if (!var_sinfo) return;
auto expr_sinfo = expr.as<PrimStructInfoNode>();
CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype)
<< "Cannot bind expression with struct type " << expr << " to variable with struct type "
<< var << ", due to conflicting PrimExpr DataType";
if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return;
bind_from_prim_expr(var_sinfo->value.value(), expr_sinfo->value.value());
};
auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
auto var_shape = var.as<ShapeStructInfoNode>();
if (!var_shape) return;
if (!var_shape->values.defined()) return;
auto expr_shape = expr.as<ShapeStructInfoNode>();
CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_shape->values.defined()) return;
auto var_shape_arr = var_shape->values.value();
auto expr_shape_arr = expr_shape->values.value();
CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
<< "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size()
<< " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size();
for (size_t i = 0; i < var_shape_arr.size(); i++) {
bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
}
};
auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const StructInfo& expr) {
auto var_tensor = var.as<TensorStructInfoNode>();
if (!var_tensor) return;
if (!var_tensor->shape.defined()) return;
auto expr_tensor = expr.as<TensorStructInfoNode>();
CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
if (!expr_tensor->shape.defined()) return;
bind_from_shape(GetStructInfo(var_tensor->shape.value()),
GetStructInfo(expr_tensor->shape.value()));
};
for (const auto& [relax_var, relax_expr] : relax_var_remap) {
auto var_sinfo = GetStructInfo(relax_var);
auto expr_sinfo = GetStructInfo(relax_expr);
bind_from_tensor(var_sinfo, expr_sinfo);
bind_from_shape(var_sinfo, expr_sinfo);
bind_from_prim_value(var_sinfo, expr_sinfo);
}
return tir_var_remap;
}
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
DataType dtype;
int ndim;
if (const auto* tensor = sinfo.as<TensorStructInfoNode>()) {
dtype = tensor->dtype;
ndim = tensor->ndim;
} else if (const auto* prim = sinfo.as<PrimStructInfoNode>()) {
dtype = prim->dtype;
ndim = 0;
} else {
return false;
}
bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype && dtype.is_void());
bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1);
return correct_dtype && correct_rank;
}
bool IsLeafOrTuple(const Expr& expr) {
return expr.as<LeafExprNode>() || expr.as<GlobalVarNode>() || expr.as<ExternFuncNode>() ||
expr.as<OpNode>() || expr.as<TupleNode>();
}
bool IsImpureCall(const Call& call) {
if (auto op_ptr = call->op.as<OpNode>()) {
auto op = GetRef<Op>(op_ptr);
static auto purity_map = Op::GetAttrMap<Bool>("FPurity");
ICHECK(purity_map.count(op)) << "Cannot find the registered purity of this op: " << op->name;
return !(purity_map[op]->value);
}
// the StructInfo must be FuncStructInfo
auto func_struct_info = GetStructInfoAs<FuncStructInfoNode>(call->op);
return !func_struct_info->purity;
}
Expr GetBoundValue(const Binding& b) {
if (auto* var_binding = b.as<VarBindingNode>()) {
return var_binding->value;
} else if (auto* match_binding = b.as<MatchCastNode>()) {
return match_binding->value;
} else {
CHECK(false) << "Invalid binding (should never happen)";
}
}
/*!
* \brief Copy a new Relax function with new remapped vars and symbolic vars.
* To get the var mapping from old vars to new vars, see FuncCopier in src/relax/transform/utils.h.
* \param func The Relax function we want to copy.
* \return The copied function.
*/
Function CopyWithNewVars(Function func) { return FunctionCopier().Copy(func); }
TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars);
} // namespace relax
} // namespace tvm