blob: 6d94a08cad5d2869c7a59c1894b7e0a515e63d3b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir/op/op.cc
*
* Common operator definitions for ops in tir/op.h
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
#include "../../target/datatype/registry.h"
namespace tvm {
using namespace tir;
// macro to register an unary op
#define TIR_REGISTER_PURE_UNARY_OP(OpName) \
TVM_REGISTER_OP(OpName).set_num_inputs(1).set_attr<TCallEffectKind>( \
"TCallEffectKind", Integer(CallEffectKind::kPure))
// macro to register an binary op
#define TIR_REGISTER_PURE_BINARY_OP(OpName) \
TVM_REGISTER_OP(OpName).set_num_inputs(2).set_attr<TCallEffectKind>( \
"TCallEffectKind", Integer(CallEffectKind::kPure))
runtime::DataType GetRuntimeDataType(const Type& type) {
if (auto* n = type.as<PrimTypeNode>()) {
return n->dtype;
} else if (type.as<PointerTypeNode>()) {
return DataType::Handle();
} else if (IsVoidType(type)) {
return DataType::Void();
} else {
LOG(FATAL) << "Type " << type << " does not have a corresponding runtime::DataType";
return DataType::Handle();
}
}
Type GetType(const PrimExpr& expr) {
// TODO(tqchen): add recursive type inference for Call here
// once we introduced the corresponding fields to the IR.
if (auto* ptr = expr.as<tir::VarNode>()) {
// If Var has a more refined type annotation,
// return the type anotation
if (ptr->type_annotation.defined()) {
return ptr->type_annotation;
}
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
if (dtype.is_void()) {
return VoidType();
}
return PrimType(dtype);
}
// simple cast that only checks if type matches and cast
inline PrimExpr SimpleCast(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
return tir::Cast(t, value);
}
// LargeUIntImm
PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high) {
return tir::Call(t, tir::builtin::large_uint_imm(),
{make_const(DataType::UInt(32), low), make_const(DataType::UInt(32), high)});
}
// Q-multiplication
PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s) {
return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(),
{x, y, q, s});
}
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs) { // NOLINT(*)
if (lhs.dtype() == rhs.dtype()) return;
DataType ltype = lhs.dtype();
DataType rtype = rhs.dtype();
if (ltype.lanes() == 1 && rtype.lanes() != 1) {
lhs = tir::Broadcast(lhs, rtype.lanes());
} else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
rhs = tir::Broadcast(rhs, ltype.lanes());
} else {
CHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype << " vs " << rtype;
}
if (lhs.dtype() == rhs.dtype()) return;
// Only do very simple type coversion
// int->float, DataType::Int(32)->int(64)
// require the types to be relatively consistent
// This will the reduce amount code generated by operators
// and also help user to find potential type conversion problems.
if (!lhs.dtype().is_float() &&
(rhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(rhs.dtype().code()))) {
// int->float
lhs = cast(rhs.dtype(), lhs);
} else if ((lhs.dtype().is_float() ||
datatype::Registry::Global()->GetTypeRegistered(lhs.dtype().code())) &&
!rhs.dtype().is_float()) {
// int->float
rhs = cast(lhs.dtype(), rhs);
} else if ((lhs.dtype().is_int() && rhs.dtype().is_int()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_uint())) {
// promote int to higher bits
if (lhs.dtype().bits() < rhs.dtype().bits()) {
lhs = cast(rhs.dtype(), lhs);
} else {
rhs = cast(lhs.dtype(), rhs);
}
} else if ((lhs.dtype().is_int() && rhs.dtype().is_uint()) ||
(lhs.dtype().is_uint() && rhs.dtype().is_int())) {
int bits = std::max(lhs.dtype().bits(), rhs.dtype().bits());
lhs = SimpleCast(DataType::Int(bits, lhs.dtype().lanes()), lhs);
rhs = SimpleCast(DataType::Int(bits, rhs.dtype().lanes()), rhs);
} else {
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
}
// maximum and min limits
PrimExpr max_value(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_int()) {
if (dtype.bits() == 64) {
return IntImm(dtype, std::numeric_limits<int64_t>::max());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = (val << (dtype.bits() - 1)) - 1;
return IntImm(dtype, val);
}
} else if (dtype.is_uint()) {
if (dtype.bits() == 64) {
return make_const(dtype, std::numeric_limits<uint64_t>::max());
} else if (dtype.bits() < 64) {
uint64_t val = 1;
val = (val << static_cast<uint64_t>(dtype.bits())) - 1;
return IntImm(dtype, static_cast<int64_t>(val));
}
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::max());
} else if (dtype.bits() == 32) {
return FloatImm(dtype, std::numeric_limits<float>::max());
} else if (dtype.bits() == 16) {
return FloatImm(dtype, 65504.0);
}
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
return PrimExpr();
}
PrimExpr min_value(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) {
auto f = datatype::GetMinFunc(dtype.code());
CHECK(f) << "No minimum function registered for custom dtype " << (unsigned int)dtype.code();
// TODO(@hypercubestart) Document this change (and others associated with the overflowing
// floatimm min bug)
return (*f)(dtype.bits());
} else if (dtype.is_int()) {
if (dtype.bits() == 64) {
return IntImm(dtype, std::numeric_limits<int64_t>::lowest());
} else if (dtype.bits() < 64) {
int64_t val = 1;
val = -(val << (dtype.bits() - 1));
return IntImm(dtype, val);
}
} else if (dtype.is_uint()) {
return IntImm(dtype, 0);
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::lowest());
} else if (dtype.bits() == 32) {
return FloatImm(dtype, std::numeric_limits<float>::lowest());
} else if (dtype.bits() == 16) {
return FloatImm(dtype, -65504.0);
}
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
return PrimExpr();
}
// infinity
PrimExpr infinity(const DataType& dtype) {
using namespace tir;
CHECK_EQ(dtype.lanes(), 1);
if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImm(dtype, std::numeric_limits<double>::infinity());
} else if (dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity());
}
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
return PrimExpr();
}
namespace tir {
template <typename ValueType>
inline bool ConstPowerHelper(ValueType val, int* shift) {
if (val <= 0) return false;
shift[0] = 0;
while (val != 0) {
if (val & 1) {
return (val == 1);
}
++shift[0];
val = val >> 1;
}
return true;
}
bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) {
if (const auto* op = x.as<tir::IntImmNode>()) {
return ConstPowerHelper(op->value, shift);
} else {
return false;
}
}
} // namespace tir
PrimExpr cast(const DataType& t, PrimExpr value) {
using tir::FloatImmNode;
if (value.dtype() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
return make_const(t, op->value);
} else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
return make_const(t, op->value);
}
return tir::Cast(t, value);
} else {
if (value.dtype().lanes() == 1) {
// manually unroll cast
DataType vtype = t.element_of();
if (value.dtype() != vtype) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
value = make_const(vtype, op->value);
} else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
value = make_const(vtype, op->value);
} else {
value = tir::Cast(vtype, value);
}
}
return tir::Broadcast(value, t.lanes());
} else {
CHECK(value.dtype().lanes() == t.lanes());
return tir::Cast(t, value);
}
}
}
// reinterpret
PrimExpr reinterpret(const DataType& t, PrimExpr value) {
if (value.dtype() == t) return value;
return tir::Call(t, tir::builtin::reinterpret(), {value});
}
// operator+
PrimExpr operator+(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Add>(a, b);
if (ret.defined()) return ret;
return tir::Add(a, b);
}
// negation
PrimExpr operator-(PrimExpr a) {
using tir::FloatImmNode;
using tir::IntImmNode;
const IntImmNode* pa = a.as<IntImmNode>();
const FloatImmNode* fa = a.as<FloatImmNode>();
if (pa) return IntImm(a.dtype(), -pa->value);
if (fa) return FloatImm(a.dtype(), -fa->value);
return make_zero(a.dtype()) - a;
}
PrimExpr operator-(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Sub>(a, b);
if (ret.defined()) return ret;
return tir::Sub(a, b);
}
PrimExpr operator*(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Mul>(a, b);
if (ret.defined()) return ret;
return tir::Mul(a, b);
}
PrimExpr div(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Div>(a, b);
if (ret.defined()) return ret;
return tir::Div(a, b);
}
PrimExpr truncdiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
return div(a, b);
}
PrimExpr truncmod(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Mod>(a, b);
if (ret.defined()) return ret;
return tir::Mod(a, b);
}
PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); }
PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); }
// TODO(tqchen): switch to floordiv
PrimExpr indexdiv(PrimExpr a, PrimExpr b) { return floordiv(a, b); }
PrimExpr indexmod(PrimExpr a, PrimExpr b) { return floormod(a, b); }
PrimExpr floordiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::FloorDiv>(a, b);
if (ret.defined()) return ret;
return tir::FloorDiv(a, b);
}
PrimExpr floormod(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::FloorMod>(a, b);
if (ret.defined()) return ret;
return tir::FloorMod(a, b);
}
PrimExpr min(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
using arith::is_neg_inf;
using arith::is_pos_inf;
if (is_pos_inf(a)) return b;
if (is_neg_inf(a)) return a;
if (is_pos_inf(b)) return a;
if (is_neg_inf(b)) return b;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Min>(a, b);
if (ret.defined()) return ret;
return tir::Min(a, b);
}
PrimExpr max(PrimExpr a, PrimExpr b) {
// inf-aware simplificaiton
using arith::is_neg_inf;
using arith::is_pos_inf;
if (is_pos_inf(a)) return a;
if (is_neg_inf(a)) return b;
if (is_pos_inf(b)) return b;
if (is_neg_inf(b)) return a;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::Max>(a, b);
if (ret.defined()) return ret;
return tir::Max(a, b);
}
// if_then_else
PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
CHECK(cond.dtype() == DataType::Bool(1))
<< "if_then_else only accept the condition to be boolean type.";
BinaryOpMatchTypes(true_value, false_value);
if (const IntImmNode* op = cond.as<IntImmNode>()) {
if (op->value != 0) {
return true_value;
} else {
return false_value;
}
}
return tir::Call(true_value.dtype(), tir::builtin::if_then_else(),
{cond, true_value, false_value});
}
// likely
PrimExpr likely(PrimExpr cond) {
if (is_const_int(cond)) return cond;
return tir::Call(cond.dtype(), tir::builtin::likely(), {cond});
}
// operator>
PrimExpr operator>(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::GT>(a, b);
if (ret.defined()) return ret;
return tir::GT(a, b);
}
PrimExpr operator>=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::GE>(a, b);
if (ret.defined()) return ret;
return tir::GE(a, b);
}
PrimExpr operator<(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::LT>(a, b);
if (ret.defined()) return ret;
return tir::LT(a, b);
}
PrimExpr operator<=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::LE>(a, b);
if (ret.defined()) return ret;
return tir::LE(a, b);
}
PrimExpr operator==(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::EQ>(a, b);
if (ret.defined()) return ret;
return tir::EQ(a, b);
}
PrimExpr operator!=(PrimExpr a, PrimExpr b) {
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<tir::NE>(a, b);
if (ret.defined()) return ret;
return tir::NE(a, b);
}
PrimExpr operator&&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
PrimExpr ret = arith::TryConstFold<tir::And>(a, b);
if (ret.defined()) return ret;
return tir::And(a, b);
}
PrimExpr operator||(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
PrimExpr ret = arith::TryConstFold<tir::Or>(a, b);
if (ret.defined()) return ret;
return tir::Or(a, b);
}
PrimExpr operator!(PrimExpr a) {
CHECK(a.dtype().is_bool());
PrimExpr ret = arith::TryConstFold<tir::Not>(a);
if (ret.defined()) return ret;
return tir::Not(a);
}
// shirt right
PrimExpr operator>>(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pb)
CHECK(pb->value >= 0 && pb->value < rtype.bits())
<< "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
<< rtype;
if (pa && pb) return IntImm(rtype, (pa->value >> pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b});
}
// shift left
PrimExpr operator<<(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pb)
CHECK(pb->value >= 0 && pb->value < rtype.bits())
<< "Shift amount must be non-negative and less than " << rtype.bits() << " for type "
<< rtype;
if (pa && pb) return IntImm(rtype, (pa->value << pb->value));
if (pb) {
if (pb->value == 0) return a;
}
});
return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b});
}
// bitwise and
PrimExpr operator&(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value & pb->value));
});
return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b});
}
// bitwise_or
PrimExpr operator|(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value | pb->value));
});
return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b});
}
// bitwise_xor
PrimExpr operator^(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
BinaryOpMatchTypes(a, b);
TVM_INDEX_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value));
});
return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b});
}
// bitwie_not
PrimExpr operator~(PrimExpr a) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a});
}
TVM_REGISTER_GLOBAL("tir.bitwise_not").set_body_typed([](PrimExpr a) { return ~a; });
// pow
PrimExpr pow(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "power only applies to float";
static auto op = Op::Get("tir.pow");
return tir::Call(x.dtype(), op, {x, y});
}
TIR_REGISTER_PURE_BINARY_OP("tir.pow").set_attr<TVectorizable>("TVectorizable", true);
// abs
PrimExpr abs(PrimExpr x) {
if (x.dtype().is_int()) {
using tir::IntImmNode;
const IntImmNode* px = x.as<IntImmNode>();
if (px) {
return IntImm(x.dtype(), std::abs(px->value));
}
return tir::Select(x >= make_zero(x.dtype()), x, -x);
} else if (x.dtype().is_float()) {
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return FloatImm(x.dtype(), std::fabs(fx->value));
}
static auto op = Op::Get("tir.fabs");
return tir::Call(x.dtype(), op, {x});
} else if (x.dtype().is_uint()) {
return x;
} else {
LOG(FATAL) << "Data type " << x.dtype()
<< " not supported for absolute op. Skipping absolute op...";
return x;
}
}
TIR_REGISTER_PURE_UNARY_OP("tir.fabs").set_attr<TVectorizable>("TVectorizable", true);
// isnan
PrimExpr isnan(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return make_const(t, std::isnan(fx->value));
}
static auto op = Op::Get("tir.isnan");
if (x.dtype().bits() == 16) {
return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x))});
} else {
return tir::Call(t, op, {x});
}
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for isnan op. Skipping isnan op...";
return x;
}
}
// isinf
PrimExpr isinf(PrimExpr x) {
DataType t = DataType::Bool(x.dtype().lanes());
if (x.dtype().is_int() || x.dtype().is_uint()) {
return make_const(t, false);
} else if (x.dtype().is_float()) {
PrimExpr infX = infinity(x.dtype());
return abs(x) == infX && !isnan(x);
} else {
LOG(FATAL) << "Data type " << x.dtype() << " not supported for finiteness ops. Skipping it...";
return x;
}
}
// isfinite
PrimExpr isfinite(PrimExpr x) { return !isinf(x) && !isnan(x); }
PrimExpr sum(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Add(x, y);
PrimExpr identity_element = make_zero(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
PrimExpr all(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::And(x, y);
PrimExpr identity_element = make_const(source.dtype(), true);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
PrimExpr any(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
CHECK(source.dtype().is_bool());
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Or(x, y);
PrimExpr identity_element = make_const(source.dtype(), false);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
PrimExpr max(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Max(x, y);
PrimExpr identity_element = min_value(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
PrimExpr min(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Min(x, y);
PrimExpr identity_element = max_value(source.dtype());
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
PrimExpr prod(PrimExpr source, Array<IterVar> rdom, Array<PrimExpr> init) {
Var x("x", source.dtype()), y("y", source.dtype());
PrimExpr result = tir::Mul(x, y);
PrimExpr identity_element = make_const(source.dtype(), 1);
tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element});
return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), true), 0, init);
}
// fmod
PrimExpr fmod(PrimExpr x, PrimExpr y) {
BinaryOpMatchTypes(x, y);
CHECK(x.dtype().is_float()) << "fmod only applies to float";
static auto op = Op::Get("tir.fmod");
return tir::Call(x.dtype(), op, {x, y});
}
TIR_REGISTER_PURE_UNARY_OP("tir.fmod");
// floor
PrimExpr floor(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
static auto op = Op::Get("tir.floor");
return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.floor").set_attr<TVectorizable>("TVectorizable", true);
// ceil
PrimExpr ceil(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
static auto op = Op::Get("tir.ceil");
return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.ceil").set_attr<TVectorizable>("TVectorizable", true);
// round
PrimExpr round(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
static auto op = Op::Get("tir.round");
return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.round").set_attr<TVectorizable>("TVectorizable", true);
// nearbyint
PrimExpr nearbyint(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
static auto op = Op::Get("tir.nearbyint");
return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint");
// trunc
PrimExpr trunc(PrimExpr x) {
if (x.dtype().is_int() || x.dtype().is_uint()) {
return x;
}
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)));
}
static auto op = Op::Get("tir.trunc");
return tir::Call(x.dtype(), op, {x});
}
TIR_REGISTER_PURE_UNARY_OP("tir.trunc").set_attr<TVectorizable>("TVectorizable", true);
// unary op registration.
TIR_REGISTER_PURE_UNARY_OP("tir.exp").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.exp2").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.exp10").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.erf");
TIR_REGISTER_PURE_UNARY_OP("tir.tanh").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.sigmoid");
TIR_REGISTER_PURE_UNARY_OP("tir.sqrt").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.rsqrt");
TIR_REGISTER_PURE_UNARY_OP("tir.log").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.log2").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.log1p");
TIR_REGISTER_PURE_UNARY_OP("tir.log10").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.tan").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.cos").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.cosh").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.sin").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.sinh").set_attr<TVectorizable>("TVectorizable", true);
TIR_REGISTER_PURE_UNARY_OP("tir.asin");
TIR_REGISTER_PURE_UNARY_OP("tir.acos");
TIR_REGISTER_PURE_UNARY_OP("tir.atan");
TIR_REGISTER_PURE_UNARY_OP("tir.acosh");
TIR_REGISTER_PURE_UNARY_OP("tir.asinh");
TIR_REGISTER_PURE_UNARY_OP("tir.atanh");
// binary intrinsics
TIR_REGISTER_PURE_BINARY_OP("tir.atan2");
TIR_REGISTER_PURE_BINARY_OP("tir.nextafter");
TIR_REGISTER_PURE_BINARY_OP("tir.hypot");
TIR_REGISTER_PURE_BINARY_OP("tir.copysign");
TIR_REGISTER_PURE_BINARY_OP("tir.ldexp");
// expose basic functions to node namespace
TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kDLInt) {
*ret = tir::make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kDLFloat) {
*ret = tir::make_const(args[1], args[0].operator double());
} else {
LOG(FATAL) << "only accept int or float";
}
});
TVM_REGISTER_GLOBAL("node.LargeUIntImm").set_body_typed(LargeUIntImm);
TVM_REGISTER_GLOBAL("tir.min_value").set_body_typed(min_value);
TVM_REGISTER_GLOBAL("tir.max_value").set_body_typed(max_value);
TVM_REGISTER_GLOBAL("tir.abs").set_body_typed(tvm::abs);
TVM_REGISTER_GLOBAL("tir.isnan").set_body_typed(tvm::isnan);
TVM_REGISTER_GLOBAL("tir.isfinite").set_body_typed(tvm::isfinite);
TVM_REGISTER_GLOBAL("tir.isinf").set_body_typed(tvm::isinf);
TVM_REGISTER_GLOBAL("tir.floor").set_body_typed(tvm::floor);
TVM_REGISTER_GLOBAL("tir.ceil").set_body_typed(tvm::ceil);
TVM_REGISTER_GLOBAL("tir.round").set_body_typed(tvm::round);
TVM_REGISTER_GLOBAL("tir.nearbyint").set_body_typed(tvm::nearbyint);
TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc);
TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast);
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("tir." #Node).set_body([](TVMArgs args, TVMRetValue* ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator PrimExpr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator PrimExpr(), args[1].operator PrimExpr())); \
} \
})
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
REGISTER_MAKE_BINARY_OP(_OpSub, operator-);
REGISTER_MAKE_BINARY_OP(_OpMul, operator*);
REGISTER_MAKE_BINARY_OP(_OpDiv, div);
REGISTER_MAKE_BINARY_OP(_OpMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpIndexDiv, indexdiv);
REGISTER_MAKE_BINARY_OP(_OpIndexMod, indexmod);
REGISTER_MAKE_BINARY_OP(_OpFloorDiv, floordiv);
REGISTER_MAKE_BINARY_OP(_OpFloorMod, floormod);
REGISTER_MAKE_BINARY_OP(_OpTruncDiv, truncdiv);
REGISTER_MAKE_BINARY_OP(_OpTruncMod, truncmod);
REGISTER_MAKE_BINARY_OP(_OpPow, pow);
REGISTER_MAKE_BINARY_OP(_OpMin, min);
REGISTER_MAKE_BINARY_OP(_OpMax, max);
REGISTER_MAKE_BINARY_OP(_OpEQ, operator==);
REGISTER_MAKE_BINARY_OP(_OpNE, operator!=);
REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*)
REGISTER_MAKE_BINARY_OP(_OpGE, operator>=);
REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&);
REGISTER_MAKE_BINARY_OP(_OpOr, operator||);
REGISTER_MAKE_BIT_OP(bitwise_and, operator&);
REGISTER_MAKE_BIT_OP(bitwise_or, operator|);
REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("tir._OpIfThenElse")
.set_body_typed([](PrimExpr cond, PrimExpr true_value, PrimExpr false_value) {
return if_then_else(cond, true_value, false_value);
});
} // namespace tvm