| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file codegen_hybrid.cc |
| */ |
| #include "codegen_hybrid.h" |
| |
| #include <tvm/runtime/registry.h> |
| #include <tvm/tir/builtin.h> |
| |
| #include <cctype> |
| #include <iomanip> |
| |
| namespace tvm { |
| namespace contrib { |
| |
| using runtime::TVMArgs; |
| using runtime::TVMRetValue; |
| |
| using namespace tir; |
| |
| std::string dot_to_underscore(std::string s) { |
| for (auto& ch : s) |
| if (ch == '.') ch = '_'; |
| return s; |
| } |
| |
| std::string CodeGenHybrid::GetUniqueName(std::string prefix) { |
| prefix = dot_to_underscore(prefix); |
| auto it = ids_allocated_.find(prefix); |
| if (it != ids_allocated_.end()) { |
| while (true) { |
| std::ostringstream os; |
| os << prefix << (++it->second); |
| std::string name = os.str(); |
| if (ids_allocated_.count(name) == 0) { |
| prefix = name; |
| break; |
| } |
| } |
| } |
| ids_allocated_[prefix] = 0; |
| return prefix; |
| } |
| |
| std::string CodeGenHybrid::Finish() { return stream.str(); } |
| |
| void CodeGenHybrid::PrintType(DataType t, std::ostream& os) { |
| if (t.is_float()) { |
| os << "float"; |
| CHECK(t.bits() == 16 || t.bits() == 32 || t.bits() == 64); |
| } else if (t.is_int()) { |
| os << "int"; |
| CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); |
| } else { |
| CHECK(t.is_uint()) << "Unsupported type " << t; |
| os << "uint"; |
| CHECK(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64); |
| } |
| os << t.bits(); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) |
| os << op->value; |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*) |
| PrintType(op->dtype, os); |
| os << "(" << std::setprecision(20) << op->value << ")"; |
| } |
| void CodeGenHybrid::VisitExpr_(const StringImmNode* op, std::ostream& os) { // NOLINT(*) |
| os << "'" << op->value << "'"; |
| } |
| |
| template <typename T> |
| inline void PrintBinaryExpr(const T* op, const char* opstr, |
| std::ostream& os, // NOLINT(*) |
| CodeGenHybrid* p) { |
| CHECK(op->dtype.lanes() == 1) << "vec bin op not implemented"; |
| if (isalpha(opstr[0])) { |
| os << opstr << '('; |
| p->PrintExpr(op->a, os); |
| os << ", "; |
| p->PrintExpr(op->b, os); |
| os << ')'; |
| } else { |
| os << '('; |
| p->PrintExpr(op->a, os); |
| if (!strcmp(opstr, "&&")) opstr = "and"; |
| if (!strcmp(opstr, "||")) opstr = "or"; |
| os << ' ' << opstr << ' '; |
| p->PrintExpr(op->b, os); |
| os << ')'; |
| } |
| } |
| |
| inline void PrintBinaryIntrinsitc(const CallNode* op, const char* opstr, |
| std::ostream& os, // NOLINT(*) |
| CodeGenHybrid* p) { |
| CHECK(op->dtype.lanes() == 1) << "vec bin intrin not implemented"; |
| CHECK_EQ(op->args.size(), 2U); |
| os << '('; |
| p->PrintExpr(op->args[0], os); |
| os << opstr; |
| p->PrintExpr(op->args[1], os); |
| os << ')'; |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const CastNode* op, std::ostream& os) { // NOLINT(*) |
| if (op->dtype == op->value.dtype()) { |
| PrintExpr(op->value, stream); |
| } else { |
| PrintType(op->dtype, os); |
| os << "("; |
| PrintExpr(op->value, os); |
| os << ")"; |
| } |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const VarNode* op, std::ostream& os) { // NOLINT(*) |
| os << GetVarID(op); |
| } |
| void CodeGenHybrid::VisitExpr_(const AddNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "+", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const SubNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "-", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const MulNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "*", os, this); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const DivNode* op, std::ostream& os) { // NOLINT(*) |
| if (op->dtype.is_int()) |
| PrintBinaryExpr(op, "//", os, this); |
| else |
| PrintBinaryExpr(op, "/", os, this); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const FloorDivNode* op, std::ostream& os) { // NOLINT(*) |
| if (op->dtype.is_int()) |
| PrintBinaryExpr(op, "//", os, this); |
| else |
| PrintBinaryExpr(op, "/", os, this); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const ModNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "%", os, this); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const FloorModNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "%", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const MinNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "min", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "max", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const EQNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "==", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const NENode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "!=", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const LTNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "<", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const LENode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "<=", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const GTNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, ">", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const GENode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, ">=", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const AndNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "&&", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const OrNode* op, std::ostream& os) { // NOLINT(*) |
| PrintBinaryExpr(op, "||", os, this); |
| } |
| void CodeGenHybrid::VisitExpr_(const NotNode* op, std::ostream& os) { // NOLINT(*) |
| os << "not "; |
| PrintExpr(op->a, os); |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const ProducerLoadNode* op, std::ostream& os) { // NOLINT(*) |
| auto tensor = Downcast<Tensor>(op->producer); |
| |
| os << GetTensorID(tensor); |
| os << "["; |
| for (size_t i = 0; i < op->indices.size(); ++i) { |
| if (i) os << ", "; |
| std::stringstream idx; |
| PrintExpr(op->indices[i], idx); |
| os << idx.str(); |
| } |
| os << "]"; |
| } |
| void CodeGenHybrid::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) |
| if (op->op.same_as(builtin::bitwise_and())) { |
| PrintBinaryIntrinsitc(op, "&", os, this); |
| } else if (op->op.same_as(builtin::bitwise_xor())) { |
| PrintBinaryIntrinsitc(op, "^", os, this); |
| } else if (op->op.same_as(builtin::bitwise_or())) { |
| PrintBinaryIntrinsitc(op, "|", os, this); |
| } else if (op->op.same_as(builtin::shift_left())) { |
| PrintBinaryIntrinsitc(op, "<<", os, this); |
| } else if (op->op.same_as(builtin::shift_right())) { |
| PrintBinaryIntrinsitc(op, ">>", os, this); |
| } else if (op->op.same_as(builtin::bitwise_not())) { |
| CHECK_EQ(op->args.size(), 1U); |
| os << "(~"; |
| PrintExpr(op->args[0], os); |
| os << ')'; |
| } else if (op->op.same_as(builtin::if_then_else())) { |
| PrintExpr(op->args[1], os); |
| os << " if "; |
| PrintExpr(op->args[0], os); |
| os << " else "; |
| PrintExpr(op->args[2], os); |
| } else if (op->op.same_as(builtin::call_pure_extern()) || |
| op->op.same_as(builtin::call_extern())) { |
| StringImm fname = Downcast<StringImm>(op->args[0]); |
| os << fname << "("; |
| for (size_t i = 1; i < op->args.size(); i++) { |
| PrintExpr(op->args[i], os); |
| if (i < op->args.size() - 1) { |
| os << ", "; |
| } |
| } |
| os << ")"; |
| } else { |
| auto* ptr_op = op->op.as<OpNode>(); |
| CHECK(ptr_op != nullptr); |
| std::string name = ptr_op->name; |
| CHECK_EQ(name.compare(0, 4, "tir."), 0); |
| os << name.substr(4) << "("; |
| for (size_t i = 0; i < op->args.size(); i++) { |
| PrintExpr(op->args[i], os); |
| if (i < op->args.size() - 1) { |
| os << ", "; |
| } |
| } |
| os << ")"; |
| } |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) |
| LOG(FATAL) << "Phase 0 has no Load(s)!"; |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "Phase 0 has no Store(s)!"; } |
| |
| void CodeGenHybrid::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) |
| LOG(FATAL) << "Phase 0 has no Let(s)!"; |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const AllocateNode* op) { |
| LOG(FATAL) << "Phase 0 has no Allocate(s)!"; |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*) |
| LOG(FATAL) << "Ramp to be supported yet"; |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*) |
| LOG(FATAL) << "Broadcast: not supported "; |
| } |
| |
| void CodeGenHybrid::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(*) |
| PrintExpr(op->true_value, os); |
| os << " if "; |
| PrintExpr(op->condition, os); |
| os << " else "; |
| PrintExpr(op->false_value, os); |
| os << "\n"; |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const LetStmtNode* op) { |
| std::string value = PrintExpr(op->value); |
| stream << GetVarID(op->var.get()) << " = " << value << ";\n"; |
| PrintStmt(op->body); |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { |
| if (op->attr_key == tir::attr::thread_extent) { |
| auto iter_var = op->node.as<IterVarNode>(); |
| CHECK(iter_var); |
| binds_[iter_var->var.get()] = dot_to_underscore(iter_var->var->name_hint); |
| PrintIndent(); |
| stream << "for " << binds_[iter_var->var.get()] << " in bind('" << iter_var->var->name_hint |
| << "', "; |
| PrintExpr(op->value, stream); |
| stream << "):\n"; |
| indent_ += tab_; |
| PrintStmt(op->body); |
| indent_ -= tab_; |
| } else if (op->attr_key == tir::attr::realize_scope) { |
| auto v = Downcast<Operation>(op->node); |
| alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value; |
| PrintStmt(op->body); |
| } else { |
| // For now we ignore the unsupported AttrStmt |
| PrintStmt(op->body); |
| } |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { |
| auto tensor = Downcast<Tensor>(op->producer); |
| CHECK(alloc_storage_scope_.count(tensor->op)); |
| if (!alloc_storage_scope_[tensor->op].empty()) { |
| PrintIndent(); |
| stream << GetTensorID(tensor) << " = allocate(("; |
| for (size_t i = 0; i < op->bounds.size(); ++i) { |
| if (i) stream << ", "; |
| stream << PrintExpr(op->bounds[i]->extent); |
| } |
| if (op->bounds.size() == 1) stream << ", "; |
| stream << "), '"; |
| PrintType(tensor->dtype, stream); |
| stream << "', '"; |
| stream << alloc_storage_scope_[tensor->op] << "')\n"; |
| } |
| PrintStmt(op->body); |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const AssertStmtNode* op) { |
| PrintIndent(); |
| stream << "assert "; |
| PrintExpr(op->condition, stream); |
| stream << ", "; |
| PrintExpr(op->message, stream); |
| stream << "\n"; |
| PrintStmt(op->body); |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const ProducerStoreNode* op) { |
| auto tensor = Downcast<Tensor>(op->producer); |
| PrintIndent(); |
| stream << GetTensorID(tensor); |
| stream << "["; |
| for (size_t i = 0; i < op->indices.size(); ++i) { |
| if (i) stream << ", "; |
| PrintExpr(op->indices[i], stream); |
| } |
| stream << "] = "; |
| PrintExpr(op->value, stream); |
| stream << "\n"; |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const ForNode* op) { |
| std::string extent = PrintExpr(op->extent); |
| PrintIndent(); |
| std::string vid = GetVarID(op->loop_var.get()); |
| stream << "for " << vid << " in " |
| << "range(" << extent << "):\n"; |
| indent_ += tab_; |
| PrintStmt(op->body); |
| indent_ -= tab_; |
| } |
| |
| bool is_noop(const Stmt& stmt) { |
| if (!stmt.defined()) return true; |
| if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value); |
| return false; |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) { |
| std::string cond = PrintExpr(op->condition); |
| PrintIndent(); |
| stream << "if " << cond << ":\n"; |
| indent_ += tab_; |
| PrintStmt(op->then_case); |
| indent_ -= tab_; |
| |
| if (!is_noop(op->else_case)) { |
| PrintIndent(); |
| stream << "else:\n"; |
| indent_ += tab_; |
| PrintStmt(op->else_case); |
| indent_ -= tab_; |
| } |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) { |
| for (Stmt stmt : op->seq) { |
| PrintStmt(stmt); |
| } |
| } |
| |
| void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { |
| if (is_const_int(op->value)) return; |
| std::string str = PrintExpr(op->value); |
| if (!str.empty()) stream << str << "\n"; |
| } |
| |
| void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } |
| |
| std::string CodeGenHybrid::GetVarID(const VarNode* v) { |
| if (binds_.count(v)) return binds_[v]; |
| auto key = std::make_pair(static_cast<const Object*>(v), 0); |
| if (id_map_.count(key)) { |
| return id_map_[key]; |
| } |
| return id_map_[key] = GetUniqueName(v->name_hint); |
| } |
| |
| std::string CodeGenHybrid::GetTensorID(const Tensor& tensor) { |
| auto key = std::make_pair(tensor->op.get(), tensor->value_index); |
| if (id_map_.count(key)) { |
| return id_map_[key]; |
| } |
| std::string name_hint = tensor->op->name; |
| if (tensor->op->num_outputs() > 1) { |
| name_hint += "_v" + std::to_string(tensor->value_index); |
| } |
| return id_map_[key] = GetUniqueName(name_hint); |
| } |
| |
| void CodeGenHybrid::ReserveKeywords() { |
| GetUniqueName("def"); |
| GetUniqueName("for"); |
| GetUniqueName("in"); |
| GetUniqueName("range"); |
| GetUniqueName("True"); |
| GetUniqueName("False"); |
| GetUniqueName("unroll"); |
| GetUniqueName("const_range"); |
| GetUniqueName("parallel"); |
| GetUniqueName("vectorize"); |
| GetUniqueName("bind"); |
| GetUniqueName("threadIdx.x"); |
| GetUniqueName("threadIdx.y"); |
| GetUniqueName("threadIdx.z"); |
| GetUniqueName("blockIdx.x"); |
| GetUniqueName("blockIdx.y"); |
| GetUniqueName("blockIdx.z"); |
| GetUniqueName("vthread"); |
| GetUniqueName("allocate"); |
| GetUniqueName("output_tensor"); |
| GetUniqueName("sqrt"); |
| GetUniqueName("log"); |
| GetUniqueName("tanh"); |
| GetUniqueName("power"); |
| GetUniqueName("exp"); |
| GetUniqueName("sigmoid"); |
| GetUniqueName("popcount"); |
| GetUniqueName("likely"); |
| GetUniqueName("int8"); |
| GetUniqueName("int16"); |
| GetUniqueName("int32"); |
| GetUniqueName("int64"); |
| GetUniqueName("uint8"); |
| GetUniqueName("uint16"); |
| GetUniqueName("uint32"); |
| GetUniqueName("uint64"); |
| GetUniqueName("float16"); |
| GetUniqueName("float32"); |
| GetUniqueName("float64"); |
| GetUniqueName("ceil_div"); |
| GetUniqueName("max_num_threads"); |
| } |
| |
| void CodeGenHybrid::DumpStmt(const Stmt& stmt, const Array<ObjectRef>& inputs, |
| const Array<Tensor>& outputs, const std::string& name) { |
| ReserveKeywords(); |
| GetUniqueName(name); |
| |
| stream << "def " << name << "("; |
| for (size_t i = 0; i < inputs.size(); ++i) { |
| if (i) stream << ", "; |
| if (auto tensor = inputs[i].as<TensorNode>()) { |
| stream << GetTensorID(GetRef<Tensor>(tensor)); |
| } else { |
| auto var = inputs[i].as<VarNode>(); |
| CHECK(var) << "Input should either be a tensor or a variable!"; |
| stream << GetVarID(var); |
| } |
| } |
| stream << "):\n"; |
| indent_ += tab_; |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| PrintIndent(); |
| stream << GetTensorID(outputs[i]) << " = output_tensor(("; |
| for (size_t j = 0; j < outputs[i]->shape.size(); ++j) { |
| if (j) stream << ", "; |
| PrintExpr(outputs[i]->shape[j], stream); |
| } |
| if (outputs[i]->shape.size() == 1) stream << ", "; |
| stream << "), '" << outputs[i]->dtype << "')\n"; |
| } |
| PrintStmt(stmt); |
| PrintIndent(); |
| stream << "return "; |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| if (i) stream << ", "; |
| stream << GetTensorID(outputs[i]); |
| } |
| stream << "\n"; |
| } |
| |
| TVM_REGISTER_GLOBAL("hybrid._Dump").set_body([](TVMArgs args, TVMRetValue* rv) { |
| CodeGenHybrid codegen; |
| if (args.size() == 4) |
| codegen.DumpStmt(args[0], args[1], args[2], args[3]); |
| else |
| codegen.DumpStmt(args[0], args[1], args[2]); |
| *rv = codegen.Finish(); |
| }); |
| } // namespace contrib |
| } // namespace tvm |