blob: ed28e5dbc8da163918cd9c4adfcb1353ec32d9a6 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/cast.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/function.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/logging.h>
#include <tvm/tirx/function.h>
#include <tvm/tirx/op.h>
namespace tvm {
namespace relax {
class ConstantFolder : public ExprMutator {
public:
static Function Fold(Function func, IRModule ctx_module) {
ConstantFolder folder(std::move(ctx_module));
func = Downcast<Function>(RemoveAllUnused(folder(func)));
return func;
}
private:
explicit ConstantFolder(IRModule ctx_module) : ExprMutator(ctx_module) {}
/*!
* \brief Pattern match the shape inside the given struct info to a
* constant shape and get runtime shape tuple from it.
* \param struct_info The given struct info whose shape inside is to be casted.
* \return The runtime shape tuple, or nullopt if it is not a constant shape.
* \note Only TensorStructInfo is supported. Returns std::nullopt
* if the input struct info is not TensorStructInfo.
*/
static ffi::Optional<ffi::Shape> MatchConstShape(const StructInfo& struct_info) {
const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
if (tensor_sinfo == nullptr) {
return std::nullopt;
}
const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
TVM_FFI_ICHECK(shape != nullptr) << "struct info given by call_tir should have ShapeExpr shape";
std::vector<int64_t> shape_values;
for (const auto v : shape->values) {
auto* ptr = v.as<IntImmNode>();
if (!ptr) return std::nullopt;
shape_values.push_back(ptr->value);
}
return ffi::Shape(shape_values.begin(), shape_values.end());
}
/*!
* \brief Pattern match op to constant array arguments.
* \return The constant array arguments, or nullopt if match fails.
*/
static ffi::Optional<ffi::Array<runtime::Tensor>> MatchConstArrayArgs(
const ffi::Array<Expr>& args) {
ffi::Array<runtime::Tensor> res;
for (auto arg : args) {
auto* ptr = arg.as<relax::ConstantNode>();
if (!ptr) return std::nullopt;
res.push_back(ptr->data);
}
return res;
}
/*!
* \brief Pattern match op to a TIR function and look it up.
* \return The TIR function, or nullopt if pattern match fails.
*/
ffi::Optional<tirx::PrimFunc> MatchPrimFunc(const Expr& op) {
const GlobalVar& global_var = Downcast<GlobalVar>(op);
// NOTE: as check works for nullptr(returns null)
ffi::Optional<BaseFunc> base_func = builder_->GetContextIRModule()->functions.Get(global_var);
if (auto* pfunc = base_func.as<tirx::PrimFuncNode>()) {
return ffi::GetRef<tirx::PrimFunc>(pfunc);
}
return std::nullopt;
}
/*!
* \brief Get a cached build version of func
* \return The cached func, nullopt if func cannot be built.
*/
ffi::Optional<ffi::Function> GetCachedBuild(tirx::PrimFunc func) {
// TODO(tvm-team): consider another way of bulk extract and build PrimFunc once
// would be helpful for future cases where PrimFunc recursively call into each other
Target eval_cpu_target{"llvm"};
auto it = func_build_cache_.find(func);
if (it != func_build_cache_.end()) {
return it->second;
}
ffi::Optional<ffi::Function> build_func = std::nullopt;
try {
// Not all the primfunc can be directly built via llvm, for example, if a function is
// already scheduled to only work on GPU, we will need to skip this in the const folder for
// now
// TODO(Hongyi): further check and narrow the scope of foldable function
const auto pf = tvm::ffi::Function::GetGlobalRequired("tirx.build");
func = WithAttr(func, tvm::attr::kGlobalSymbol, ffi::String("tir_function"));
ffi::Module rt_module = pf(func, eval_cpu_target).cast<ffi::Module>();
build_func = rt_module->GetFunction("tir_function");
} catch (const tvm::ffi::Error& err) {
// build failure may happen in which case we skip
DLOG(WARNING) << "Build failure for function " << func << ", Error message: " << err.what();
}
func_build_cache_[func] = build_func;
return build_func;
}
/*!
* \brief Checks if it is useful to fold \p expr.
* \details Folding an expr is a trade-off - we are materializing a constant in the IRModule and
* paying compile time cost to avoid the cost of executing this expr at runtime. For example,
* folding iota ops could result in large constants being materialized, thus increasing the size
* of the program.
*/
static bool ExprContainsTensor(const Expr& expr) {
if (GetStructInfo(expr).as<TensorStructInfoNode>()) {
return true;
}
if (const auto* tuple = expr.as<TupleNode>()) {
for (const auto& field : tuple->fields) {
if (ExprContainsTensor(field)) {
return true;
}
}
}
return false;
}
bool ShouldBeFolded(Expr expr) {
// Skip folding for creation ops (no tensor inputs) that produce large outputs.
// These ops (e.g., zeros, ones, full, arange) are cheap to compute at runtime,
// and folding them would materialize large constants in the binary.
static constexpr int64_t kMaxFoldElements = 1024;
const auto* call = expr.as<CallNode>();
if (!call) return true;
const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
if (!tensor_sinfo) return true;
auto opt_shape = tensor_sinfo->GetShape();
if (!opt_shape) return true;
int64_t num_elements = 1;
for (const auto& dim : opt_shape.value()) {
const auto* int_dim = dim.as<IntImmNode>();
if (!int_dim) return true;
int64_t d = int_dim->value;
if (d <= 0) return true;
if (num_elements > kMaxFoldElements / d) {
num_elements = kMaxFoldElements + 1;
break;
}
num_elements *= d;
}
if (num_elements <= kMaxFoldElements) return true;
// Large output. Only skip if there are no tensor inputs,
// i.e., this is a pure creation op.
for (const auto& arg : call->args) {
if (ExprContainsTensor(arg)) {
return true;
}
}
return false;
}
// Try constant evaluate a call_tir with a single tensor output.
// Returns std::nullopt on failure.
ffi::Optional<Expr> ConstEvaluateCallTIR(tirx::PrimFunc tir_func,
ffi::Array<runtime::Tensor> arr_args, ffi::Shape shape,
DataType ret_type) {
// obtain function from the cache.
ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
if (!func) return std::nullopt;
// here the vector size has an additional + 1 because we need to put ret_tensor at the end
std::vector<AnyView> packed_args(arr_args.size() + 1);
DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
runtime::Tensor ret_tensor = runtime::Tensor::Empty(shape, ret_type, cpu_dev);
// avoid set rvalue ref which get de-allocated later, store args in a vector
// where temp_args[i] are lvalue ref that is stable
std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
size_t arg_offset = 0;
for (; arg_offset < arr_args.size(); ++arg_offset) {
packed_args[arg_offset] = temp_args[arg_offset];
}
// set return value
packed_args[arg_offset++] = ret_tensor;
ffi::Any ret;
// invoke
func.value().CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &ret);
return Constant(ret_tensor);
}
// Try constant evaluate a call_tir with tuple outputs (multiple output tensors).
// Returns std::nullopt on failure.
ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tirx::PrimFunc tir_func,
ffi::Array<runtime::Tensor> arr_args,
const TupleStructInfoNode* tuple_sinfo) {
ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
if (!func) return std::nullopt;
DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
size_t num_outputs = tuple_sinfo->fields.size();
// Match shapes and dtypes for all output fields.
std::vector<runtime::Tensor> ret_tensors;
for (size_t i = 0; i < num_outputs; ++i) {
ffi::Optional<ffi::Shape> shape = MatchConstShape(tuple_sinfo->fields[i]);
if (!shape) return std::nullopt;
auto tensor_sinfo = Downcast<TensorStructInfo>(tuple_sinfo->fields[i]);
if (tensor_sinfo->IsUnknownDtype()) return std::nullopt;
ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), tensor_sinfo->dtype, cpu_dev));
}
// Pack input args + all output tensors.
std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
std::vector<AnyView> packed_args;
packed_args.reserve(temp_args.size() + num_outputs);
for (const auto& arg : temp_args) {
packed_args.push_back(arg);
}
for (const auto& out_tensor : ret_tensors) {
packed_args.push_back(out_tensor);
}
ffi::Any ret;
func.value().CallPacked(ffi::PackedArgs(packed_args.data(), packed_args.size()), &ret);
ffi::Array<Expr> fields;
for (size_t i = 0; i < num_outputs; ++i) {
fields.push_back(Constant(ret_tensors[i]));
}
return Tuple(fields);
}
// Returns the folded expr if the call is successfully folded to constant, otherwise null.
ffi::Optional<Expr> VisitCallTIR(Call call) {
// call_tir needs to have at least two arguments
TVM_FFI_ICHECK_GE(call->args.size(), 2);
ffi::Optional<tirx::PrimFunc> func = MatchPrimFunc(call->args[0]);
TVM_FFI_ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
ffi::Optional<ffi::Array<runtime::Tensor>> arr_args =
MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
TVM_FFI_ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one sinfo arg";
if (!func || !arr_args) return {};
// Handle tuple output: sinfo_args[0] is a TupleStructInfo.
if (const auto* tuple_sinfo = call->sinfo_args[0].as<TupleStructInfoNode>()) {
return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(), tuple_sinfo);
}
// Handle single tensor output.
ffi::Optional<ffi::Shape> shape = MatchConstShape(call->sinfo_args[0]);
if (shape) {
TensorStructInfo ret_sinfo = Downcast<TensorStructInfo>(call->struct_info_);
return ConstEvaluateCallTIR(func.value(), arr_args.value(), shape.value(), ret_sinfo->dtype)
.value_or({});
}
return {};
}
using ExprMutator::VisitExpr_;
// TODO(@sunggg):
// Next PR will support fold with ffi::Function and MatchCast
// Until then, DecomposeOps() should be applied after
// this pass to fold `tensor_to_shape` op.
Expr VisitExpr_(const CallNode* call) final {
// post-order mutation
Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
// Check if it is useful to fold this call
if (!ShouldBeFolded(post_call)) return post_call;
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
auto* op_node = post_call->op.as<OpNode>();
// Not an OpNode
if (op_node == nullptr) {
return post_call;
}
auto op = ffi::GetRef<Op>(op_node);
if (op.same_as(call_tir_op)) {
return VisitCallTIR(post_call).value_or(post_call);
}
// Special logic to fold ShapeExpr between operators
// e.g.,
// <Before>
// lv: R.Shape([16, 16]) = R.shape([16, 16])
// gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv)
// <After>
// gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 16]))
//
ffi::Array<Expr> new_args;
for (auto arg : post_call->args) {
if (arg->IsInstance<VarNode>()) {
ffi::Optional<Expr> val = LookupBinding(Downcast<Var>(arg));
if (val.defined() && val.value()->IsInstance<ShapeExprNode>()) {
new_args.push_back(val.value());
continue;
}
}
new_args.push_back(arg);
}
post_call =
Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, post_call->span);
// If we are in a dataflow block, we can fold ops.
if (builder_->CurrentBlockIsDataFlow()) {
// Check if we can them to call_tir
if (legalize_map.count(op)) {
// Get the legalized expression
Call post_call_normalized = Downcast<Call>(builder_->Normalize(post_call));
Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, post_call_normalized));
// If the legalized expression is call_tir, try to fold it.
const CallNode* call = legalized_expr.as<CallNode>();
if (call && call->op.same_as(call_tir_op)) {
return VisitCallTIR(ffi::GetRef<Call>(call)).value_or(post_call);
}
} else if (op->name == "relax.tensor_to_shape") {
// Special handling for composite op "relax.tensor_to_shape"
// If its input is constant, we can access its value and create ShapeExpr
// TODO(@sunggg):
// currently, we do not have a info map about decomposition.
// Thus, this is a temporary solution until we have a consensus about
// how to deal with composite ops. One possibility is we register the
// decomposition map for each op in a similar way we do for legalization.
TVM_FFI_ICHECK_EQ(post_call->args.size(), 1);
Expr arg = post_call->args[0];
if (arg->IsInstance<ConstantNode>()) {
Constant constant = Downcast<Constant>(arg);
runtime::Tensor ndarray = constant->data;
TVM_FFI_ICHECK_EQ(ndarray->device.device_type, kDLCPU);
TVM_FFI_ICHECK(ndarray.IsContiguous());
TVM_FFI_ICHECK_EQ(ndarray->byte_offset, 0);
TVM_FFI_ICHECK_EQ(ndarray->ndim, 1);
const int64_t* data = static_cast<const int64_t*>(ndarray->data);
int64_t num_elems = ndarray->shape[0];
ffi::Array<PrimExpr> shape_values;
for (int64_t i = 0; i < num_elems; i++) {
shape_values.push_back(IntImm(DataType::Int(64), data[i]));
}
return ShapeExpr(shape_values);
}
} else if (op->name == "relax.shape_to_tensor") {
// Special handling for "relax.shape_to_tensor" since it is implemented in ffi::Function.
// TODO(sunggg): revisit this when we extend ConstantFolding to fold ffi::Function.
Expr arg = post_call->args[0];
ShapeExpr shape = Downcast<ShapeExpr>(arg);
ffi::Array<PrimExpr> values = shape->values;
ffi::Array<Integer> arr;
bool is_known = true;
for (size_t i = 0; i < values.size(); i++) {
PrimExpr val = values[i];
arr.push_back(ffi::GetRef<IntImm>(val.as<IntImmNode>()));
is_known &= (val.dtype() == DataType::Int(64));
}
if (is_known) {
const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor");
runtime::Tensor vals = func(arr).cast<runtime::Tensor>();
return Constant(vals);
}
}
}
return post_call;
}
Expr VisitExpr_(const VarNode* op) final {
ffi::Optional<Expr> opt = LookupBinding(ffi::GetRef<Var>(op));
// `as` check checks if opt is not null and is instance of constant
if (opt.as<relax::ConstantNode>()) {
return opt.value();
}
return ExprMutator::VisitExpr_(op);
}
// cache for function build, via structural equality
std::unordered_map<tirx::PrimFunc, ffi::Optional<ffi::Function>, ffi::StructuralHash,
ffi::StructuralEqual>
func_build_cache_;
};
namespace transform {
Pass FoldConstant() {
auto pass_func = [=](Function f, IRModule m, PassContext pc) {
return ConstantFolder::Fold(f, m);
};
return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.FoldConstant", FoldConstant);
}
} // namespace transform
} // namespace relax
} // namespace tvm