blob: 5e76af12e5834dfc38f7e9ef8b758ac4c9031d96 [file] [log] [blame]
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_common.h
* \brief Common utility for codegen.
*/
#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_
#define TVM_CODEGEN_CODEGEN_COMMON_H_
#include <tvm/arithmetic.h>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
/*!
* \brief Visit AssertStmt recursively, update align_map from condition.
* \param op The AssertStmt
* \param align_map The alignmap
* \param fvisit The recursive visitor
* \tparam FVisit the recursive visitor
*/
template<typename FVisit>
inline void VisitAssert(
const ir::AssertStmt* op,
std::unordered_map<const Variable*, arith::ModularEntry>* align_map,
FVisit fvisit) {
using namespace ir;
auto& align_map_ = *align_map;
// Detect useful invariant pattern and use them to visit child.
// Pattern: Var % const == 0
// TODO(tqchen) merge these pattern to a generic scope info visitor.
if (const EQ* eq = op->condition.as<EQ>()) {
const Mod* mod = eq->a.as<Mod>();
int64_t factor = 0, offset = 0;
if (mod && arith::GetConst(eq->b, &offset)) {
const Variable *var = mod->a.as<Variable>();
if (var && arith::GetConst(mod->b, &factor)) {
arith::ModularEntry old = align_map_[var];
if (factor > old.coeff) {
arith::ModularEntry e;
e.coeff = static_cast<int>(factor);
e.base = static_cast<int>(offset);
// new alignment info,
align_map_[var] = e;
fvisit(op->body);
// restore old info
align_map_[var] = old;
return;
}
}
}
}
fvisit(op->body);
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_COMMON_H_