blob: e914190c06456493f6d9008967072c14fb1853b0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.util.ToNumberParser
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType}
import org.apache.spark.unsafe.types.UTF8String
abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Boolean)
extends BinaryExpression with Serializable with ImplicitCastInputTypes with NullIntolerant {
private lazy val numberFormatter = {
val value = right.eval()
if (value != null) {
new ToNumberParser(value.toString.toUpperCase(Locale.ROOT), errorOnFail)
} else {
null
}
}
override def dataType: DataType = if (numberFormatter != null) {
numberFormatter.parsedDecimalType
} else {
DecimalType.USER_DEFAULT
}
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation)
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
if (!right.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId(right.prettyName),
"inputType" -> toSQLType(right.dataType),
"inputExpr" -> toSQLExpr(right)
)
)
} else if (numberFormatter == null) {
TypeCheckResult.TypeCheckSuccess
} else {
numberFormatter.checkInputDataTypes()
}
} else {
inputTypeCheck
}
}
override def nullSafeEval(string: Any, format: Any): Any = {
val input = string.asInstanceOf[UTF8String]
numberFormatter.parse(input)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val builder =
ctx.addReferenceObj("builder", numberFormatter, classOf[ToNumberParser].getName)
val eval = left.genCode(ctx)
ev.copy(code =
code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull} || ($builder == null);
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $builder.parse(${eval.value});
| ${ev.isNull} = ${ev.isNull} || (${ev.value} == null);
|}
""".stripMargin)
}
}
/**
* A function that converts strings to decimal values, returning an exception if the input string
* fails to match the format string.
*/
@ExpressionDescription(
usage = """
_FUNC_(expr, fmt) - Convert string 'expr' to a number based on the string format 'fmt'.
Throws an exception if the conversion fails. The format can consist of the following
characters, case insensitive:
'0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
string matches a sequence of digits in the input string. If the 0/9 sequence starts with
0 and is before the decimal point, it can only match a digit sequence of the same size.
Otherwise, if the sequence starts with 9 or is after the decimal point, it can match a
digit sequence that has the same or smaller size.
'.' or 'D': Specifies the position of the decimal point (optional, only allowed once).
',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be
a 0 or 9 to the left and right of each grouping separator. 'expr' must match the
grouping separator relevant for the size of the number.
'$': Specifies the location of the $ currency sign. This character may only be specified
once.
'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at
the beginning or end of the format string). Note that 'S' allows '-' but 'MI' does not.
'PR': Only allowed at the end of the format string; specifies that 'expr' indicates a
negative number with wrapping angled brackets.
('<1>').
""",
examples = """
Examples:
> SELECT _FUNC_('454', '999');
454
> SELECT _FUNC_('454.00', '000.00');
454.00
> SELECT _FUNC_('12,454', '99,999');
12454
> SELECT _FUNC_('$78.12', '$99.99');
78.12
> SELECT _FUNC_('12,454.8-', '99,999.9S');
-12454.8
""",
since = "3.3.0",
group = "string_funcs")
case class ToNumber(left: Expression, right: Expression)
extends ToNumberBase(left, right, true) {
override def prettyName: String = "to_number"
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ToNumber =
copy(left = newLeft, right = newRight)
}
/**
* A function that converts strings to decimal values, returning NULL if the input string fails to
* match the format string.
*/
@ExpressionDescription(
usage = """
_FUNC_(expr, fmt) - Convert string 'expr' to a number based on the string format `fmt`.
Returns NULL if the string 'expr' does not match the expected format. The format follows the
same semantics as the to_number function.
""",
examples = """
Examples:
> SELECT _FUNC_('454', '999');
454
> SELECT _FUNC_('454.00', '000.00');
454.00
> SELECT _FUNC_('12,454', '99,999');
12454
> SELECT _FUNC_('$78.12', '$99.99');
78.12
> SELECT _FUNC_('12,454.8-', '99,999.9S');
-12454.8
""",
since = "3.3.0",
group = "string_funcs")
case class TryToNumber(left: Expression, right: Expression)
extends ToNumberBase(left, right, false) {
override def nullable: Boolean = true
override def prettyName: String = "try_to_number"
override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): TryToNumber =
copy(left = newLeft, right = newRight)
}
/**
* A function that converts decimal/datetime values to strings, returning NULL if the value fails to
* match the format string.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr, format) - Convert `expr` to a string based on the `format`.
Throws an exception if the conversion fails. The format can consist of the following
characters, case insensitive:
'0' or '9': Specifies an expected digit between 0 and 9. A sequence of 0 or 9 in the format
string matches a sequence of digits in the input value, generating a result string of the
same length as the corresponding sequence in the format string. The result string is
left-padded with zeros if the 0/9 sequence comprises more digits than the matching part of
the decimal value, starts with 0, and is before the decimal point. Otherwise, it is
padded with spaces.
'.' or 'D': Specifies the position of the decimal point (optional, only allowed once).
',' or 'G': Specifies the position of the grouping (thousands) separator (,). There must be
a 0 or 9 to the left and right of each grouping separator.
'$': Specifies the location of the $ currency sign. This character may only be specified
once.
'S' or 'MI': Specifies the position of a '-' or '+' sign (optional, only allowed once at
the beginning or end of the format string). Note that 'S' prints '+' for positive values
but 'MI' prints a space.
'PR': Only allowed at the end of the format string; specifies that the result string will be
wrapped by angle brackets if the input value is negative.
('<1>').
If `expr` is a datetime, `format` shall be a valid datetime pattern, see <a href="https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html">Datetime Patterns</a>.
If `expr` is a binary, it is converted to a string in one of the formats:
'base64': a base 64 string.
'hex': a string in the hexadecimal format.
'utf-8': the input binary is decoded to UTF-8 string.
""",
examples = """
Examples:
> SELECT _FUNC_(454, '999');
454
> SELECT _FUNC_(454.00, '000D00');
454.00
> SELECT _FUNC_(12454, '99G999');
12,454
> SELECT _FUNC_(78.12, '$99.99');
$78.12
> SELECT _FUNC_(-12454.8, '99G999D9S');
12,454.8-
> SELECT _FUNC_(date'2016-04-08', 'y');
2016
> SELECT _FUNC_(x'537061726b2053514c', 'base64');
U3BhcmsgU1FM
> SELECT _FUNC_(x'537061726b2053514c', 'hex');
537061726B2053514C
> SELECT _FUNC_(encode('abc', 'utf-8'), 'utf-8');
abc
""",
since = "3.4.0",
group = "string_funcs")
// scalastyle:on line.size.limit
object ToCharacterBuilder extends ExpressionBuilder {
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 2) {
val (inputExpr, format) = (expressions(0), expressions(1))
inputExpr.dataType match {
case _: DatetimeType => DateFormatClass(inputExpr, format)
case _: BinaryType =>
if (!(format.dataType.isInstanceOf[StringType] && format.foldable)) {
throw QueryCompilationErrors.nonFoldableArgumentError(funcName, "format",
format.dataType)
}
val fmt = format.eval()
if (fmt == null) {
throw QueryCompilationErrors.nullArgumentError(funcName, "format")
}
fmt.asInstanceOf[UTF8String].toString.toLowerCase(Locale.ROOT).trim match {
case "base64" => Base64(inputExpr)
case "hex" => Hex(inputExpr)
case "utf-8" => new Decode(Seq(inputExpr, format))
case invalid => throw QueryCompilationErrors.binaryFormatError(funcName, invalid)
}
case _ => ToCharacter(inputExpr, format)
}
} else {
throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2), numArgs)
}
}
}
case class ToCharacter(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
private lazy val numberFormatter = {
val value = right.eval()
if (value != null) {
new ToNumberParser(value.toString.toUpperCase(Locale.ROOT), true)
} else {
null
}
}
override def dataType: DataType = SQLConf.get.defaultStringType
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation)
override def checkInputDataTypes(): TypeCheckResult = {
val inputTypeCheck = super.checkInputDataTypes()
if (inputTypeCheck.isSuccess) {
if (!right.foldable) {
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId(right.prettyName),
"inputType" -> toSQLType(right.dataType),
"inputExpr" -> toSQLExpr(right)
)
)
} else if (numberFormatter == null) {
TypeCheckResult.TypeCheckSuccess
} else {
numberFormatter.checkInputDataTypes()
}
} else {
inputTypeCheck
}
}
override def prettyName: String = "to_char"
override def nullSafeEval(decimal: Any, format: Any): Any = {
val input = decimal.asInstanceOf[Decimal]
numberFormatter.format(input)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val builder =
ctx.addReferenceObj("builder", numberFormatter, classOf[ToNumberParser].getName)
val eval = left.genCode(ctx)
val result =
code"""
|${eval.code}
|boolean ${ev.isNull} = ${eval.isNull} || ($builder == null);
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $builder.format(${eval.value});
|}
"""
val stripped = result.stripMargin
ev.copy(code = stripped)
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ToCharacter =
copy(left = newLeft, right = newRight)
}