| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file tvm/tir/op.h |
| * \brief Common operators defined for Expr. |
| * |
| * \note Most of the operator defined here perform simple constant folding |
| * when the type is int32 or int64 for simplifying the index expressions. |
| */ |
| // Acknowledgement: Most operator APIs originate from Halide. |
| #ifndef TVM_TIR_OP_H_ |
| #define TVM_TIR_OP_H_ |
| |
| #include <tvm/ir/expr.h> |
| #include <tvm/ir/op.h> |
| #include <tvm/ir/type.h> |
| #include <tvm/tir/expr.h> |
| #include <tvm/tir/stmt.h> |
| |
| #include <algorithm> |
| #include <limits> |
| #include <type_traits> |
| |
| namespace tvm { |
| |
| // Most common operators can be overloaded by argument type(PrimExpr). |
| // So we put them under the root namespace. |
| // |
| // We put more developer oriented APIs -- make_const and is_const under tir |
| // as they are more specific to the tir namespace. |
| |
| /*! |
| * \brief Get the type of the expression under the unified type system. |
| * |
| * This function could return a more refined type than |
| * the runtime type provided by expr->dtype |
| * |
| * \param expr The input parameter. |
| * \return The result type. |
| * |
| * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
| */ |
| TVM_DLL Type GetType(const PrimExpr& expr); |
| |
| /*! |
| * \brief Get the type corresponding to DataType |
| * \param dtype The data type |
| * \return The result type |
| * |
| * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
| */ |
| TVM_DLL Type GetTypeFromRuntimeDataType(const DataType& dtype); |
| |
| /*! |
| * \brief Get the implied DataType for storing values with type during runtime. |
| * |
| * \param type The input type. |
| * \return The result runtime::DataType. |
| * |
| * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. |
| */ |
| TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type); |
| |
| /*! |
| * \brief Return the value. |
| * |
| * \param value The returned value. |
| * \param span The location of this operation in the source. |
| * \return The return expression. |
| */ |
| TVM_DLL PrimExpr ret(PrimExpr value, Span span = Span()); |
| |
| /*! |
| * Query the maximum possible value of dtype. |
| * \param dtype The data type. |
| * \param span The location of this operation in the source. |
| * \return the maximum possible value in this format. |
| */ |
| TVM_DLL PrimExpr max_value(const DataType& dtype, Span span = Span()); |
| |
| /*! |
| * Query the minimum possible value of dtype. |
| * \param dtype The data type. |
| * \param span The location of this operation in the source. |
| * \return the minimum possible value in this format. |
| */ |
| TVM_DLL PrimExpr min_value(const DataType& dtype, Span span = Span()); |
| |
| /*! |
| * Get the value of infinity. |
| * \param dtype The data type. |
| * \param span The location of this operation in the source. |
| * \return the infinity value in this format. |
| */ |
| TVM_DLL PrimExpr infinity(const DataType& dtype, Span span = Span()); |
| |
| /*! |
| * \brief cast value to type. |
| * |
| * \param t the target type. |
| * \param value The value |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note This function may return value if the type is the same. |
| */ |
| TVM_DLL PrimExpr cast(const DataType& t, PrimExpr value, Span span = Span()); |
| /*! |
| * \brief perform reinterpret cast value to type. |
| * |
| * \param t the target type. |
| * \param value The value |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note This function may return value if the type is the same. |
| */ |
| TVM_DLL PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span = Span()); |
| /*! |
| * \brief add operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr add(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief subtraction operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr sub(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief negation. |
| * |
| * \param a input. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr neg(PrimExpr a, Span span = Span()); |
| /*! |
| * \brief multiplication operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr mul(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief left shift operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief right shift operator |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief greater |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr greater(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief greater_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief less |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr less(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief less_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr equal(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief not_equal |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief and |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief or |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief not |
| * |
| * \param a left operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note This operator does eager constant folding. |
| */ |
| TVM_DLL PrimExpr logical_not(PrimExpr a, Span span = Span()); |
| /*! |
| * \brief compute division in C semantics. |
| * |
| * a / b as in C/C++. |
| * |
| * When operands are integers, it directly corresponds to truncdiv. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr div(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute trunc(a / b) |
| * |
| * This is the default integer division behavior in C. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute the remainder of truncdiv |
| * |
| * This is the default integer division behavior in C. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute floor(a / b) where a and b are non-negative. |
| * |
| * Use this function for index split calculation. |
| * |
| * This function might take advantage of the fact |
| * that a and b are non-negative. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute ceil(a / b) where a and b are non-negative. |
| * |
| * Use this function for shape split calculation. |
| * |
| * This function might take advantage of the fact |
| * that a and b are non-negative. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * shape types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute the remainder floor(a / b) where a and b are non-negative. |
| * |
| * Use this function for index split calculation. |
| * This function might take advantage of the fact |
| * that a and b are non-negative. |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute floor(a / b) |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute ceil(a / b) |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief compute the remainder of floordiv |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr floormod(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take maximum of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr max(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take minimum of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr min(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take bitwise and of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take bitwise or of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take bitwise xor of two values |
| * |
| * \param a left operand |
| * \param b right operand |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span = Span()); |
| /*! |
| * \brief take bitwise negation of two values |
| * |
| * \param a the input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr bitwise_neg(PrimExpr a, Span span = Span()); |
| /*! |
| * \brief Conditional expression. |
| * |
| * \param cond The condition |
| * \param true_value The value when results are true. |
| * \param false_value The value when results are false. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * \note this function does eager constant folding for |
| * index types(int32, int64) when possible. |
| */ |
| TVM_DLL PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, |
| Span span = Span()); |
| /*! |
| * \brief Mark condition as likely. |
| * \param cond The condition |
| * \param span The location of this operation in the source. |
| * \return The marked expression. |
| */ |
| TVM_DLL PrimExpr likely(PrimExpr cond, Span span = Span()); |
| /*! |
| * \brief Calculate power(x, y) |
| * \param x The left operand. |
| * \param y The right operand. |
| * \param span The location of this operation in the source. |
| */ |
| TVM_DLL PrimExpr pow(PrimExpr x, PrimExpr y, Span span = Span()); |
| /*! |
| * \brief Calculate absolute value of x. |
| * \param x The input data |
| * \param span The location of this operation in the source. |
| * |
| * \return The aboslute value of input data x |
| */ |
| TVM_DLL PrimExpr abs(PrimExpr x, Span span = Span()); |
| /*! |
| * \brief Check if x is NaN. |
| * \param x The input data |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr isnan(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Check if x is finite. |
| * \param x The input data |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr isfinite(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Check if x is infinite. |
| * \param x The input data |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief sum of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| * \return The result. |
| */ |
| TVM_DLL PrimExpr sum(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief logical And of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| */ |
| TVM_DLL PrimExpr all(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief logical Or of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| * \return The result. |
| */ |
| TVM_DLL PrimExpr any(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief max of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| * \return The result. |
| */ |
| TVM_DLL PrimExpr max(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief max of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| * \return The result. |
| */ |
| TVM_DLL PrimExpr min(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief product of source expression over axis |
| * \param source The source expression. |
| * \param axis List of iteration variables that will be used for reduction. |
| * \param init The value with which to initialize the output. |
| * \param span The location of this operation in the source. |
| * \return The result. |
| */ |
| TVM_DLL PrimExpr prod(PrimExpr source, Array<tir::IterVar> axis, Array<PrimExpr> init = {}, |
| Span span = Span()); |
| |
| /*! |
| * \brief Calculate floor(x) |
| * \param x The input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Calculate ceil(x) |
| * \param x The input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Calculate round(x) |
| * \param x The input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr round(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Calculates std::nearbyint(x) |
| * \param x The input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| * This is a faster alternate to round. |
| */ |
| TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Calculate trunc(x) |
| * \param x The input expression. |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| TVM_DLL PrimExpr trunc(PrimExpr x, Span span = Span()); |
| |
| /*! |
| * \brief Construct a large uint constant by its low 32 bits and high 32bits. |
| * \param dtype The final data type. |
| * \param low The lower 32 bits. |
| * \param high The higher 32 bits. |
| * \param span The location of this operation in the source. |
| * \return The constructed expression. |
| */ |
| TVM_DLL PrimExpr LargeUIntImm(DataType dtype, int64_t low, int64_t high, Span span = Span()); |
| |
| /*! |
| * \brief Execute a multiplication between two Q-numbers x and y |
| * followed by a right shift s. The mathematical expression is: |
| * |
| * out = round(x*y*2^-s) |
| * |
| * Please note that the two Q-numbers x and y are supposed to have |
| * the same number of fractional bits q. |
| * |
| * More about Q-numbers here: https://en.wikipedia.org/wiki/Q_(number_format) |
| * |
| * The rounding rule is to the nearest value, rounding half up |
| * (i.e., round(x.1) = x and round (x.5) = x+1) |
| * \param x first Q-number |
| * \param y second Q-number |
| * \param q number of fractional bits in x and y. Needs to be > 0 |
| * \param s integer right shift |
| * \param span The location of this operation in the source. |
| * \return The constructed expression. |
| */ |
| TVM_DLL PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, |
| Span span = Span()); |
| |
| // Intrinsic operators |
| #define TVM_DECLARE_INTRIN_UNARY(OpName) \ |
| inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ |
| static const Op& op = Op::Get("tir." #OpName); \ |
| if (x.dtype().is_bfloat16()) { \ |
| DataType bf16_dtype = x.dtype(); \ |
| DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ |
| PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ |
| PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ |
| return tir::Cast(bf16_dtype, {result_fp32}, span); \ |
| } else { \ |
| return tir::Call(x.dtype(), op, {x}, span); \ |
| } \ |
| } |
| |
| TVM_DECLARE_INTRIN_UNARY(exp); |
| TVM_DECLARE_INTRIN_UNARY(exp2); |
| TVM_DECLARE_INTRIN_UNARY(exp10); |
| TVM_DECLARE_INTRIN_UNARY(erf); |
| TVM_DECLARE_INTRIN_UNARY(tanh); |
| TVM_DECLARE_INTRIN_UNARY(sigmoid); |
| TVM_DECLARE_INTRIN_UNARY(sqrt); |
| TVM_DECLARE_INTRIN_UNARY(rsqrt); |
| TVM_DECLARE_INTRIN_UNARY(log); |
| TVM_DECLARE_INTRIN_UNARY(log2); |
| TVM_DECLARE_INTRIN_UNARY(log10); |
| TVM_DECLARE_INTRIN_UNARY(log1p); |
| TVM_DECLARE_INTRIN_UNARY(popcount); |
| TVM_DECLARE_INTRIN_UNARY(tan); |
| TVM_DECLARE_INTRIN_UNARY(cos); |
| TVM_DECLARE_INTRIN_UNARY(cosh); |
| TVM_DECLARE_INTRIN_UNARY(sin); |
| TVM_DECLARE_INTRIN_UNARY(sinh); |
| TVM_DECLARE_INTRIN_UNARY(asin); |
| TVM_DECLARE_INTRIN_UNARY(acos); |
| TVM_DECLARE_INTRIN_UNARY(atan); |
| TVM_DECLARE_INTRIN_UNARY(acosh); |
| TVM_DECLARE_INTRIN_UNARY(asinh); |
| TVM_DECLARE_INTRIN_UNARY(atanh); |
| TVM_DECLARE_INTRIN_UNARY(clz); |
| |
| #define TVM_DECLARE_INTRIN_BINARY(OpName) \ |
| inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ |
| static const Op& op = Op::Get("tir." #OpName); \ |
| return tir::Call(x.dtype(), op, {x, y}, span); \ |
| } |
| |
| TVM_DECLARE_INTRIN_BINARY(atan2); |
| TVM_DECLARE_INTRIN_BINARY(nextafter); |
| TVM_DECLARE_INTRIN_BINARY(copysign); |
| TVM_DECLARE_INTRIN_BINARY(hypot); |
| TVM_DECLARE_INTRIN_BINARY(ldexp); |
| |
| namespace tir { |
| |
| /*! |
| * \brief Check if type is a pointer to a runtime element type. |
| * \param type The type to be checked. |
| * \param element_type The corresponding element type. |
| * \return The check results |
| */ |
| inline bool IsPointerType(const Type& type, const DataType& element_type) { |
| if (!type.defined()) return false; |
| if (const auto* ptr_type = type.as<PointerTypeNode>()) { |
| if (const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>()) { |
| return prim_type->dtype == element_type; |
| } |
| } |
| return false; |
| } |
| |
| /*! |
| * \brief Make a const value with certain data type. |
| * \param t The target type. |
| * \param value The input value |
| * \return the result expression. |
| * \tparam ValueType The constant value type |
| * \param span The location of this operation in the source. |
| */ |
| template <typename ValueType, |
| typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> |
| inline PrimExpr make_const(DataType t, ValueType value, Span span = Span()); |
| /*! |
| * \brief Make a const zero expr. |
| * \param t The target type. |
| * \param span The location of this operation in the source. |
| * \return the result expression. |
| */ |
| inline PrimExpr make_zero(DataType t, Span span = Span()); |
| /*! |
| * \brief Make a constant true expression. |
| * \param lanes The number of lanes in the bool |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| inline PrimExpr const_true(int lanes = 1, Span span = Span()) { |
| return make_const(DataType::UInt(1, lanes), 1); |
| } |
| /*! |
| * \brief Make a constant false expression. |
| * \param lanes The number of lanes in the bool |
| * \param span The location of this operation in the source. |
| * \return The result expression. |
| */ |
| inline PrimExpr const_false(int lanes = 1, Span span = Span()) { |
| return make_const(DataType::UInt(1, lanes), 0); |
| } |
| /*! |
| * \brief Get x as constant int expression. |
| * \param x The expression |
| * \return the address to the int expression, |
| * return nullptr, if x is not IntImm. |
| */ |
| inline const int64_t* as_const_int(const PrimExpr& x) { |
| if (!x.defined()) return nullptr; |
| if (const tir::IntImmNode* op = x.as<tir::IntImmNode>()) { |
| return &(op->value); |
| } |
| |
| return nullptr; |
| } |
| |
| /*! |
| * \brief Check whether x is a constant integer expression. |
| * \param x The input argument |
| * \param value the value to be compared against. |
| * \return whether x is constant expression. |
| */ |
| inline bool is_const_int(const PrimExpr& x, int64_t value); |
| |
| /*! |
| * \brief Check whether stmt is nop. |
| * \param stmt The input statement |
| * \return whether stmt is nop |
| */ |
| inline bool is_no_op(const tir::Stmt& stmt); |
| |
| /*! |
| * \brief Check whether x is a constant integer 1 |
| * \param x The input argument. |
| * \note This only return true for integer types. |
| * \return whether x is constant 1 |
| */ |
| inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); } |
| |
| /*! |
| * \brief Check whether x is a constant integer 0 |
| * \param x The input argument |
| * \return whether x is constant 0 |
| * \note This only return true for integer types. |
| */ |
| inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); } |
| |
| /*! |
| * \brief Check whether x is an integer constant. |
| * \note This only return true for integer types. |
| * \return whether x is constant |
| */ |
| inline bool is_const_int(const PrimExpr& x); |
| |
| /*! |
| * \brief Check whether x is an integer/float constant. |
| * \note This only return true for integer types. |
| * \return whether x is constant |
| */ |
| inline bool is_const_number(const PrimExpr& x); |
| |
| /*! |
| * \brief Left fold. |
| * \param freduce The reduction function. |
| * \param init_value The initial value. |
| * \param values The values to be folded. |
| * \param span The location of the fold in the source. |
| * \return The result. |
| * \tparam FReduce The type of the reduction. |
| */ |
| template <typename FReduce> |
| inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values, |
| Span span = Span()); |
| |
| /*! |
| * \brief Check whether x is a constant power of two |
| * If x is power of two, write the power to the shift. |
| * |
| * \param x The input expression. |
| * \param shift The output shift if x is power of two. |
| * \return whether x is constant power of two |
| */ |
| TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); |
| |
| // Implementation details after this |
| inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); } |
| |
| inline bool is_const_number(const PrimExpr& x) { |
| if (x.as<tir::IntImmNode>()) { |
| return true; |
| } else if (x.as<tir::FloatImmNode>()) { |
| return true; |
| } else if (const auto* op = x.as<tir::BroadcastNode>()) { |
| return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>()); |
| } |
| return false; |
| } |
| |
| inline bool is_positive_const(const PrimExpr& a) { |
| const int64_t* as_int = as_const_int(a); |
| return as_int && (*as_int > 0); |
| } |
| |
| inline bool is_negative_const(const PrimExpr& a) { |
| const int64_t* as_int = as_const_int(a); |
| return as_int && (*as_int < 0); |
| } |
| |
| inline bool is_const_int(const PrimExpr& x, int64_t value) { |
| const int64_t* as_int = as_const_int(x); |
| return as_int && (*as_int == value); |
| } |
| |
| inline bool is_no_op(const tir::Stmt& stmt) { |
| if (!stmt.defined()) return true; |
| if (const auto* op = stmt.as<tir::EvaluateNode>()) { |
| return is_const_int(op->value); |
| } |
| if (const auto* op = stmt.as<tir::SeqStmtNode>()) { |
| return op->seq.size() == 0; |
| } |
| return false; |
| } |
| |
| template <typename ValueType> |
| inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) { |
| if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span); |
| if (t.is_uint()) { |
| // Use IntImm if it is a small integer |
| uint64_t uval = static_cast<uint64_t>(value); |
| if (uval <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) { |
| return IntImm(t, static_cast<int64_t>(value), span); |
| } else { |
| uint64_t mask = (static_cast<uint64_t>(1) << 32U) - 1U; |
| uint64_t low = uval & mask; |
| uint64_t high = uval >> 32U; |
| return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high), span); |
| } |
| } |
| if (t.is_float() || t.is_bfloat16()) return FloatImm(t, static_cast<double>(value), span); |
| // For now, we store const scalar values of custom datatypes within doubles; later, during the |
| // datatypes lowering pass, we will lower the value to its true representation in the format |
| // specified by the datatype. |
| // TODO(gus) when do we need to start worrying about doubles not being precise enough? |
| if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) { |
| return FloatImm(t, static_cast<double>(value), span); |
| } |
| LOG(FATAL) << "cannot make const for type " << t; |
| return PrimExpr(); |
| } |
| |
| template <typename ValueType, typename> |
| inline PrimExpr make_const(DataType t, ValueType value, Span span) { |
| if (t.lanes() == 1) { |
| return MakeConstScalar(t, value, span); |
| } else { |
| return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); |
| } |
| } |
| |
| inline PrimExpr make_zero(DataType t, Span span) { |
| if (t.is_handle()) { |
| return reinterpret(t, make_const(DataType::UInt(64), 0, span)); |
| } |
| return make_const(t, 0, span); |
| } |
| |
| template <typename FReduce> |
| inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr>& values, |
| Span span) { |
| for (PrimExpr val : values) { |
| init_value = freduce(init_value, val, span); |
| } |
| return init_value; |
| } |
| |
| } // namespace tir |
| |
| // additional const expression overloading |
| #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ |
| inline PrimExpr Name(PrimExpr& a, PrimExpr b) { \ |
| a = OpFunc(a, b); \ |
| return a; \ |
| } |
| |
| #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ |
| inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ |
| inline PrimExpr Name(int a, const PrimExpr& b) { \ |
| return Name(tir::make_const(b.dtype(), a), b); \ |
| } \ |
| inline PrimExpr Name(const PrimExpr& a, int b) { \ |
| return Name(a, tir::make_const(a.dtype(), b)); \ |
| } \ |
| inline PrimExpr Name(const PrimExpr& a, double b) { \ |
| return Name(a, tir::make_const(DataType::Float(64), b)); \ |
| } |
| |
| #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, float b, Span span = Span()) { \ |
| return Name(a, PrimExpr(b), span); \ |
| } \ |
| inline PrimExpr Name(float a, const PrimExpr& b, Span span = Span()) { \ |
| return Name(PrimExpr(a), b, span); \ |
| } \ |
| inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ |
| return Name(tir::make_const(b.dtype(), a), b, span); \ |
| } \ |
| inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ |
| return Name(a, tir::make_const(a.dtype(), b), span); \ |
| } \ |
| inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ |
| return Name(a, tir::make_const(DataType::Float(64), b), span); \ |
| } |
| |
| #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, bool b) { return Name(a, PrimExpr(b)); } \ |
| inline PrimExpr Name(bool a, const PrimExpr& b) { return Name(PrimExpr(a), b); } |
| |
| #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, bool b, Span span = Span()) { \ |
| return Name(a, PrimExpr(b), span); \ |
| } \ |
| inline PrimExpr Name(bool a, const PrimExpr& b, Span span = Span()) { \ |
| return Name(PrimExpr(a), b, span); \ |
| } |
| |
| #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, int b) { \ |
| return Name(a, tir::make_const(a.dtype(), b)); \ |
| } \ |
| inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } |
| |
| #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ |
| inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ |
| return Name(a, tir::make_const(a.dtype(), b), span); \ |
| } \ |
| inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ |
| return Name(tir::make_const(b.dtype(), a), b, span); \ |
| } |
| |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); |
| TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(max); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(min); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(div); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(add); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(sub); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(mul); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(greater_equal); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less); |
| TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(less_equal); |
| // integer related ops |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexdiv); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(indexmod); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncdiv); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(truncmod); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floordiv); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(floormod); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(right_shift); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(left_shift); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_and); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_or); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(bitwise_xor); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); |
| TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); |
| // logical ops |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_and); |
| TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD_SPANNED(logical_or); |
| |
| /*! |
| * \brief Helper function to raise a compiler error about division ambiguity. |
| * \note The call to this function will always results in a compiler error. |
| * \tparam TA Any class type. |
| */ |
| template <typename TA> |
| inline void DivAmbiguityError(const TA& a) { |
| constexpr bool div_ambiguity = !std::is_class<TA>::value; |
| static_assert(div_ambiguity, |
| "TVM supports multiple types of integer divisions, " |
| "please call div, indexdiv/indexmod, " |
| "floordiv/floormod or truncdiv/truncmod directly " |
| "to avoid ambiguity in the code. " |
| "Checkout these functions in tir/op.h."); |
| } |
| |
| // The following code are not intended to be used in the codebase. |
| // Instead, they generate clear compiler errors that ask developers |
| // to use the specific division function. |
| // The second template argument is necessary to make sure the |
| // code compiles lazily by the compiler during invocation. |
| template <typename TB> |
| inline PrimExpr operator/(const PrimExpr& a, const TB& b) { |
| DivAmbiguityError(a); |
| return a; |
| } |
| |
| template <typename TB> |
| inline PrimExpr operator/=(const PrimExpr& a, const TB& b) { |
| DivAmbiguityError(a); |
| return a; |
| } |
| |
| template <typename TB> |
| inline PrimExpr operator%(const PrimExpr& a, const TB& b) { |
| DivAmbiguityError(a); |
| return a; |
| } |
| } // namespace tvm |
| #endif // TVM_TIR_OP_H_ |