blob: 925188a7df3a2d815969a3a41d8e05cc00869682 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import unittest
from pyflink.table import DataTypes
from pyflink.table.expression import TimeIntervalUnit, TimePointUnit, JsonExistsOnError, \
JsonValueOnEmptyOrError, JsonType, JsonQueryWrapper, JsonQueryOnEmptyOrError
from pyflink.table.expressions import (col, lit, range_, and_, or_, current_date,
current_time, current_timestamp, local_time,
local_timestamp, temporal_overlaps, date_format,
timestamp_diff, array, row, map_, row_interval, pi, e,
rand, rand_integer, atan2, negative, concat, concat_ws, uuid,
null_of, log, if_then_else, with_columns, call,
to_timestamp_ltz)
from pyflink.testing.test_case_utils import PyFlinkTestCase
class PyFlinkBatchExpressionTests(PyFlinkTestCase):
def test_expression(self):
expr1 = col('a')
expr2 = col('b')
expr3 = col('c')
expr4 = col('d')
expr5 = lit(10)
# comparison functions
self.assertEqual('equals(a, b)', str(expr1 == expr2))
self.assertEqual('mod(2, b)', str(2 % expr2))
self.assertEqual('notEquals(a, b)', str(expr1 != expr2))
self.assertEqual('lessThan(a, b)', str(expr1 < expr2))
self.assertEqual('lessThanOrEqual(a, b)', str(expr1 <= expr2))
self.assertEqual('greaterThan(a, b)', str(expr1 > expr2))
self.assertEqual('greaterThanOrEqual(a, b)', str(expr1 >= expr2))
# logic functions
self.assertEqual('and(a, b)', str(expr1 & expr2))
self.assertEqual('or(a, b)', str(expr1 | expr2))
self.assertEqual('isNotTrue(a)', str(expr1.is_not_true))
self.assertEqual('isNotTrue(a)', str(~expr1))
# arithmetic functions
self.assertEqual('plus(a, b)', str(expr1 + expr2))
self.assertEqual('plus(2, b)', str(2 + expr2))
self.assertEqual('plus(cast(b, DATE), 2)', str(expr2.to_date + 2))
self.assertEqual('minus(a, b)', str(expr1 - expr2))
self.assertEqual('minus(cast(b, DATE), 2)', str(expr2.to_date - 2))
self.assertEqual('times(a, b)', str(expr1 * expr2))
self.assertEqual('divide(a, b)', str(expr1 / expr2))
self.assertEqual('mod(a, b)', str(expr1 % expr2))
self.assertEqual('power(a, b)', str(expr1 ** expr2))
self.assertEqual('minusPrefix(a)', str(-expr1))
self.assertEqual('exp(a)', str(expr1.exp))
self.assertEqual('log10(a)', str(expr1.log10))
self.assertEqual('log2(a)', str(expr1.log2))
self.assertEqual('ln(a)', str(expr1.ln))
self.assertEqual('log(a)', str(expr1.log()))
self.assertEqual('cosh(a)', str(expr1.cosh))
self.assertEqual('sinh(a)', str(expr1.sinh))
self.assertEqual('sin(a)', str(expr1.sin))
self.assertEqual('cos(a)', str(expr1.cos))
self.assertEqual('tan(a)', str(expr1.tan))
self.assertEqual('cot(a)', str(expr1.cot))
self.assertEqual('asin(a)', str(expr1.asin))
self.assertEqual('acos(a)', str(expr1.acos))
self.assertEqual('atan(a)', str(expr1.atan))
self.assertEqual('tanh(a)', str(expr1.tanh))
self.assertEqual('degrees(a)', str(expr1.degrees))
self.assertEqual('radians(a)', str(expr1.radians))
self.assertEqual('sqrt(a)', str(expr1.sqrt))
self.assertEqual('abs(a)', str(expr1.abs))
self.assertEqual('abs(a)', str(abs(expr1)))
self.assertEqual('sign(a)', str(expr1.sign))
self.assertEqual('round(a, b)', str(expr1.round(expr2)))
self.assertEqual('between(a, b, c)', str(expr1.between(expr2, expr3)))
self.assertEqual('notBetween(a, b, c)', str(expr1.not_between(expr2, expr3)))
self.assertEqual('ifThenElse(a, b, c)', str(expr1.then(expr2, expr3)))
self.assertEqual('isNull(a)', str(expr1.is_null))
self.assertEqual('isNotNull(a)', str(expr1.is_not_null))
self.assertEqual('isTrue(a)', str(expr1.is_true))
self.assertEqual('isFalse(a)', str(expr1.is_false))
self.assertEqual('isNotTrue(a)', str(expr1.is_not_true))
self.assertEqual('isNotFalse(a)', str(expr1.is_not_false))
self.assertEqual('distinct(a)', str(expr1.distinct))
self.assertEqual('sum(a)', str(expr1.sum))
self.assertEqual('sum0(a)', str(expr1.sum0))
self.assertEqual('min(a)', str(expr1.min))
self.assertEqual('max(a)', str(expr1.max))
self.assertEqual('count(a)', str(expr1.count))
self.assertEqual('avg(a)', str(expr1.avg))
self.assertEqual('stddevPop(a)', str(expr1.stddev_pop))
self.assertEqual('stddevSamp(a)', str(expr1.stddev_samp))
self.assertEqual('varPop(a)', str(expr1.var_pop))
self.assertEqual('varSamp(a)', str(expr1.var_samp))
self.assertEqual('collect(a)', str(expr1.collect))
self.assertEqual("as(a, 'a', 'b', 'c')", str(expr1.alias('a', 'b', 'c')))
self.assertEqual('cast(a, INT)', str(expr1.cast(DataTypes.INT())))
self.assertEqual('asc(a)', str(expr1.asc))
self.assertEqual('desc(a)', str(expr1.desc))
self.assertEqual('in(a, b, c, d)', str(expr1.in_(expr2, expr3, expr4)))
self.assertEqual('start(a)', str(expr1.start))
self.assertEqual('end(a)', str(expr1.end))
self.assertEqual('bin(a)', str(expr1.bin))
self.assertEqual('hex(a)', str(expr1.hex))
self.assertEqual('truncate(a, 3)', str(expr1.truncate(3)))
# string functions
self.assertEqual('substring(a, b, 3)', str(expr1.substring(expr2, 3)))
self.assertEqual("trim(true, false, ' ', a)", str(expr1.trim_leading()))
self.assertEqual("trim(false, true, ' ', a)", str(expr1.trim_trailing()))
self.assertEqual("trim(true, true, ' ', a)", str(expr1.trim()))
self.assertEqual('replace(a, b, c)', str(expr1.replace(expr2, expr3)))
self.assertEqual('charLength(a)', str(expr1.char_length))
self.assertEqual('upper(a)', str(expr1.upper_case))
self.assertEqual('lower(a)', str(expr1.lower_case))
self.assertEqual('initCap(a)', str(expr1.init_cap))
self.assertEqual("like(a, 'Jo_n%')", str(expr1.like('Jo_n%')))
self.assertEqual("similar(a, 'A+')", str(expr1.similar('A+')))
self.assertEqual('position(a, b)', str(expr1.position(expr2)))
self.assertEqual('lpad(a, 4, b)', str(expr1.lpad(4, expr2)))
self.assertEqual('rpad(a, 4, b)', str(expr1.rpad(4, expr2)))
self.assertEqual('overlay(a, b, 6, 2)', str(expr1.overlay(expr2, 6, 2)))
self.assertEqual("regexpReplace(a, b, 'abc')", str(expr1.regexp_replace(expr2, 'abc')))
self.assertEqual('regexpExtract(a, b, 3)', str(expr1.regexp_extract(expr2, 3)))
self.assertEqual('fromBase64(a)', str(expr1.from_base64))
self.assertEqual('toBase64(a)', str(expr1.to_base64))
self.assertEqual('ltrim(a)', str(expr1.ltrim))
self.assertEqual('rtrim(a)', str(expr1.rtrim))
self.assertEqual('repeat(a, 3)', str(expr1.repeat(3)))
self.assertEqual("over(a, 'w')", str(expr1.over('w')))
# temporal functions
self.assertEqual('cast(a, DATE)', str(expr1.to_date))
self.assertEqual('cast(a, TIME(0))', str(expr1.to_time))
self.assertEqual('cast(a, TIMESTAMP(3))', str(expr1.to_timestamp))
self.assertEqual('extract(YEAR, a)', str(expr1.extract(TimeIntervalUnit.YEAR)))
self.assertEqual('floor(a, YEAR)', str(expr1.floor(TimeIntervalUnit.YEAR)))
self.assertEqual('ceil(a)', str(expr1.ceil()))
# advanced type helper functions
self.assertEqual("get(a, 'col')", str(expr1.get('col')))
self.assertEqual('flatten(a)', str(expr1.flatten))
self.assertEqual('at(a, 0)', str(expr1.at(0)))
self.assertEqual('cardinality(a)', str(expr1.cardinality))
self.assertEqual('element(a)', str(expr1.element))
# time definition functions
self.assertEqual('rowtime(a)', str(expr1.rowtime))
self.assertEqual('proctime(a)', str(expr1.proctime))
self.assertEqual('120', str(expr5.year))
self.assertEqual('120', str(expr5.years))
self.assertEqual('30', str(expr5.quarter))
self.assertEqual('30', str(expr5.quarters))
self.assertEqual('10', str(expr5.month))
self.assertEqual('10', str(expr5.months))
self.assertEqual('6048000000', str(expr5.week))
self.assertEqual('6048000000', str(expr5.weeks))
self.assertEqual('864000000', str(expr5.day))
self.assertEqual('864000000', str(expr5.days))
self.assertEqual('36000000', str(expr5.hour))
self.assertEqual('36000000', str(expr5.hours))
self.assertEqual('600000', str(expr5.minute))
self.assertEqual('600000', str(expr5.minutes))
self.assertEqual('10000', str(expr5.second))
self.assertEqual('10000', str(expr5.seconds))
self.assertEqual('10', str(expr5.milli))
self.assertEqual('10', str(expr5.millis))
# hash functions
self.assertEqual('md5(a)', str(expr1.md5))
self.assertEqual('sha1(a)', str(expr1.sha1))
self.assertEqual('sha224(a)', str(expr1.sha224))
self.assertEqual('sha256(a)', str(expr1.sha256))
self.assertEqual('sha384(a)', str(expr1.sha384))
self.assertEqual('sha512(a)', str(expr1.sha512))
self.assertEqual('sha2(a, 224)', str(expr1.sha2(224)))
# json functions
self.assertEqual("IS_JSON('42')", str(lit('42').is_json()))
self.assertEqual("IS_JSON('42', SCALAR)", str(lit('42').is_json(JsonType.SCALAR)))
self.assertEqual("JSON_EXISTS('{}', '$.x')", str(lit('{}').json_exists('$.x')))
self.assertEqual("JSON_EXISTS('{}', '$.x', FALSE)",
str(lit('{}').json_exists('$.x', JsonExistsOnError.FALSE)))
self.assertEqual("JSON_VALUE('{}', '$.x', STRING, NULL, null, NULL, null)",
str(lit('{}').json_value('$.x')))
self.assertEqual("JSON_VALUE('{}', '$.x', INT, DEFAULT, 42, ERROR, null)",
str(lit('{}').json_value('$.x', DataTypes.INT(),
JsonValueOnEmptyOrError.DEFAULT, 42,
JsonValueOnEmptyOrError.ERROR, None)))
self.assertEqual("JSON_QUERY('{}', '$.x', WITHOUT_ARRAY, NULL, EMPTY_ARRAY)",
str(lit('{}').json_query('$.x', JsonQueryWrapper.WITHOUT_ARRAY,
JsonQueryOnEmptyOrError.NULL,
JsonQueryOnEmptyOrError.EMPTY_ARRAY)))
def test_expressions(self):
expr1 = col('a')
expr2 = col('b')
expr3 = col('c')
self.assertEqual('10', str(lit(10, DataTypes.INT(False))))
self.assertEqual('rangeTo(1, 2)', str(range_(1, 2)))
self.assertEqual('and(a, b, c)', str(and_(expr1, expr2, expr3)))
self.assertEqual('or(a, b, c)', str(or_(expr1, expr2, expr3)))
from pyflink.table.expressions import UNBOUNDED_ROW, UNBOUNDED_RANGE, CURRENT_ROW, \
CURRENT_RANGE
self.assertEqual('unboundedRow()', str(UNBOUNDED_ROW))
self.assertEqual('unboundedRange()', str(UNBOUNDED_RANGE))
self.assertEqual('currentRow()', str(CURRENT_ROW))
self.assertEqual('currentRange()', str(CURRENT_RANGE))
self.assertEqual('currentDate()', str(current_date()))
self.assertEqual('currentTime()', str(current_time()))
self.assertEqual('currentTimestamp()', str(current_timestamp()))
self.assertEqual('localTime()', str(local_time()))
self.assertEqual('localTimestamp()', str(local_timestamp()))
self.assertEqual('toTimestampLtz(123, 0)', str(to_timestamp_ltz(123, 0)))
self.assertEqual("temporalOverlaps(cast('2:55:00', TIME(0)), 3600000, "
"cast('3:30:00', TIME(0)), 7200000)",
str(temporal_overlaps(
lit("2:55:00").to_time,
lit(1).hours,
lit("3:30:00").to_time,
lit(2).hours)))
self.assertEqual("dateFormat(time, '%Y, %d %M')",
str(date_format(col("time"), "%Y, %d %M")))
self.assertEqual("timestampDiff(DAY, cast('2016-06-15', DATE), cast('2016-06-18', DATE))",
str(timestamp_diff(
TimePointUnit.DAY,
lit("2016-06-15").to_date,
lit("2016-06-18").to_date)))
self.assertEqual('array(1, 2, 3)', str(array(1, 2, 3)))
self.assertEqual("row('key1', 1)", str(row("key1", 1)))
self.assertEqual("map('key1', 1, 'key2', 2, 'key3', 3)",
str(map_("key1", 1, "key2", 2, "key3", 3)))
self.assertEqual('4', str(row_interval(4)))
self.assertEqual('pi()', str(pi()))
self.assertEqual('e()', str(e()))
self.assertEqual('rand(4)', str(rand(4)))
self.assertEqual('randInteger(4)', str(rand_integer(4)))
self.assertEqual('atan2(1, 2)', str(atan2(1, 2)))
self.assertEqual('minusPrefix(a)', str(negative(expr1)))
self.assertEqual('concat(a, b, c)', str(concat(expr1, expr2, expr3)))
self.assertEqual("concat_ws(', ', b, c)", str(concat_ws(', ', expr2, expr3)))
self.assertEqual('uuid()', str(uuid()))
self.assertEqual('null', str(null_of(DataTypes.BIGINT())))
self.assertEqual('log(a)', str(log(expr1)))
self.assertEqual('ifThenElse(a, b, c)', str(if_then_else(expr1, expr2, expr3)))
self.assertEqual('withColumns(a, b, c)', str(with_columns(expr1, expr2, expr3)))
self.assertEqual('a.b.c(a)', str(call('a.b.c', expr1)))
if __name__ == "__main__":
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output='target/test-reports')
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)