blob: b54544b00082d5bd0172ae09deb8eabacd9731f6 [file]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/rewrite_dataflow_reshape.cc
* \brief Transform all reshape within dataflow block to a relax.reshape operator
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/tirx/analysis.h>
#include <tvm/tirx/function.h>
#include <vector>
#include "../op/tensor/manipulate.h"
namespace tvm {
namespace relax {
std::vector<size_t> GetUsedTensorArgIndices(const tirx::PrimFunc& fn, size_t num_args) {
std::vector<size_t> indices;
for (size_t i = 0; i < num_args; ++i) {
if (auto buffer = fn->buffer_map.Get(fn->params[i])) {
auto buffer_var = buffer.value()->data;
if (tirx::UsesVar(fn->body,
[=](const tirx::VarNode* var) { return var == buffer_var.get(); })) {
indices.push_back(i);
}
}
}
return indices;
}
class DataflowReshapeRewriter : public ExprMutator {
public:
explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}
private:
using ExprMutator::VisitExpr_;
BindingBlock VisitBindingBlock(const BindingBlock& block) final {
// We only rewrite the bindings inside dataflow blocks.
if (const auto* dataflow_block = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(dataflow_block);
} else {
return block;
}
}
void VisitBinding_(const VarBindingNode* binding) final {
// We only rewrite the bindings that are not dataflow output (which means they are not
// externally referenced)
if (!binding->var->IsInstance<DataflowVarNode>()) {
this->builder_->EmitNormalized(ffi::GetRef<VarBinding>(binding));
} else {
ExprMutator::VisitBinding_(binding);
}
}
Expr VisitExpr_(const CallNode* call) final {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (call->op != call_tir_op) {
return ffi::GetRef<Call>(call);
}
// We bring the calls of reshape PrimFunc back to calls of high-level
// relax.reshape op, which will be lowered to calls of the ExternFunc
// vm.builtin.reshape in the VMBuiltinLower pass.
auto prim_fn = Downcast<tirx::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
auto arg_tuple = Downcast<Tuple>(call->args[1])->fields;
auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size());
// The number of inputs to call_tir(reshape, (...)) might not be one, since FuseOps
// can generate a fused TupleGetItem + reshape function whose input is a tuple. FuseTIR
// then flattens the tuple input so that the fused TIR reshape function ends up having
// multiple input buffers. But only one of them should be accessed and reshaped.
if (used_tensor_arg_indices.size() != 1) {
return ffi::GetRef<Call>(call);
}
auto arg = arg_tuple[used_tensor_arg_indices[0]];
if (!IsCallingTIRReshape(call, arg)) {
return ffi::GetRef<Call>(call);
}
TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_.value());
return reshape(arg, res_sinfo->shape.value());
}
bool IsCallingTIRReshape(const CallNode* call, Expr inp) {
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const auto* func = mod_->functions.Get(global_var).value().as<tirx::PrimFuncNode>();
TVM_FFI_ICHECK_NOTNULL(func);
if (!HasReshapePattern(ffi::GetRef<tirx::PrimFunc>(func))) {
return false;
}
// The reshape operator expects that the number of elements in the source is the same
// as the number of elements in the result. There are operators that could have a reshape
// pattern that don't meet this requirement (e.g. strided_slice), and they should not be
// converted to reshape.
TVM_FFI_ICHECK(inp->struct_info_.defined() && call->struct_info_.defined());
TensorStructInfo inp_sinfo = Downcast<TensorStructInfo>(inp->struct_info_.value());
TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_.value());
if (inp_sinfo->IsUnknownDtype() || inp_sinfo->dtype != res_sinfo->dtype) {
return false;
}
TVM_FFI_ICHECK(inp_sinfo->shape.defined() && res_sinfo->shape.defined());
if (inp_sinfo->IsUnknownNdim() || res_sinfo->IsUnknownNdim()) {
return false;
}
auto product = [](ffi::Array<PrimExpr> args) -> PrimExpr {
PrimExpr p;
if (args.empty()) {
// Scalar tensors may be empty indicating a single element.
p = 1;
} else {
p = args[0];
}
for (int i = 1, e = args.size(); i < e; ++i) p *= args[i];
return p;
};
auto inp_count = product(inp_sinfo->GetShape().value());
auto res_count = product(res_sinfo->GetShape().value());
if (!arith::Analyzer().CanProveEqual(inp_count, res_count)) {
return false;
}
return true;
}
const IRModule& mod_;
};
Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) {
return DataflowReshapeRewriter(mod)(f);
}
namespace transform {
Pass RewriteDataflowReshape() {
auto pass_func = [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(RewriteDataflowReshape(f, m));
};
return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.RewriteDataflowReshape", RewriteDataflowReshape);
}
} // namespace transform
} // namespace relax
} // namespace tvm