blob: d2c5d79a0960623ab5c5195c9863fe077d729e06 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bound_deducer.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include "ir_visitor_with_analyzer.h"
namespace tvm {
namespace arith {
using namespace tir;
namespace {
using BufferTouches = std::vector<std::vector<IntSet>>;
struct LoadAccess {
BufferTouches set;
};
struct StoreAccess {
BufferTouches set;
};
struct CombinedAccess {
BufferTouches set;
};
using BufferDomainAccess = std::tuple<LoadAccess, StoreAccess, CombinedAccess>;
} // namespace
// Find Read region of the tensor in the stmt.
class BufferTouchedDomain final : public IRVisitorWithAnalyzer {
public:
BufferTouchedDomain(const Stmt& stmt) { operator()(stmt); }
std::unordered_map<const BufferNode*, BufferDomainAccess>& GetAccessedBufferRegions() {
return buffer_access_map_;
}
Region FindUnion(const Buffer& buffer, bool consider_loads, bool consider_stores) {
Region ret;
auto kv = buffer_access_map_.find(buffer.get());
if (kv == buffer_access_map_.end()) {
LOG(WARNING) << "[arith::BufferDomainTouched] "
<< "The requested buffer is not contained in the provided stmt body: " << buffer;
return ret;
}
Range none;
BufferTouches bounds;
if (consider_loads && consider_stores) {
bounds = std::get<CombinedAccess>(kv->second).set;
} else if (consider_loads) {
bounds = std::get<LoadAccess>(kv->second).set;
} else if (consider_stores) {
bounds = std::get<StoreAccess>(kv->second).set;
} else {
CHECK(false) << "Must consider at least on of either loads and stores, but both are false";
}
for (size_t i = 0; i < bounds.size(); ++i) {
ret.push_back(arith::Union(bounds[i]).CoverRange(none));
}
return ret;
}
private:
using Parent = IRVisitorWithAnalyzer;
using Parent::VisitExpr_;
using Parent::VisitStmt_;
void VisitExpr_(const BufferLoadNode* op) final {
// Record load-exclusive buffer access
Touch(&std::get<LoadAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
Parent::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode* op) final {
// Record store-exclusive buffer access
Touch(&std::get<StoreAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
// Record load-store inclusive buffer access
Touch(&std::get<CombinedAccess>(buffer_access_map_[op->buffer.get()]).set, op->indices);
Parent::VisitStmt_(op);
}
private:
void Touch(BufferTouches* bounds, const Array<PrimExpr>& args) {
if (args.size() > bounds->size()) {
bounds->resize(args.size());
}
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].as<RampNode>()) {
(*bounds)[i].emplace_back(IntSet::Vector(args[i]));
} else {
(*bounds)[i].emplace_back(analyzer_.int_set(args[i]));
}
}
}
std::unordered_map<const BufferNode*, BufferDomainAccess> buffer_access_map_;
};
Region DomainTouched(const Stmt& stmt, const Buffer& buffer, bool consider_loads,
bool consider_stores) {
return BufferTouchedDomain(stmt).FindUnion(buffer, consider_loads, consider_stores);
}
Map<Buffer, runtime::ADT> DomainTouchedAccessMap(const PrimFunc& func) {
auto buffer_access_map = BufferTouchedDomain(func->body).GetAccessedBufferRegions();
Map<Buffer, runtime::ADT> ret;
auto& buffer_map = func->buffer_map;
for (auto& var : func->params) {
auto& buffer = buffer_map[var];
auto& access = buffer_access_map[buffer.get()];
Array<Array<IntSet>> loads, stores, combined;
for (std::vector<IntSet>& touch : std::get<LoadAccess>(access).set) {
loads.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<StoreAccess>(access).set) {
stores.push_back(Array<IntSet>(touch));
}
for (std::vector<IntSet>& touch : std::get<CombinedAccess>(access).set) {
combined.push_back(Array<IntSet>(touch));
}
std::vector<ObjectRef> fields;
fields.push_back(loads);
fields.push_back(stores);
fields.push_back(combined);
ret.Set(buffer, runtime::ADT::Tuple(fields));
}
return ret;
}
TVM_REGISTER_GLOBAL("arith.DomainTouched").set_body_typed(DomainTouched);
TVM_REGISTER_GLOBAL("arith.DomainTouchedAccessMap").set_body_typed(DomainTouchedAccessMap);
} // namespace arith
} // namespace tvm