blob: 5c968966e2f0ba87d5eada2100a912e438a8d196 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/scalable_expression.cc
* \brief Analyze scalable expressions.
*/
#include "scalable_expression.h"
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <vector>
#include "../tir/analysis/check_contains.h"
#include "../tir/transforms/replace_selected_expr.h"
#include "./pattern_match.h"
namespace tvm {
namespace arith {
bool IsVScaleCall(const PrimExpr& expr) {
if (auto call = expr.as<tir::CallNode>()) {
return call->op.same_as(tir::builtin::vscale());
}
return false;
}
bool ContainsVscaleCall(const PrimExpr& expr) {
return tir::CheckContains::ExprContains(expr, IsVScaleCall);
}
PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) {
std::function<bool(const PrimExpr&)> predicate_selector = [](const PrimExpr& current_expr) {
return IsVScaleCall(current_expr);
};
std::function<bool(const PrimExpr&)> can_replace_inside = [](const PrimExpr& current_expr) {
return true;
};
return tir::ReplaceSelectedExpr::ReplaceSelectedExprInExpr(
expr, predicate_selector, tir::MakeConstScalar(DataType::Int(32), vscale_value),
can_replace_inside);
}
std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
PVar<IntImm> multiplier;
PCallExpr<PVscaleOp> vscale;
if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
return multiplier.Eval()->value;
} else {
return std::nullopt;
}
}
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
const std::vector<unsigned int>& vscale_values) {
bool can_prove_expr = true;
for (const unsigned int vscale_value : vscale_values) {
PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value);
result = analyzer->Simplify(result);
const int64_t* as_int = tir::as_const_int(result);
if (!as_int || *as_int == 0) {
can_prove_expr = false;
break;
}
}
return can_prove_expr;
}
bool TargetHasVLA(ffi::Optional<Target> target) {
if (!target.defined()) {
target = Target::Current();
}
bool has_vla{false};
if (target.defined()) {
// aarch64
has_vla = Downcast<Target>(target)->GetFeature<Bool>("has_sve").value_or(Bool(false));
// riscv{32,64}
static auto target_has_feature_fn =
tvm::ffi::Function::GetGlobalRequired("target.target_has_feature");
has_vla |= target_has_feature_fn("v", target).cast<bool>();
}
return has_vla;
}
const std::vector<unsigned int> GetVScaleValues(ffi::Optional<Target> target) {
unsigned int vector_width = 0;
std::vector<unsigned int> kVScaleValues;
if (!target.defined()) {
target = Target::Current();
}
if (target.defined()) {
static auto llvm_get_vector_width_fn =
tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
vector_width = llvm_get_vector_width_fn(target).cast<int>();
}
// scale list with powers of two
for (unsigned int i = 0;; ++i) {
auto power = static_cast<unsigned int>(std::pow(2, i));
if (power > (vector_width / 8)) break;
kVScaleValues.push_back(power);
}
return kVScaleValues;
}
} // namespace arith
} // namespace tvm