| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file src/ir/expr.cc |
| * \brief The expression AST nodes for the common IR infra. |
| */ |
| #include <tvm/arith/analyzer.h> |
| #include <tvm/ffi/function.h> |
| #include <tvm/ffi/reflection/registry.h> |
| #include <tvm/ir/expr.h> |
| #include <tvm/ir/function.h> |
| #include <tvm/te/tensor.h> |
| #include <tvm/tir/expr.h> |
| |
| #include "../support/scalars.h" |
| |
| namespace tvm { |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| BaseExprNode::RegisterReflection(); |
| PrimExprNode::RegisterReflection(); |
| RelaxExprNode::RegisterReflection(); |
| BaseFuncNode::RegisterReflection(); |
| GlobalVarNode::RegisterReflection(); |
| IntImmNode::RegisterReflection(); |
| FloatImmNode::RegisterReflection(); |
| RangeNode::RegisterReflection(); |
| } |
| |
| PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} |
| |
| PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} |
| |
| PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringImm(value); } |
| |
| IntImm::IntImm(DataType dtype, int64_t value, Span span) { |
| TVM_FFI_CHECK(dtype.is_scalar(), ValueError) |
| << "IntImm can only take scalar, but " << dtype << " was supplied."; |
| TVM_FFI_CHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool(), ValueError) |
| << "IntImm supports only int or uint or bool type, but " << dtype << " was supplied."; |
| if (dtype.is_uint()) { |
| TVM_FFI_CHECK_GE(value, 0U, ValueError) |
| << "Literal value " << value << " is negative for unsigned integer type " << dtype; |
| if (dtype.bits() < 64) { |
| TVM_FFI_CHECK_LT(value, 1LL << dtype.bits(), ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } |
| } else if (dtype.bits() == 1 || dtype.is_bool()) { |
| // int(1) |
| TVM_FFI_CHECK(value == 0 || value == 1, ValueError) << value << " exceeds range of " << dtype; |
| } else if (dtype.bits() < 64) { |
| TVM_FFI_CHECK_GE(value, -(1LL << (dtype.bits() - 1)), ValueError) |
| << "Literal value " << value << " exceeds minimum of " << dtype; |
| TVM_FFI_CHECK_LT(value, 1LL << (dtype.bits() - 1), ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } |
| ObjectPtr<IntImmNode> node = ffi::make_object<IntImmNode>(); |
| node->dtype = dtype; |
| node->value = value; |
| node->span = span; |
| data_ = std::move(node); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("ir.IntImm", [](DataType dtype, int64_t value, Span span) { |
| return IntImm(dtype, value, span); |
| }); |
| } |
| |
| FloatImm::FloatImm(DataType dtype, double value, Span span) { |
| TVM_FFI_CHECK_EQ(dtype.lanes(), 1, ValueError) << "FloatImm can only take scalar."; |
| |
| TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || |
| dtype.is_float4() || dtype.code() >= DataType::kCustomBegin, |
| ValueError) |
| << "FloatImm supports only float, but " << dtype << " was supplied."; |
| |
| // check range for float32 and float16 since they have specified range. |
| if (!std::isinf(value) && !std::isnan(value)) { |
| if (dtype.bits() == 32) { |
| TVM_FFI_CHECK_GE(value, std::numeric_limits<float>::lowest(), ValueError) |
| << "Literal value " << value << " exceeds minimum of " << dtype; |
| TVM_FFI_CHECK_LE(value, std::numeric_limits<float>::max(), ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } else if (dtype.is_float16()) { |
| TVM_FFI_CHECK_GE(value, -support::kMaxFloat16, ValueError) |
| << "Literal value " << value << " exceeds minimum of " << dtype; |
| TVM_FFI_CHECK_LE(value, support::kMaxFloat16, ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } else if (dtype.is_bfloat16()) { |
| TVM_FFI_CHECK_GE(value, -support::kMaxBFloat16, ValueError) |
| << "Literal value " << value << " exceeds minimum of " << dtype; |
| TVM_FFI_CHECK_LE(value, support::kMaxBFloat16, ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } else if (dtype.is_float8_e3m4() || dtype.is_float8_e4m3() || dtype.is_float8_e4m3b11fnuz() || |
| dtype.is_float8_e4m3fn() || dtype.is_float8_e4m3fnuz() || dtype.is_float8_e5m2() || |
| dtype.is_float8_e5m2fnuz() || dtype.is_float8_e8m0fnu()) { |
| double bound = 0.0; |
| bool nonneg = false; |
| |
| switch (dtype.code()) { |
| case DataType::TypeCode::kFloat8_e3m4: |
| bound = support::kMaxE3M4; |
| break; |
| case DataType::TypeCode::kFloat8_e4m3: |
| bound = support::kMaxE4M3; |
| break; |
| case DataType::TypeCode::kFloat8_e4m3b11fnuz: |
| bound = support::kMaxE4M3B11FNUZ; |
| nonneg = true; |
| break; |
| case DataType::TypeCode::kFloat8_e4m3fn: |
| bound = support::kMaxE4M3FN; |
| break; |
| case DataType::TypeCode::kFloat8_e4m3fnuz: |
| bound = support::kMaxE4M3FNUZ; |
| nonneg = true; |
| break; |
| case DataType::TypeCode::kFloat8_e5m2: |
| bound = support::kMaxE5M2; |
| break; |
| case DataType::TypeCode::kFloat8_e5m2fnuz: |
| bound = support::kMaxE5M2FNUZ; |
| nonneg = true; |
| break; |
| case DataType::TypeCode::kFloat8_e8m0fnu: |
| bound = support::kMaxE8M0FNU; |
| nonneg = true; |
| break; |
| default: |
| TVM_FFI_THROW(InternalError) << "Unhandled float8 type: " << dtype; |
| } |
| |
| if (nonneg) { |
| TVM_FFI_CHECK_GE(value, 0, ValueError) |
| << "Literal value " << value << " below zero for unsigned " << dtype; |
| } else { |
| TVM_FFI_CHECK_GE(value, -bound, ValueError) |
| << "Literal value " << value << " below minimum of " << dtype; |
| } |
| TVM_FFI_CHECK_LE(value, bound, ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| |
| } else if (dtype.is_float6_e2m3fn() || dtype.is_float6_e3m2fn()) { |
| double bound = (dtype.code() == DataType::TypeCode::kFloat6_e2m3fn) ? support::kMaxE2M3FN |
| : support::kMaxE3M2FN; |
| TVM_FFI_CHECK_GE(value, -bound, ValueError) |
| << "Literal value " << value << " below minimum of " << dtype; |
| TVM_FFI_CHECK_LE(value, bound, ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| |
| } else if (dtype.is_float4_e2m1fn()) { |
| double bound = support::kMaxE2M1FN; |
| TVM_FFI_CHECK_GE(value, -bound, ValueError) |
| << "Literal value " << value << " below minimum of " << dtype; |
| TVM_FFI_CHECK_LE(value, bound, ValueError) |
| << "Literal value " << value << " exceeds maximum of " << dtype; |
| } |
| } |
| ObjectPtr<FloatImmNode> node = ffi::make_object<FloatImmNode>(); |
| node->dtype = dtype; |
| node->value = value; |
| node->span = span; |
| data_ = std::move(node); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef().def("ir.FloatImm", [](DataType dtype, double value, Span span) { |
| return FloatImm(dtype, value, span); |
| }); |
| } |
| |
| Range::Range(PrimExpr begin, PrimExpr end, Span span) |
| : Range(ffi::make_object<RangeNode>(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} |
| |
| Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { |
| return Range(ffi::make_object<RangeNode>(min, extent, span)); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("ir.Range_from_min_extent", Range::FromMinExtent) |
| .def("ir.Range", [](PrimExpr begin, ffi::Optional<PrimExpr> end, Span span) -> Range { |
| if (end.defined()) { |
| return Range(begin, end.value(), span); |
| } else { |
| return Range(IntImm(begin->dtype, 0), begin, span); |
| } |
| }); |
| } |
| |
| GlobalVar::GlobalVar(ffi::String name_hint, Span span) { |
| ObjectPtr<GlobalVarNode> n = ffi::make_object<GlobalVarNode>(); |
| n->name_hint = std::move(name_hint); |
| n->span = std::move(span); |
| data_ = std::move(n); |
| } |
| |
| TVM_FFI_STATIC_INIT_BLOCK() { |
| namespace refl = tvm::ffi::reflection; |
| refl::GlobalDef() |
| .def("ir.GlobalVar", [](ffi::String name) { return GlobalVar(name); }) |
| .def("ir.DebugPrint", [](ObjectRef ref) { |
| std::stringstream ss; |
| ss << ref; |
| return ss.str(); |
| }); |
| } |
| |
| } // namespace tvm |