blob: 5d737d7640d163007c0296bf40a02c448e4a1145 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file s_tir/transform/canonicalize_loop.cc
* \brief Canonicalize all loops to start from zero and step one.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/device_api.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <utility>
namespace tvm {
namespace s_tir {
using namespace tvm::tir;
class LoopCanonicalizer : public StmtExprMutator {
public:
LoopCanonicalizer() = default;
private:
Stmt VisitStmt_(const ForNode* op) final {
if (is_zero(op->min) && op->HasTrivialStep()) {
return StmtExprMutator::VisitStmt_(op);
}
const auto* loop_var = op->loop_var.get();
PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1));
// report warning for negative step, since it would be a forever loop
if (!analyzer_.CanProveGreaterEqual(step, 1)) {
// TODO(tvm): prove dynamic shaped step
LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step;
}
new_iter_info_[loop_var] = std::make_pair(step, op->min);
auto n = CopyOnWrite(op);
n->body = VisitStmt(op->body);
n->min = make_zero(loop_var->dtype);
n->extent = analyzer_.Simplify(ceildiv(op->extent, step));
n->step = std::nullopt;
new_iter_info_.erase(loop_var);
return For(n);
}
PrimExpr VisitExpr_(const VarNode* op) final {
auto it = new_iter_info_.find(op);
if (it != new_iter_info_.end()) {
const auto& [stride, offset] = it->second;
return ffi::GetRef<Var>(op) * stride + offset;
}
return ffi::GetRef<Var>(op);
}
private:
arith::Analyzer analyzer_;
/*! \brief Map iter variable `x` to `x * stride + offset`. */
std::unordered_map<const VarNode*, std::pair<PrimExpr, PrimExpr>> new_iter_info_;
};
namespace transform {
Pass CanonicalizeLoop() {
auto pass_func = [=](PrimFunc func, IRModule m, PassContext ctx) {
PrimFuncNode* fptr = func.CopyOnWrite();
fptr->body = LoopCanonicalizer()(std::move(fptr->body));
return func;
};
return CreatePrimFuncPass(pass_func, 0, "s_tir.CanonicalizeLoop", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("s_tir.transform.CanonicalizeLoop", CanonicalizeLoop);
}
} // namespace transform
} // namespace s_tir
} // namespace tvm