blob: ac95acce63f2d2384529e325e2a4c0139a2d913e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/transform/convert_dataflow.cc
* \brief Pass for extracting groups of pure operations without
* dataflow into dataflow blocks.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/utils.h>
#include <optional>
namespace tvm {
namespace relax {
class DataflowBlockExtractor : public ExprMutator {
public:
explicit DataflowBlockExtractor(size_t min_size) : ExprMutator(), min_size_(min_size) {}
Expr VisitExpr_(const SeqExprNode* seq) override {
ffi::Array<BindingBlock> new_blocks;
Expr new_body = VisitExpr(seq->body);
bool changed = !new_body.same_as(seq->body);
// Accumulated bindings that are not going to be added to a
// DataflowBlock, either because they would be illegal within a
// DataflowBlock, or because there were insufficient bindings to
// make a dataflowblock. Because these bindings occur prior to
// `dataflow_bindings`, this array may only be accumulated into
// when `dataflow_bindings` is empty.
ffi::Array<Binding> non_dataflow_bindings;
// Current bindings that may legally be added to a DataflowBlock.
ffi::Array<Binding> dataflow_bindings;
// If present, a DataflowBlock whose bindings are currently in
// `dataflow_bindings`. Used to propagate DataflowBlock to the
// output, even if it doesn't meet the minimum size.
ffi::Optional<DataflowBlock> input_dataflow_block;
// Handle any bindings currently in `dataflow_bindings`. These
// are either pushed to their own block, or to the end of
// `non_dataflow_bindings`, depending on whether the bindings meet
// the minimum size requirement.
auto push_dataflow_bindings = [&]() {
if (dataflow_bindings.empty()) {
// No Dataflow bindings, so no action required.
return;
}
if (dataflow_bindings.size() < min_size_ && !input_dataflow_block) {
// The df block is below the minimum length, and no input
// DataflowBlock needs to be preserved. Combine the blocks
// and reset the dataflow collection.
non_dataflow_bindings.insert(non_dataflow_bindings.end(), dataflow_bindings.begin(),
dataflow_bindings.end());
} else {
// A new DataflowBlock can be generated, with bindings that
// occur after the non-dataflow bindings.
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
new_blocks.push_back(DataflowBlock(dataflow_bindings));
non_dataflow_bindings = {};
// Making a dataflow block doesn't imply that the function was
// changed. A change requires that this either be a new
// dataflow block, or have additional dataflow bindings in the
// current block.
changed = changed || !input_dataflow_block.defined() ||
input_dataflow_block.value()->bindings.size() != dataflow_bindings.size();
}
dataflow_bindings = {};
input_dataflow_block = std::nullopt;
};
for (auto block : seq->blocks) {
BindingBlock new_block = this->VisitBindingBlock(block);
changed = changed || !new_block.same_as(block);
// For an existing dataflow block, we add to the current streak
// or start a new streak in case there will be more dataflow operations
// coming up
if (auto dataflow_block = new_block.as<DataflowBlock>()) {
dataflow_bindings.insert(dataflow_bindings.end(), new_block->bindings.begin(),
new_block->bindings.end());
input_dataflow_block = dataflow_block;
continue;
}
// for a binding block, attempt to extract dataflow blocks inside
auto binding_block = Downcast<BindingBlock>(new_block);
for (const auto& binding : binding_block->bindings) {
Expr value = GetBoundValue(binding);
// dataflow values: not an if node and not an impure call
bool is_dataflow = (!value.as<IfNode>()) &&
(!(value.as<CallNode>() && IsImpureCall(Downcast<Call>(value))));
if (is_dataflow) {
// extend the streak
dataflow_bindings.push_back(binding);
} else {
// End the streak, if one currently exists.
push_dataflow_bindings();
non_dataflow_bindings.push_back(binding);
}
}
}
// handle any remaining bindings
push_dataflow_bindings();
new_blocks.push_back(BindingBlock(non_dataflow_bindings));
if (changed) {
return SeqExpr(new_blocks, new_body);
} else {
return ffi::GetRef<SeqExpr>(seq);
}
}
private:
size_t min_size_;
};
Expr ConvertToDataflow(const Expr& input, size_t min_size) {
DataflowBlockExtractor extractor(min_size);
return extractor.VisitExpr(input);
}
namespace transform {
Pass ConvertToDataflow(int min_size) {
auto pass_func = [=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(ConvertToDataflow(f, min_size));
};
auto pass = CreateFunctionPass(pass_func, 0, "ConvertToDataflow", {});
// Canonicalize bindings is included afterwards in order to transform any
// normal vars in DF blocks that are not used outside the DF block into
// dataflow vars. This allows us to avoid reimplementing that functionality.
return tvm::transform::Sequential({pass, CanonicalizeBindings()});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.transform.ConvertToDataflow", ConvertToDataflow);
}
} // namespace transform
} // namespace relax
} // namespace tvm