| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /*! |
| * \file intrin_rule_default.cc |
| * \brief Default intrinsic rules. |
| */ |
| #include "intrin_rule.h" |
| |
| #include <tvm/runtime/logging.h> |
| #include <tvm/tirx/buffer.h> |
| #include <tvm/tirx/op.h> |
| #include <tvm/tirx/op_attr_types.h> |
| |
| namespace tvm { |
| namespace codegen { |
| namespace intrin { |
| using tirx::FLowerIntrinsic; |
| |
| TVM_REGISTER_OP("tirx.exp") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.erf") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.log") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.log2") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.log10") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.log1p") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.tanh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.tan") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.trunc") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.atan") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.atanh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.atan2") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.cos") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.acos") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.cosh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.acosh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.sin") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.asin") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.sinh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.asinh") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.hypot") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.nextafter") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.copysign") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.ldexp") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.sqrt") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.floor") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.ceil") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.round") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.nearbyint") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.pow") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", DispatchPureExtern<FloatSuffix>); |
| |
| TVM_REGISTER_OP("tirx.tvm_access_ptr") |
| .set_attr<FLowerIntrinsic>("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| TVM_FFI_ICHECK_EQ(call->args.size(), 5U); |
| DataType dtype = call->args[0].dtype(); |
| Var buffer_var = Downcast<Var>(call->args[1]); |
| PrimExpr offset = call->args[2]; |
| TVM_FFI_ICHECK(call->dtype.is_handle()); |
| if (dtype.lanes() != 1) { |
| offset = offset * make_const(offset.dtype(), dtype.lanes()); |
| offset = Ramp(offset, make_const(offset.dtype(), 1), dtype.lanes()); |
| } |
| Buffer dummy_buf(buffer_var, dtype.element_of(), {offset + 1}, {}, 0, buffer_var->name_hint, |
| 0, 0, kDefault); |
| BufferLoad buf_load(dummy_buf, {offset}); |
| return Call(DataType::Handle(), builtin::address_of(), {buf_load}); |
| }); |
| |
| PrimExpr DispatchFastErf(const PrimExpr& e) { |
| DLOG(WARNING) << "fast_erf will be used instead of erf"; |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| TVM_FFI_ICHECK_EQ(call->args.size(), 1); |
| PrimExpr arg = call->args[0]; |
| int bits = arg.dtype().bits(); |
| PrimExpr res; |
| if (arg.dtype().is_float() && (bits == 16 || bits == 32)) { |
| res = fast_erf_float_expr(arg, bits); |
| } else { |
| TVM_FFI_THROW(InternalError) << "Unsupported type in Metal fast_erf"; |
| } |
| return res; |
| } |
| |
| PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) { |
| using tirx::make_const; |
| using tirx::make_zero; |
| const tirx::CallNode* call = e.as<tirx::CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| const PrimExpr& x = call->args[0]; |
| PrimExpr one = make_const(x.dtype(), 1); |
| PrimExpr two = make_const(x.dtype(), 2); |
| PrimExpr neg_two = make_const(x.dtype(), -2); |
| |
| PrimExpr exp_neg2x = exp(neg_two * x); |
| PrimExpr exp_pos2x = exp(two * x); |
| |
| PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); |
| PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); |
| return tirx::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); |
| } |
| |
| } // namespace intrin |
| |
| namespace legalize { |
| |
| using namespace tirx; |
| |
| TVM_REGISTER_OP("tirx.rsqrt") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| auto one = make_const(call->args[0].dtype(), 1); |
| return one / sqrt(call->args[0]); |
| }); |
| |
| TVM_REGISTER_OP("tirx.sigmoid") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| auto one = make_const(call->args[0].dtype(), 1); |
| return one / (one + exp(-call->args[0])); |
| }); |
| |
| TVM_REGISTER_OP("tirx.isfinite") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| return isfinite(call->args[0]); |
| }); |
| |
| TVM_REGISTER_OP("tirx.isinf") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| const CallNode* call = e.as<CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| return isinf(call->args[0]); |
| }); |
| |
| /*! |
| * \brief Makes fixed point multiplication. |
| * \param x Input tensor. |
| * \param y Integer multiplier. |
| * \param left_shift Integer left shift. |
| * \param right_shift Integer right shift. |
| * \param is_left_shift_required Flag whether we need to do left shift or not. |
| * \return Calculated expression. |
| */ |
| static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left_shift, |
| PrimExpr right_shift, PrimExpr is_left_shift_required) { |
| // Only int32 types are supported (any number of lanes is allowed) |
| TVM_FFI_ICHECK(y.dtype().code() == DLDataTypeCode::kDLInt && y.dtype().bits() == 32); |
| TVM_FFI_ICHECK(left_shift.dtype().code() == DLDataTypeCode::kDLInt && |
| left_shift.dtype().bits() == 32); |
| TVM_FFI_ICHECK(right_shift.dtype().code() == DLDataTypeCode::kDLInt && |
| right_shift.dtype().bits() == 32); |
| |
| DataType hp_dtype = DataType::Int(64, x.dtype().lanes()); |
| DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); |
| |
| // 1) Cast and Multiply the integer multiplier |
| PrimExpr one = make_const(hp_dtype, 1); |
| x = cast(hp_dtype, x); |
| y = cast(hp_dtype, y); |
| x = tirx::Select(is_left_shift_required, x << left_shift, x); |
| |
| // 2) Perform the multiplication in higher precision. |
| x = x * y; |
| |
| // 3) Find the rounding scalar |
| PrimExpr total_right_shift = right_shift + q; |
| PrimExpr pos_rounding_value = (one << (total_right_shift - 1)); |
| x = x + pos_rounding_value; |
| |
| // 4) Simply right shift the result to get the final output. |
| x = x >> total_right_shift; |
| |
| // 5) The fixed point multiplication keeps the value in int32 range. Casting back to int32. |
| return cast(lp_dtype, x); |
| } |
| |
| TVM_REGISTER_OP("tirx.q_multiply_shift") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| using tirx::make_const; |
| |
| const tirx::CallNode* call = e.as<tirx::CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| |
| PrimExpr x = call->args[0]; |
| PrimExpr y = call->args[1]; |
| PrimExpr q = call->args[2]; |
| PrimExpr s = call->args[3]; |
| |
| // Lambda function to extract the int value from PrimExpr |
| auto get_int_value = [](const PrimExpr node) { |
| if (auto int_node = node.as<IntImmNode>()) { |
| return int_node->value; |
| } |
| auto broadcast_node = node.as<BroadcastNode>(); |
| TVM_FFI_ICHECK(broadcast_node != nullptr); |
| auto int_node = broadcast_node->value.as<IntImmNode>(); |
| TVM_FFI_ICHECK(int_node != nullptr); |
| return int_node->value; |
| }; |
| // Power of 2 is determined by the fixed_point_multiplier == 1 << 30. In case of power of |
| // 2, fixed point multiplier will represent a float value of 0.5. In fixed point, this is |
| // represented by 1 << 30. |
| if (get_int_value(y) == (1 << 30)) { |
| PrimExpr exp = s - 1; |
| int exp_val = get_int_value(s) - 1; |
| if (exp_val > 0) { |
| // power of 2 is greater than 0, apply left shift. |
| return x << exp; |
| } else { |
| // power of 2 is less than 0, round and then apply right shift. |
| DataType lp_dtype = DataType::Int(32, x.dtype().lanes()); |
| PrimExpr one = make_const(lp_dtype, 1); |
| exp = -exp; |
| PrimExpr rounding_factor = one << (exp - 1); |
| PrimExpr rounded_t = x + rounding_factor; |
| return rounded_t >> exp; |
| } |
| } else { |
| // Only int32 types are supported (any number of lanes is allowed) |
| TVM_FFI_ICHECK(s.dtype().code() == DLDataTypeCode::kDLInt && s.dtype().bits() == 32); |
| |
| // Calculating integer shifts |
| PrimExpr zero = make_const(s.dtype(), 0); |
| PrimExpr left_shift = tirx::Select(s > zero, s, zero); |
| PrimExpr right_shift = tirx::Select(s > zero, zero, -s); |
| PrimExpr is_left_shift_required = (left_shift != zero); |
| |
| return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required); |
| } |
| }); |
| |
| TVM_REGISTER_OP("tirx.q_multiply_shift_per_axis") |
| .set_attr<FLegalize>("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { |
| const tirx::CallNode* call = e.as<tirx::CallNode>(); |
| TVM_FFI_ICHECK(call != nullptr); |
| |
| PrimExpr x = call->args[0]; |
| PrimExpr y = call->args[1]; |
| PrimExpr left_shift = call->args[2]; |
| PrimExpr right_shift = call->args[3]; |
| PrimExpr q = call->args[4]; |
| PrimExpr is_lshift_required = call->args[5]; |
| // Note, 7th argument is "is_rshift_required" flag, but we don't need that here. |
| // PrimExpr is_rshift_required = call->args[6]; |
| |
| return QMultiplyShift(x, y, q, left_shift, right_shift, is_lshift_required); |
| }); |
| } // namespace legalize |
| } // namespace codegen |
| } // namespace tvm |