blob: 5b2b5704c5c9dcd329a285587253b94ce4bdffe2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file remove_weight_layout_rewrite_block.cc
* \brief Remove weight layout rewrite block before benchmark
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/index_map.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
namespace tvm {
namespace tir {
class RemoveLayoutRewriteBlock : public StmtMutator {
public:
static std::tuple<PrimFunc, ffi::Map<Buffer, Buffer>,
std::unordered_map<const VarNode*, IndexMap>,
std::unordered_map<const VarNode*, ffi::Array<PrimExpr>>>
Rewrite(PrimFunc f) {
RemoveLayoutRewriteBlock rewriter;
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
return std::make_tuple(f, rewriter.buf_map_, rewriter.buffer_var_to_index_map_,
rewriter.buffer_var_to_rewritten_shape_);
}
private:
Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
auto it = block->annotations.find(attr::meta_schedule_layout_rewrite_preproc);
if (it == block->annotations.end() || !is_one(Downcast<PrimExpr>((*it).second))) {
// The block is not a weight layout block
// Remove allocates if needed
ffi::Array<Buffer> alloc_buffers;
for (const Buffer& buffer : block->alloc_buffers) {
if (!rewritten_buffers_.count(buffer)) {
alloc_buffers.push_back(buffer);
}
}
if (alloc_buffers.size() < block->alloc_buffers.size()) {
auto n = CopyOnWrite(block.get());
n->alloc_buffers = std::move(alloc_buffers);
return Stmt(n);
} else {
return block;
}
}
// Step 0. Checking block attrs
ICHECK(block->alloc_buffers.empty());
ICHECK(block->match_buffers.empty());
// Step 1. Checking the body is a BufferStore
const auto* store = block->body.as<BufferStoreNode>();
ICHECK(store);
// Step 2. Checking the rhs of buffer store is a BufferLoad
const auto* load = store->value.as<BufferLoadNode>();
ICHECK(load);
// Step 3. Update Buffer
buf_map_.Set(load->buffer, store->buffer);
rewritten_buffers_.insert(store->buffer);
// Step 4. Set block body as no_op
auto n = CopyOnWrite(block.get());
n->body = std::move(Evaluate(0));
n->reads = {};
n->writes = {};
ffi::Array<Var> load_indices;
for (auto ind : load->indices) {
ICHECK(ind->IsInstance<VarNode>());
load_indices.push_back(Downcast<Var>(ind));
}
buffer_var_to_index_map_[load->buffer->data.get()] = IndexMap(load_indices, store->indices);
buffer_var_to_rewritten_shape_[load->buffer->data.get()] = store->buffer->shape;
return Stmt(n);
}
private:
/*! \brief The buffer map from original layout buffer to rewritten buffer */
ffi::Map<Buffer, Buffer> buf_map_;
/*! \brief The buffer map from original layout buffer to rewritten buffer */
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> rewritten_buffers_;
/*! \brief Maps a buffer load to an index map associated with the load / store
in a layout rewrite block. */
std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
/*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
std::unordered_map<const VarNode*, ffi::Array<PrimExpr>> buffer_var_to_rewritten_shape_;
};
// After RemoveLayoutRewriteBlock, the body of a compute update block references a
// non-existant buffer. For example, fused_constant_2_global below is originally a
// cache_read buffer, whose allocation is removed by RemoveLayoutRewriteBlock:
//
// constant fused_constant_2[float32 * 3 * 3 * 64 * 64]
// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2_global[ry,
// floordiv(rc, 32),
// floordiv(ff, 16),
// rx,
// floormod(rc, 32),
// floormod(ff, 16)]))
//
// When cache_read is reading from AllocateConstant, we need to replace the reference
// to fused_constant_2_global with the corresponding transformed AllocateConstant.
// To do that, we manually rewrite the original constant using the associated index map,
// and let the body of the compute block to load from the rewritten constant.
//
// After this transformation, the example above looks like:
//
// constant fused_constant_2[float32 * 3 * 2 * 4 * 3 * 32 * 16]
// conv2d_nhwc[nn, yy, xx, ff] += ... * fused_constant_2[ry,
// floordiv(rc, 32),
// floordiv(ff, 16),
// rx,
// floormod(rc, 32),
// floormod(ff, 16)]))
using BufferVarMap = std::unordered_map<const tir::VarNode*, const tir::VarNode*>;
class AllocateConstRewrite : public StmtExprMutator {
public:
AllocateConstRewrite(
const BufferVarMap& buffer_var_map,
const std::unordered_map<const VarNode*, IndexMap>& buffer_var_to_index_map,
const std::unordered_map<const VarNode*, ffi::Array<PrimExpr>>& buffer_var_to_rewritten_shape,
bool skip_tensor_rewrite)
: buffer_var_map_(buffer_var_map),
buffer_var_to_index_map_(buffer_var_to_index_map),
buffer_var_to_rewritten_shape_(buffer_var_to_rewritten_shape),
skip_tensor_rewrite_(skip_tensor_rewrite) {}
private:
Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
auto n = CopyOnWrite(block.get());
ffi::Array<BufferRegion> new_reads;
for (auto read_region : op->reads) {
if (auto it = new_load_buf_.find(read_region->buffer->data.get());
it != new_load_buf_.end()) {
new_reads.push_back(BufferRegion(it->second, read_region->region));
} else {
new_reads.push_back(read_region);
}
}
n->reads = new_reads;
return Stmt(n);
}
Stmt VisitStmt_(const AllocateConstNode* alloc) final {
if (auto it = buffer_var_to_index_map_.find(alloc->buffer_var.get());
it != buffer_var_to_index_map_.end()) {
ICHECK(buffer_var_to_rewritten_shape_.count(alloc->buffer_var.get()));
auto new_body = StmtMutator::VisitStmt(alloc->body);
auto rewritten_tensor = RewriteTensor(
alloc->data.value(), it->second, buffer_var_to_rewritten_shape_[alloc->buffer_var.get()]);
ffi::Array<PrimExpr> rewritten_extents;
for (auto s : rewritten_tensor.Shape()) {
rewritten_extents.push_back(PrimExpr(static_cast<int>(s)));
}
return AllocateConst(alloc->buffer_var, alloc->dtype, rewritten_extents, rewritten_tensor,
new_body, alloc->annotations, alloc->span);
}
return StmtMutator::VisitStmt_(alloc);
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (auto it = buffer_var_map_.find(op->buffer->data.get()); it != buffer_var_map_.end()) {
auto new_buffer =
Buffer(ffi::GetRef<Var>(it->second), op->buffer->dtype, op->buffer->shape,
op->buffer->strides, op->buffer->elem_offset, it->second->name_hint,
op->buffer->data_alignment, op->buffer->offset_factor, op->buffer->buffer_type);
new_load_buf_[op->buffer->data.get()] = new_buffer;
return BufferLoad(new_buffer, op->indices, op->predicate);
}
return ExprMutator::VisitExpr_(op);
}
runtime::Tensor RewriteTensor(runtime::Tensor src, const IndexMap& index_map,
const ffi::Array<PrimExpr>& dst_shape) {
if (skip_tensor_rewrite_) {
// Only the shape of the destination array needs to be correct.
std::vector<int64_t> dst_shape_int;
for (auto s : dst_shape) {
ICHECK(s->IsInstance<IntImmNode>());
dst_shape_int.push_back(s.as<IntImmNode>()->value);
}
return src.CreateView(dst_shape_int, src.DataType());
} else {
return index_map->MapTensor(src);
}
}
/*! \brief Maps a buffer store to a load in a layout rewrite block */
BufferVarMap buffer_var_map_;
/*! \brief Maps a buffer load to an index map associated with the load / store
in a layout rewrite block. */
std::unordered_map<const VarNode*, IndexMap> buffer_var_to_index_map_;
/*! \brief Maps a buffer load to the shape of the corresponding rewritten buffer. */
std::unordered_map<const VarNode*, ffi::Array<PrimExpr>> buffer_var_to_rewritten_shape_;
/*! \brief Maps load buffer variables to newly created buffers */
std::unordered_map<const VarNode*, Buffer> new_load_buf_;
/*! \brief Whether or not to skip rewriting of Tensor contents */
bool skip_tensor_rewrite_;
};
class CollectAllocateConstBufferVars : public StmtVisitor {
public:
void VisitStmt_(const AllocateConstNode* alloc) final {
StmtVisitor::VisitStmt_(alloc);
constant_buf_var.insert(alloc->buffer_var.get());
}
std::unordered_set<const VarNode*> constant_buf_var;
};
class WeightLayoutRewriteBlockRemover : public StmtMutator {
public:
static PrimFunc Remove(PrimFunc f, bool skip_tensor_rewrite) {
CollectAllocateConstBufferVars collector;
collector(f->body);
auto [f_, buf_map, buffer_var_to_index_map, buffer_var_to_rewritten_shape] =
RemoveLayoutRewriteBlock().Rewrite(f);
BufferVarMap buffer_var_map;
for (const auto& [load_buf, store_buf] : buf_map) {
if (collector.constant_buf_var.find(load_buf->data.get()) !=
collector.constant_buf_var.end()) {
buffer_var_map[store_buf->data.get()] = load_buf->data.get();
}
}
PrimFuncNode* n = f_.CopyOnWrite();
AllocateConstRewrite rewriter(buffer_var_map, buffer_var_to_index_map,
buffer_var_to_rewritten_shape, skip_tensor_rewrite);
n->body = rewriter(std::move(n->body));
ffi::Map<tir::Var, Buffer> buffer_map;
for (const auto& [param, buffer] : f_->buffer_map) {
auto it = buf_map.find(buffer);
if (it != buf_map.end()) {
buffer_map.Set(param, (*it).second);
} else {
buffer_map.Set(param, buffer);
}
}
n->buffer_map = std::move(buffer_map);
return f_;
}
};
namespace transform {
Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite) {
auto pass_func = [skip_tensor_rewrite](PrimFunc f, IRModule m, PassContext ctx) {
return WeightLayoutRewriteBlockRemover::Remove(std::move(f), skip_tensor_rewrite);
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemoveWeightLayoutRewriteBlock", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tir.transform.RemoveWeightLayoutRewriteBlock",
RemoveWeightLayoutRewriteBlock);
}
} // namespace transform
} // namespace tir
} // namespace tvm