blob: f648aca18e46bece95189d3a7310a52ca1c38597 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file expr.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <limits>
#include <memory>
#include "../../support/str_escape.h"
namespace tvm {
namespace tir {
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b) { \
using T = Name::ContainerType; \
CHECK(a.defined()) << "ValueError: a is undefined\n"; \
CHECK(b.defined()) << "ValueError: b is undefined\n"; \
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = a.dtype(); \
node->a = std::move(a); \
node->b = std::move(b); \
data_ = std::move(node); \
}
#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b) { \
using T = Name::ContainerType; \
CHECK(a.defined()) << "ValueError: a is undefined\n"; \
CHECK(b.defined()) << "ValueError: b is undefined\n"; \
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = DataType::Bool(a.dtype().lanes()); \
node->a = std::move(a); \
node->b = std::move(b); \
data_ = std::move(node); \
}
// Var
Var::Var(String name_hint, DataType dtype) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
Var::Var(String name_hint, Type type_annotation) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = GetRuntimeDataType(type_annotation);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}
Var Var::copy_with_suffix(const String& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
if (auto* ptr = this->as<SizeVarNode>()) {
new_ptr = make_object<SizeVarNode>(*ptr);
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint = new_ptr->name_hint + suffix;
return Var(new_ptr);
}
TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMArgValue type) {
if (type.IsObjectRef<Type>()) {
return Var(name_hint, type.operator Type());
} else {
return Var(name_hint, type.operator DataType());
}
});
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const VarNode*>(node.get());
// omit the type
// stream << op->name << "." << op->type;
p->stream << op->name_hint;
});
// SizeVar
SizeVar::SizeVar(String name_hint, DataType dtype) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t) {
return SizeVar(s, t);
});
TVM_REGISTER_NODE_TYPE(SizeVarNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SizeVarNode*>(node.get());
p->stream << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
});
// IterVar
IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag) {
ObjectPtr<IterVarNode> n = make_object<IterVarNode>();
n->dom = dom;
n->var = var;
n->iter_type = t;
n->thread_tag = thread_tag;
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.IterVar")
.set_body_typed([](Range dom, Var var, int iter_type, String thread_tag) {
return IterVar(dom, var, static_cast<IterVarType>(iter_type), thread_tag);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const IterVarNode*>(node.get());
p->stream << "iter_var(";
if (op->var->name_hint.length() != 0) {
p->stream << op->var->name_hint << ", ";
}
if (op->dom.defined()) {
p->stream << op->dom;
}
if (op->thread_tag.length() != 0) {
p->stream << ", " << op->thread_tag;
}
p->stream << ")";
});
TVM_REGISTER_NODE_TYPE(IterVarNode);
// StringImm
StringImm::StringImm(String value) {
ObjectPtr<StringImmNode> node = make_object<StringImmNode>();
node->dtype = DataType::Handle();
node->value = std::move(value);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value) { return StringImm(value); });
TVM_REGISTER_NODE_TYPE(StringImmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StringImmNode*>(node.get());
p->stream << '\"' << support::StrEscape(op->value) << '\"';
});
// Cast
Cast::Cast(DataType t, PrimExpr value) {
CHECK(value.defined());
CHECK_EQ(t.lanes(), value.dtype().lanes());
ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Cast").set_body_typed([](DataType dtype, PrimExpr value) {
return Cast(dtype, value);
});
TVM_REGISTER_NODE_TYPE(CastNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CastNode*>(node.get());
p->stream << op->dtype << '(';
p->Print(op->value);
p->stream << ')';
});
// Add
TVM_DEFINE_BINOP_CONSTRUCTOR(Add);
TVM_REGISTER_GLOBAL("tir.Add").set_body_typed([](PrimExpr a, PrimExpr b) { return Add(a, b); });
TVM_REGISTER_NODE_TYPE(AddNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AddNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AddNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " + ";
p->Print(op->b);
p->stream << ')';
});
// Sub
TVM_DEFINE_BINOP_CONSTRUCTOR(Sub);
TVM_REGISTER_GLOBAL("tir.Sub").set_body_typed([](PrimExpr a, PrimExpr b) { return Sub(a, b); });
TVM_REGISTER_NODE_TYPE(SubNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SubNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SubNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " - ";
p->Print(op->b);
p->stream << ')';
});
// Mul
TVM_DEFINE_BINOP_CONSTRUCTOR(Mul);
TVM_REGISTER_GLOBAL("tir.Mul").set_body_typed([](PrimExpr a, PrimExpr b) { return Mul(a, b); });
TVM_REGISTER_NODE_TYPE(MulNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MulNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MulNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << "*";
p->Print(op->b);
p->stream << ')';
});
// Div
TVM_DEFINE_BINOP_CONSTRUCTOR(Div);
TVM_REGISTER_GLOBAL("tir.Div").set_body_typed([](PrimExpr a, PrimExpr b) { return Div(a, b); });
TVM_REGISTER_NODE_TYPE(DivNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DivNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DivNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << "/";
p->Print(op->b);
p->stream << ')';
});
// Mod
TVM_DEFINE_BINOP_CONSTRUCTOR(Mod);
TVM_REGISTER_GLOBAL("tir.Mod").set_body_typed([](PrimExpr a, PrimExpr b) { return Mod(a, b); });
TVM_REGISTER_NODE_TYPE(ModNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ModNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ModNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " % ";
p->Print(op->b);
p->stream << ')';
});
// FloorDiv
TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv);
TVM_REGISTER_GLOBAL("tir.FloorDiv").set_body_typed([](PrimExpr a, PrimExpr b) {
return FloorDiv(a, b);
});
TVM_REGISTER_NODE_TYPE(FloorDivNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloorDivNode*>(node.get());
p->stream << "floordiv(" << op->a << ", " << op->b << ")";
});
// FloorMod
TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod);
TVM_REGISTER_GLOBAL("tir.FloorMod").set_body_typed([](PrimExpr a, PrimExpr b) {
return FloorMod(a, b);
});
TVM_REGISTER_NODE_TYPE(FloorModNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloorModNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const FloorModNode*>(node.get());
p->stream << "floormod(" << op->a << ", " << op->b << ")";
});
// Min
TVM_DEFINE_BINOP_CONSTRUCTOR(Min);
TVM_REGISTER_GLOBAL("tir.Min").set_body_typed([](PrimExpr a, PrimExpr b) { return Min(a, b); });
TVM_REGISTER_NODE_TYPE(MinNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MinNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MinNode*>(node.get());
p->stream << "min(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
});
// Max
TVM_DEFINE_BINOP_CONSTRUCTOR(Max);
TVM_REGISTER_GLOBAL("tir.Max").set_body_typed([](PrimExpr a, PrimExpr b) { return Max(a, b); });
TVM_REGISTER_NODE_TYPE(MaxNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MaxNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MaxNode*>(node.get());
p->stream << "max(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ")";
});
// EQ
TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ);
TVM_REGISTER_GLOBAL("tir.EQ").set_body_typed([](PrimExpr a, PrimExpr b) { return EQ(a, b); });
TVM_REGISTER_NODE_TYPE(EQNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<EQNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const EQNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " == ";
p->Print(op->b);
p->stream << ')';
});
// NE
TVM_DEFINE_CMPOP_CONSTRUCTOR(NE);
TVM_REGISTER_GLOBAL("tir.NE").set_body_typed([](PrimExpr a, PrimExpr b) { return NE(a, b); });
TVM_REGISTER_NODE_TYPE(NENode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<NENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const NENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " != ";
p->Print(op->b);
p->stream << ')';
});
// LT
TVM_DEFINE_CMPOP_CONSTRUCTOR(LT);
TVM_REGISTER_GLOBAL("tir.LT").set_body_typed([](PrimExpr a, PrimExpr b) { return LT(a, b); });
TVM_REGISTER_NODE_TYPE(LTNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LTNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " < ";
p->Print(op->b);
p->stream << ')';
});
// LE
TVM_DEFINE_CMPOP_CONSTRUCTOR(LE);
TVM_REGISTER_GLOBAL("tir.LE").set_body_typed([](PrimExpr a, PrimExpr b) { return LE(a, b); });
TVM_REGISTER_NODE_TYPE(LENode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " <= ";
p->Print(op->b);
p->stream << ')';
});
// GT
TVM_DEFINE_CMPOP_CONSTRUCTOR(GT);
TVM_REGISTER_GLOBAL("tir.GT").set_body_typed([](PrimExpr a, PrimExpr b) { return GT(a, b); });
TVM_REGISTER_NODE_TYPE(GTNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GTNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const GTNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " > ";
p->Print(op->b);
p->stream << ')';
});
// GE
TVM_DEFINE_CMPOP_CONSTRUCTOR(GE);
TVM_REGISTER_GLOBAL("tir.GE").set_body_typed([](PrimExpr a, PrimExpr b) { return GE(a, b); });
TVM_REGISTER_NODE_TYPE(GENode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GENode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const GENode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " >= ";
p->Print(op->b);
p->stream << ')';
});
// And
And::And(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
ObjectPtr<AndNode> node = make_object<AndNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.And").set_body_typed([](PrimExpr a, PrimExpr b) { return And(a, b); });
TVM_REGISTER_NODE_TYPE(AndNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AndNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AndNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " && ";
p->Print(op->b);
p->stream << ')';
});
// Or
Or::Or(PrimExpr a, PrimExpr b) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(b.defined()) << "ValueError: b is undefined";
CHECK(a.dtype().is_bool());
CHECK(b.dtype().is_bool());
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
ObjectPtr<OrNode> node = make_object<OrNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
node->b = std::move(b);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Or").set_body_typed([](PrimExpr a, PrimExpr b) { return Or(a, b); });
TVM_REGISTER_NODE_TYPE(OrNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<OrNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const OrNode*>(node.get());
p->stream << '(';
p->Print(op->a);
p->stream << " || ";
p->Print(op->b);
p->stream << ')';
});
// Not
Not::Not(PrimExpr a) {
CHECK(a.defined()) << "ValueError: a is undefined";
CHECK(a.dtype().is_bool());
ObjectPtr<NotNode> node = make_object<NotNode>();
node->dtype = DataType::Bool(a.dtype().lanes());
node->a = std::move(a);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Not").set_body_typed([](PrimExpr a) { return Not(a); });
TVM_REGISTER_NODE_TYPE(NotNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<NotNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const NotNode*>(node.get());
p->stream << '!';
p->Print(op->a);
});
// Select
Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
CHECK(condition.defined()) << "ValueError: condition is undefined";
CHECK(true_value.defined()) << "ValueError: true_value is undefined";
CHECK(false_value.defined()) << "ValueError: true_value is undefined";
CHECK(condition.dtype().is_bool());
CHECK(condition.dtype().lanes() == true_value.dtype().lanes() || condition.dtype().lanes() == 1);
CHECK(false_value.dtype() == true_value.dtype()) << "TypeError: mismatched types";
ObjectPtr<SelectNode> node = make_object<SelectNode>();
node->dtype = true_value.dtype();
node->condition = std::move(condition);
node->true_value = std::move(true_value);
node->false_value = std::move(false_value);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Select")
.set_body_typed([](PrimExpr condition, PrimExpr true_value, PrimExpr false_value) {
return Select(condition, true_value, false_value);
});
TVM_REGISTER_NODE_TYPE(SelectNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SelectNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SelectNode*>(node.get());
p->stream << "select(";
p->Print(op->condition);
p->stream << ", ";
p->Print(op->true_value);
p->stream << ", ";
p->Print(op->false_value);
p->stream << ")";
});
// Load
Load::Load(DataType dtype, Var buffer_var, PrimExpr index, PrimExpr predicate) {
CHECK(buffer_var.defined());
CHECK(predicate.defined());
CHECK(index.defined());
CHECK_EQ(dtype.lanes(), index.dtype().lanes());
CHECK_EQ(dtype.lanes(), predicate.dtype().lanes());
ObjectPtr<LoadNode> node = make_object<LoadNode>();
node->dtype = dtype;
node->buffer_var = std::move(buffer_var);
node->index = std::move(index);
node->predicate = std::move(predicate);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Load").set_body([](TVMArgs args, TVMRetValue* ret) {
DataType t = args[0];
if (args.size() == 3) {
*ret = Load(t, args[1], args[2], const_true(t.lanes()));
} else {
*ret = Load(t, args[1], args[2], args[3]);
}
});
TVM_REGISTER_NODE_TYPE(LoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoadNode*>(node.get());
p->stream << op->buffer_var << "[";
p->Print(op->index);
p->stream << "]";
if (!is_one(op->predicate)) {
p->stream << " if ";
p->Print(op->predicate);
}
});
// Ramp
Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes) {
CHECK(base.defined());
CHECK(stride.defined());
CHECK(base.dtype().is_scalar());
CHECK(stride.dtype().is_scalar());
CHECK_GT(lanes, 1);
CHECK_EQ(stride.dtype(), base.dtype());
ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
node->base = base;
node->stride = stride;
node->lanes = lanes;
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Ramp").set_body_typed([](PrimExpr base, PrimExpr stride, int lanes) {
return Ramp(base, stride, lanes);
});
TVM_REGISTER_NODE_TYPE(RampNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RampNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RampNode*>(node.get());
p->stream << "ramp(";
p->Print(op->base);
p->stream << ", ";
p->Print(op->stride);
p->stream << ", " << op->lanes << ")";
});
// Broadcast
Broadcast::Broadcast(PrimExpr value, int lanes) {
CHECK(value.defined());
CHECK(value.dtype().is_scalar());
CHECK_GT(lanes, 1);
ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>();
node->dtype = value.dtype().with_lanes(lanes);
node->value = std::move(value);
node->lanes = lanes;
data_ = node;
}
TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int lanes) {
return Broadcast(value, lanes);
});
TVM_REGISTER_NODE_TYPE(BroadcastNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BroadcastNode*>(node.get());
p->stream << "x" << op->lanes << "(";
p->Print(op->value);
p->stream << ")";
});
// Let
Let::Let(Var var, PrimExpr value, PrimExpr body) {
CHECK(value.defined());
CHECK(body.defined());
CHECK_EQ(value.dtype(), var.dtype());
ObjectPtr<LetNode> node = make_object<LetNode>();
node->dtype = body.dtype();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Let").set_body_typed([](Var var, PrimExpr value, PrimExpr body) {
return Let(var, value, body);
});
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = ";
p->Print(op->value);
p->stream << " in ";
p->Print(op->body);
p->stream << ")";
});
// Call
Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args) {
for (size_t i = 0; i < args.size(); ++i) {
CHECK(args[i].defined());
}
ObjectPtr<CallNode> node = make_object<CallNode>();
node->dtype = dtype;
node->op = std::move(op);
node->args = std::move(args);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Call")
.set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args) {
Array<PrimExpr> prim_expr_args;
for (const auto& it : args) {
CHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>());
if (const auto* str = it.as<runtime::StringObj>()) {
prim_expr_args.push_back(StringImm(str->data));
} else {
prim_expr_args.push_back(Downcast<PrimExpr>(it));
}
}
return Call(type, op, prim_expr_args);
});
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CallNode*>(node.get());
if (auto* ptr_op = op->op.as<OpNode>()) {
p->stream << ptr_op->name << "(";
} else {
auto* ptr_gvar = op->op.as<GlobalVarNode>();
CHECK(ptr_gvar != nullptr);
p->stream << "@" << ptr_gvar->name_hint << "(";
}
for (size_t i = 0; i < op->args.size(); ++i) {
p->Print(op->args[i]);
if (i < op->args.size() - 1) {
p->stream << ", ";
}
}
p->stream << ")";
});
// Shuffle
Shuffle::Shuffle(Array<PrimExpr> vectors, Array<PrimExpr> indices) {
CHECK_NE(vectors.size(), 0U);
CHECK_NE(indices.size(), 0U);
DataType base_type = vectors[0].dtype().element_of();
int total_lanes = 0;
for (PrimExpr val : vectors) {
CHECK(val.dtype().element_of() == base_type);
total_lanes += val.dtype().lanes();
}
CHECK_LE(indices.size(), static_cast<size_t>(total_lanes));
ObjectPtr<ShuffleNode> node = make_object<ShuffleNode>();
node->dtype = base_type.with_lanes(static_cast<int>(indices.size()));
node->vectors = std::move(vectors);
node->indices = std::move(indices);
data_ = node;
}
PrimExpr Shuffle::Concat(Array<PrimExpr> vectors) {
CHECK_NE(vectors.size(), 0);
if (vectors.size() == 1) {
return vectors[0];
}
Array<PrimExpr> indices;
int index = 0;
for (const PrimExpr& e : vectors) {
for (int i = 0; i < e.dtype().lanes(); ++i) {
indices.push_back(IntImm(DataType::Int(32), index++));
}
}
return Shuffle(vectors, indices);
}
PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index) {
return Shuffle({vector}, {Integer(index)});
}
TVM_REGISTER_GLOBAL("tir.Shuffle")
.set_body_typed([](Array<PrimExpr> vectors, Array<PrimExpr> indices) {
return Shuffle(vectors, indices);
});
TVM_REGISTER_NODE_TYPE(ShuffleNode);
template <typename T>
void PrintList(const Array<T>& exprs, ReprPrinter* p) {
for (size_t i = 0; i < exprs.size(); ++i) {
p->Print(exprs[i]);
if (i < exprs.size() - 1) {
p->stream << ", ";
}
}
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ShuffleNode*>(node.get());
p->stream << "shuffle(";
PrintList(op->vectors, p);
p->stream << ", ";
PrintList(op->indices, p);
p->stream << ")";
});
// CommReducer
CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element) {
auto node = make_object<CommReducerNode>();
node->lhs = lhs;
node->rhs = rhs;
node->result = result;
node->identity_element = identity_element;
data_ = std::move(node);
}
Array<PrimExpr> CommReducerNode::operator()(Array<PrimExpr> a, Array<PrimExpr> b) const {
CHECK_EQ(a.size(), b.size());
CHECK_EQ(lhs.size(), a.size());
CHECK_EQ(rhs.size(), b.size());
Map<Var, PrimExpr> value_map;
for (size_t i = 0; i < a.size(); ++i) {
value_map.Set(lhs[i], a[i]);
value_map.Set(rhs[i], b[i]);
}
auto ret = this->result;
ret.MutateByApply([&value_map](const PrimExpr& e) { return Substitute(e, value_map); });
return ret;
}
TVM_REGISTER_GLOBAL("tir.CommReducer")
.set_body_typed([](Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
Array<PrimExpr> identity_element) {
return CommReducer(lhs, rhs, result, identity_element);
});
TVM_REGISTER_GLOBAL("tir.CommReducerCombine")
.set_body_method<tir::CommReducer>(&tir::CommReducerNode::operator());
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const CommReducerNode*>(node.get());
p->stream << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs
<< ", rhs=" << op->rhs << ", identity_element=" << op->identity_element << ")";
});
// Reduce
Reduce::Reduce(CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
PrimExpr condition, int value_index, Array<PrimExpr> init) {
for (size_t i = 0; i < axis.size(); ++i) {
CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis";
}
if (!condition.defined()) {
condition = const_true();
}
auto n = make_object<ReduceNode>();
CHECK(source.defined());
for (size_t i = 0; i < axis.size(); ++i) {
CHECK(axis[i].defined());
}
if (!init.empty()) {
CHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs";
for (size_t i = 0; i < init.size(); i++) {
CHECK(init[i]->IsInstance<ProducerLoadNode>() || init[i]->IsInstance<IntImmNode>() ||
init[i]->IsInstance<FloatImmNode>())
<< "init can only be a IntImm, FloatImm or ProducerLoad";
}
}
n->dtype = source[value_index].dtype();
n->combiner = std::move(combiner);
n->source = std::move(source);
n->init = std::move(init);
n->axis = std::move(axis);
n->condition = condition;
n->value_index = value_index;
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.Reduce")
.set_body_typed([](CommReducer combiner, Array<PrimExpr> source, Array<IterVar> axis,
PrimExpr condition, int value_index, Array<PrimExpr> init) {
return Reduce(combiner, source, axis, condition, value_index, init);
});
TVM_REGISTER_NODE_TYPE(ReduceNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ReduceNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ReduceNode*>(node.get());
p->stream << "reduce(combiner=" << op->combiner;
p->stream << ", source=" << op->source;
p->stream << ", init=" << op->init;
p->stream << ", axis=" << op->axis;
p->stream << ", where=" << op->condition;
p->stream << ", value_index=" << op->value_index;
p->stream << ")";
});
// Any
Any::Any() {
auto n = make_object<AnyNode>();
n->dtype = DataType::Int(32);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.Any").set_body_typed([]() { return Any(); });
TVM_REGISTER_NODE_TYPE(AnyNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AnyNode>([](const ObjectRef& node, ReprPrinter* p) { p->stream << "?"; });
// BufferLoad
BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices) {
ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
node->dtype = buffer->dtype;
node->buffer = std::move(buffer);
node->indices = std::move(indices);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.BufferLoad").set_body_typed([](Buffer buffer, Array<PrimExpr> indices) {
return BufferLoad(buffer, indices);
});
TVM_REGISTER_NODE_TYPE(BufferLoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const BufferLoadNode*>(node.get());
p->stream << op->buffer->name << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});
// ProducerLoad
ProducerLoad::ProducerLoad(DataProducer producer, Array<PrimExpr> indices) {
ObjectPtr<ProducerLoadNode> node = make_object<ProducerLoadNode>();
node->dtype = producer->GetDataType();
node->producer = std::move(producer);
node->indices = std::move(indices);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerLoad")
.set_body_typed([](DataProducer producer, Array<PrimExpr> indices) {
return ProducerLoad(producer, indices);
});
TVM_REGISTER_NODE_TYPE(ProducerLoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ProducerLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ProducerLoadNode*>(node.get());
p->stream << op->producer->GetNameHint() << "[";
for (size_t i = 0; i < op->indices.size(); ++i) {
p->Print(op->indices[i]);
if (i < op->indices.size() - 1) {
p->stream << ", ";
}
}
p->stream << "]";
});
} // namespace tir
} // namespace tvm