blob: 3c35b73bc8d7c02bb6108d72be397c619accc3d3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file texture_flatten.cc
* \brief Flattens texture storage from multi-dimensional array
* to 2D (width, height) buffer access
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/texture.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace tir {
using arith::IRVisitorWithAnalyzer;
using runtime::ApplyTexture2DFlattening;
using runtime::DefaultTextureLayoutSeparator;
using runtime::IsTextureStorage;
class TextureLoweringBase : public StmtExprMutator {
public:
explicit TextureLoweringBase(const Map<Var, Buffer>& extern_buffer_map,
IRVisitorWithAnalyzer* bound_analyzer)
: bound_analyzer_{bound_analyzer} {
for (auto kv : extern_buffer_map) {
extern_buf_.insert(kv.second);
}
}
inline PrimExpr SimplifyOffset(const Array<PrimExpr>& shape, const Array<PrimExpr>& index) const {
PrimExpr base = make_const(DataType::Int(32), 0);
ICHECK_EQ(shape.size(), index.size());
if (index.size() > 0) {
PrimExpr offset = index[0];
for (size_t i = 1; i < index.size(); ++i) {
offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]);
}
base = base + offset;
}
return base;
}
protected:
std::string GetStorageScope(const Buffer& buffer) {
auto* ptr = buffer->data->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
return ptr->storage_scope;
}
// Set of all external input and output buffers
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> extern_buf_;
// Bound analzer
IRVisitorWithAnalyzer* bound_analyzer_;
};
// Lower Nd storage access to 2d texture access using lowering convention
// specified by the buffers storage scope.
class TextureFlattener : public TextureLoweringBase {
public:
using StmtExprMutator::VisitStmt_;
explicit TextureFlattener(const Map<Var, Buffer>& extern_buffer_map,
IRVisitorWithAnalyzer* bound_analyzer)
: TextureLoweringBase(extern_buffer_map, bound_analyzer) {}
Stmt VisitStmt_(const BufferRealizeNode* op) final {
if (extern_buf_.count(op->buffer)) {
return this->VisitStmt(op->body);
}
std::string storage_scope = GetStorageScope(op->buffer);
Var buffer_var(op->buffer->data->name_hint,
PointerType(PrimType(op->buffer->dtype), String(storage_scope)));
let_binding_.insert({op->buffer->data, buffer_var});
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferRealizeNode>();
// Rewrite any buffer realizations with storage scope to 2d texture allocations
if (IsTextureStorage(storage_scope)) {
Stmt body = this->VisitStmt(op->body);
ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported";
int vec_length = static_cast<int>(op->bounds.back()->extent.as<IntImmNode>()->value);
ICHECK(vec_length == 4 || vec_length == 1)
<< "Inner dimension of texture must be vector of length 1 or 4 (RGBA)";
struct ShapeFromRange {
const Array<Range>& bounds;
PrimExpr operator[](size_t i) const { return bounds[i]->extent; }
};
size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope);
auto texture =
ApplyTexture2DFlattening<PrimExpr>(ShapeFromRange{op->bounds}, op->bounds.size(), axis);
Array<PrimExpr> args;
args.push_back(StringImm(storage_scope));
args.push_back(IntImm(DataType::Int(64), 2)); // 2d
args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(),
{texture.width, texture.height}));
stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args),
body);
}
return stmt;
}
Stmt VisitStmt_(const BufferStoreNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<BufferStoreNode>();
std::string storage_scope = GetStorageScope(op->buffer);
// Lower to two dimensional access
if (IsTextureStorage(storage_scope)) {
Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer);
args.push_back(op->value);
stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args));
}
return stmt;
}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<BufferLoadNode>();
// Lower to two dimensional access
std::string storage_scope = GetStorageScope(op->buffer);
if (IsTextureStorage(storage_scope)) {
Array<PrimExpr> args = GetTextureAccessArgs(op, op->buffer);
args.push_back(op->indices.back());
expr = Call(op->buffer->dtype, builtin::texture2d_load(), args);
}
return expr;
}
protected:
template <typename T>
Array<PrimExpr> GetTextureAccessArgs(const T* op, const Buffer& buffer) {
Array<PrimExpr> args;
if (let_binding_.count(op->buffer->data)) {
args.push_back(let_binding_[op->buffer->data]);
} else {
args.push_back(buffer->data);
}
Array<PrimExpr> row_dims, row_indices, col_dims, col_indices;
for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) {
if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) {
col_dims.push_back(op->buffer->shape[i]);
col_indices.push_back(op->indices[i]);
} else {
row_dims.push_back(op->buffer->shape[i]);
row_indices.push_back(op->indices[i]);
}
}
PrimExpr row_offset = SimplifyOffset(row_dims, row_indices);
PrimExpr col_offset = SimplifyOffset(col_dims, col_indices);
args.push_back(row_offset);
args.push_back(col_offset);
return args;
}
// Bindings to new texture vars with texture pointer scope
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
};
PrimFunc TextureFlatten(PrimFunc func) {
auto fptr = func.CopyOnWrite();
IRVisitorWithAnalyzer bound_analyzer;
bound_analyzer(fptr->body);
fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
return func;
}
namespace transform {
Pass TextureFlatten() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return TextureFlatten(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {});
}
TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten);
} // namespace transform
} // namespace tir
} // namespace tvm