blob: 0c8cd3c123714c7b51d3132ca9d366a8aa3f79fa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/relax/distributed/struct_info.h>
#include <limits>
#include "./utils.h"
namespace tvm {
namespace script {
namespace printer {
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::PrimValue>( //
"", [](relax::PrimValue n, AccessPath n_p, IRDocsifier d) -> Doc {
// TODO(@junrushao): float numbers
return Relax(d, "prim_value")->Call({d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::StringImm>( //
"", [](relax::StringImm n, AccessPath n_p, IRDocsifier d) -> Doc {
return Relax(d, "str")->Call({LiteralDoc::Str(n->value, n_p->Attr("value"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::DataTypeImm>( //
"", [](relax::DataTypeImm n, AccessPath n_p, IRDocsifier d) -> Doc {
return Relax(d, "dtype")->Call({LiteralDoc::DataType(n->value, n_p->Attr("value"))});
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Tuple>( //
"", [](relax::Tuple n, AccessPath n_p, IRDocsifier d) -> Doc {
// TODO(@junrushao): revisit tuple printing
if (n->fields.empty()) {
return Relax(d, "tuple")->Call({});
}
ffi::Array<ExprDoc> fields_doc;
AccessPath fields_p = n_p->Attr("fields");
for (int i = 0, l = n->fields.size(); i < l; ++i) {
fields_doc.push_back(d->AsDoc<ExprDoc>(n->fields[i], fields_p->ArrayItem(i)));
}
return TupleDoc(fields_doc);
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::TupleGetItem>( //
"", [](relax::TupleGetItem n, AccessPath n_p, IRDocsifier d) -> Doc {
ExprDoc idx = LiteralDoc::Int(n->index, n_p->Attr("index"));
return d->AsDoc<ExprDoc>(n->tuple, n_p->Attr("tuple"))[{idx}];
});
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::ShapeExpr>( //
"", [](relax::ShapeExpr n, AccessPath n_p, IRDocsifier d) -> Doc {
ffi::Array<ExprDoc> values_doc;
AccessPath values_p = n_p->Attr("values");
for (int i = 0, l = n->values.size(); i < l; ++i) {
values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayItem(i), d));
}
return Relax(d, "shape")->Call({ListDoc(values_doc)});
});
ffi::Optional<ExprDoc> SpecialScalar(const runtime::Tensor& n, const AccessPath& p) {
DataType dtype = n.DataType();
const void* data = n->data;
if (n->ndim != 0 || n->device.device_type != kDLCPU) {
return std::nullopt;
}
if (dtype == DataType::Int(8)) {
return LiteralDoc::Int(*reinterpret_cast<const int8_t*>(data), p);
} else if (dtype == DataType::Int(16)) {
return LiteralDoc::Int(*reinterpret_cast<const int16_t*>(data), p);
} else if (dtype == DataType::Int(32)) {
return LiteralDoc::Int(*reinterpret_cast<const int32_t*>(data), p);
} else if (dtype == DataType::Int(64)) {
return LiteralDoc::Int(*reinterpret_cast<const int64_t*>(data), p);
} else if (dtype == DataType::Float(16)) {
// From IEEE-754 float16 definition
//
// Ref: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
uint16_t bits = *reinterpret_cast<const uint16_t*>(data);
uint16_t sign_bit = (bits & 0b1000'0000'0000'0000) >> 15;
uint16_t exponent = (bits & 0b0111'1100'0000'0000) >> 10;
uint16_t fraction = (bits & 0b0000'0011'1111'1111) >> 0;
double value;
if (exponent == 0b1'1111 && fraction == 0) {
value = std::numeric_limits<double>::infinity();
} else if (exponent == 0b1'1111) {
value = std::numeric_limits<double>::quiet_NaN();
} else if (exponent == 0 && fraction == 0) {
value = 0.0;
} else if (exponent == 0) {
value = std::pow(2.0, -24) * static_cast<double>(fraction);
} else {
value = std::pow(2.0, static_cast<double>(exponent) - 25) *
static_cast<double>(fraction | (1 << 10));
}
if (sign_bit) {
value *= -1.0;
}
return LiteralDoc::Float(value, p);
} else if (dtype == DataType::Float(32)) {
return LiteralDoc::Float(*reinterpret_cast<const float*>(data), p);
} else if (dtype == DataType::Float(64)) {
return LiteralDoc::Float(*reinterpret_cast<const double*>(data), p);
} else if (dtype == DataType::Bool()) {
return LiteralDoc::Boolean(*reinterpret_cast<const uint8_t*>(data), p);
} else {
return std::nullopt;
}
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Constant>( //
"", [](relax::Constant n, AccessPath n_p, IRDocsifier d) -> Doc {
if (ffi::Optional<ExprDoc> s = SpecialScalar(n->data, n_p->Attr("data"))) {
if (n->struct_info_.as<relax::distributed::DTensorStructInfoNode>()) {
ExprDoc ann = d->AsDoc<ExprDoc>(n->struct_info_, n_p->Attr("struct_info_"));
return Relax(d, "dist.const")->Call({s.value(), ann});
}
return Relax(d, "const")
->Call({
s.value(),
LiteralDoc::DataType(n->data.DataType(), n_p->Attr("data")->Attr("dtype")),
});
}
return d->AddMetadata(n);
});
Doc PrintRelaxVar(relax::Var n, AccessPath p, IRDocsifier d) {
if (!d->IsVarDefined(n)) {
ExprDoc ann = d->AsDoc<ExprDoc>(n->struct_info_, p->Attr("struct_info_"));
Frame f = d->frames.back();
ExprDoc var = DefineVar(n, f, d);
f->stmts.push_back(AssignDoc(var, std::nullopt, ann));
}
return d->GetVarDoc(n).value();
}
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<relax::Var>("", PrintRelaxVar);
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<relax::DataflowVar>("", PrintRelaxVar);
TVM_SCRIPT_REPR(relax::PrimValueNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::StringImmNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::DataTypeImmNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::TupleNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::TupleGetItemNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::ShapeExprNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::VarNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::DataflowVarNode, ReprPrintRelax);
TVM_SCRIPT_REPR(relax::ConstantNode, ReprPrintRelax);
} // namespace printer
} // namespace script
} // namespace tvm