| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import math |
| |
| import numpy as np |
| import pyarrow as pa |
| import pytest |
| from datetime import datetime |
| |
| from datafusion import SessionContext, column |
| from datafusion import functions as f |
| from datafusion import literal |
| |
| |
| @pytest.fixture |
| def df(): |
| ctx = SessionContext() |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [ |
| pa.array(["Hello", "World", "!"]), |
| pa.array([4, 5, 6]), |
| pa.array(["hello ", " world ", " !"]), |
| pa.array( |
| [ |
| datetime(2022, 12, 31), |
| datetime(2027, 6, 26), |
| datetime(2020, 7, 2), |
| ] |
| ), |
| ], |
| names=["a", "b", "c", "d"], |
| ) |
| return ctx.create_dataframe([[batch]]) |
| |
| |
| def test_literal(df): |
| df = df.select( |
| literal(1), |
| literal("1"), |
| literal("OK"), |
| literal(3.14), |
| literal(True), |
| literal(b"hello world"), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array([1] * 3) |
| assert result.column(1) == pa.array(["1"] * 3) |
| assert result.column(2) == pa.array(["OK"] * 3) |
| assert result.column(3) == pa.array([3.14] * 3) |
| assert result.column(4) == pa.array([True] * 3) |
| assert result.column(5) == pa.array([b"hello world"] * 3) |
| |
| |
| def test_lit_arith(df): |
| """ |
| Test literals with arithmetic operations |
| """ |
| df = df.select( |
| literal(1) + column("b"), f.concat(column("a"), literal("!")) |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array([5, 6, 7]) |
| assert result.column(1) == pa.array(["Hello!", "World!", "!!"]) |
| |
| |
| def test_math_functions(): |
| ctx = SessionContext() |
| # create a RecordBatch and a new DataFrame from it |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([0.1, -0.7, 0.55]), pa.array([float("nan"), 0, 2.0])], |
| names=["value", "na_value"], |
| ) |
| df = ctx.create_dataframe([[batch]]) |
| |
| values = np.array([0.1, -0.7, 0.55]) |
| na_values = np.array([np.nan, 0, 2.0]) |
| col_v = column("value") |
| col_nav = column("na_value") |
| df = df.select( |
| f.abs(col_v), |
| f.sin(col_v), |
| f.cos(col_v), |
| f.tan(col_v), |
| f.asin(col_v), |
| f.acos(col_v), |
| f.exp(col_v), |
| f.ln(col_v + literal(pa.scalar(1))), |
| f.log2(col_v + literal(pa.scalar(1))), |
| f.log10(col_v + literal(pa.scalar(1))), |
| f.random(), |
| f.atan(col_v), |
| f.atan2(col_v, literal(pa.scalar(1.1))), |
| f.ceil(col_v), |
| f.floor(col_v), |
| f.power(col_v, literal(pa.scalar(3))), |
| f.pow(col_v, literal(pa.scalar(4))), |
| f.round(col_v), |
| f.sqrt(col_v), |
| f.signum(col_v), |
| f.trunc(col_v), |
| f.asinh(col_v), |
| f.acosh(col_v), |
| f.atanh(col_v), |
| f.cbrt(col_v), |
| f.cosh(col_v), |
| f.degrees(col_v), |
| f.gcd(literal(9), literal(3)), |
| f.lcm(literal(6), literal(4)), |
| f.nanvl(col_nav, literal(5)), |
| f.pi(), |
| f.radians(col_v), |
| f.sinh(col_v), |
| f.tanh(col_v), |
| f.factorial(literal(6)), |
| f.isnan(col_nav), |
| f.iszero(col_nav), |
| f.log(literal(3), col_v + literal(pa.scalar(1))), |
| ) |
| batches = df.collect() |
| assert len(batches) == 1 |
| result = batches[0] |
| |
| np.testing.assert_array_almost_equal(result.column(0), np.abs(values)) |
| np.testing.assert_array_almost_equal(result.column(1), np.sin(values)) |
| np.testing.assert_array_almost_equal(result.column(2), np.cos(values)) |
| np.testing.assert_array_almost_equal(result.column(3), np.tan(values)) |
| np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values)) |
| np.testing.assert_array_almost_equal(result.column(5), np.arccos(values)) |
| np.testing.assert_array_almost_equal(result.column(6), np.exp(values)) |
| np.testing.assert_array_almost_equal( |
| result.column(7), np.log(values + 1.0) |
| ) |
| np.testing.assert_array_almost_equal( |
| result.column(8), np.log2(values + 1.0) |
| ) |
| np.testing.assert_array_almost_equal( |
| result.column(9), np.log10(values + 1.0) |
| ) |
| np.testing.assert_array_less(result.column(10), np.ones_like(values)) |
| np.testing.assert_array_almost_equal(result.column(11), np.arctan(values)) |
| np.testing.assert_array_almost_equal( |
| result.column(12), np.arctan2(values, 1.1) |
| ) |
| np.testing.assert_array_almost_equal(result.column(13), np.ceil(values)) |
| np.testing.assert_array_almost_equal(result.column(14), np.floor(values)) |
| np.testing.assert_array_almost_equal( |
| result.column(15), np.power(values, 3) |
| ) |
| np.testing.assert_array_almost_equal( |
| result.column(16), np.power(values, 4) |
| ) |
| np.testing.assert_array_almost_equal(result.column(17), np.round(values)) |
| np.testing.assert_array_almost_equal(result.column(18), np.sqrt(values)) |
| np.testing.assert_array_almost_equal(result.column(19), np.sign(values)) |
| np.testing.assert_array_almost_equal(result.column(20), np.trunc(values)) |
| np.testing.assert_array_almost_equal(result.column(21), np.arcsinh(values)) |
| np.testing.assert_array_almost_equal(result.column(22), np.arccosh(values)) |
| np.testing.assert_array_almost_equal(result.column(23), np.arctanh(values)) |
| np.testing.assert_array_almost_equal(result.column(24), np.cbrt(values)) |
| np.testing.assert_array_almost_equal(result.column(25), np.cosh(values)) |
| np.testing.assert_array_almost_equal(result.column(26), np.degrees(values)) |
| np.testing.assert_array_almost_equal(result.column(27), np.gcd(9, 3)) |
| np.testing.assert_array_almost_equal(result.column(28), np.lcm(6, 4)) |
| np.testing.assert_array_almost_equal( |
| result.column(29), np.where(np.isnan(na_values), 5, na_values) |
| ) |
| np.testing.assert_array_almost_equal(result.column(30), np.pi) |
| np.testing.assert_array_almost_equal(result.column(31), np.radians(values)) |
| np.testing.assert_array_almost_equal(result.column(32), np.sinh(values)) |
| np.testing.assert_array_almost_equal(result.column(33), np.tanh(values)) |
| np.testing.assert_array_almost_equal(result.column(34), math.factorial(6)) |
| np.testing.assert_array_almost_equal( |
| result.column(35), np.isnan(na_values) |
| ) |
| np.testing.assert_array_almost_equal(result.column(36), na_values == 0) |
| np.testing.assert_array_almost_equal( |
| result.column(37), np.emath.logn(3, values + 1.0) |
| ) |
| |
| |
| def test_string_functions(df): |
| df = df.select( |
| f.ascii(column("a")), |
| f.bit_length(column("a")), |
| f.btrim(literal(" World ")), |
| f.character_length(column("a")), |
| f.chr(literal(68)), |
| f.concat_ws("-", column("a"), literal("test")), |
| f.concat(column("a"), literal("?")), |
| f.initcap(column("c")), |
| f.left(column("a"), literal(3)), |
| f.length(column("c")), |
| f.lower(column("a")), |
| f.lpad(column("a"), literal(7)), |
| f.ltrim(column("c")), |
| f.md5(column("a")), |
| f.octet_length(column("a")), |
| f.repeat(column("a"), literal(2)), |
| f.replace(column("a"), literal("l"), literal("?")), |
| f.reverse(column("a")), |
| f.right(column("a"), literal(4)), |
| f.rpad(column("a"), literal(8)), |
| f.rtrim(column("c")), |
| f.split_part(column("a"), literal("l"), literal(1)), |
| f.starts_with(column("a"), literal("Wor")), |
| f.strpos(column("a"), literal("o")), |
| f.substr(column("a"), literal(3)), |
| f.translate(column("a"), literal("or"), literal("ld")), |
| f.trim(column("c")), |
| f.upper(column("c")), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array( |
| [72, 87, 33], type=pa.int32() |
| ) # H = 72; W = 87; ! = 33 |
| assert result.column(1) == pa.array([40, 40, 8], type=pa.int32()) |
| assert result.column(2) == pa.array(["World", "World", "World"]) |
| assert result.column(3) == pa.array([5, 5, 1], type=pa.int32()) |
| assert result.column(4) == pa.array(["D", "D", "D"]) |
| assert result.column(5) == pa.array(["Hello-test", "World-test", "!-test"]) |
| assert result.column(6) == pa.array(["Hello?", "World?", "!?"]) |
| assert result.column(7) == pa.array(["Hello ", " World ", " !"]) |
| assert result.column(8) == pa.array(["Hel", "Wor", "!"]) |
| assert result.column(9) == pa.array([6, 7, 2], type=pa.int32()) |
| assert result.column(10) == pa.array(["hello", "world", "!"]) |
| assert result.column(11) == pa.array([" Hello", " World", " !"]) |
| assert result.column(12) == pa.array(["hello ", "world ", "!"]) |
| assert result.column(13) == pa.array( |
| [ |
| "8b1a9953c4611296a827abf8c47804d7", |
| "f5a7924e621e84c9280a9a27e1bcb7f6", |
| "9033e0e305f247c0c3c80d0c7848c8b3", |
| ] |
| ) |
| assert result.column(14) == pa.array([5, 5, 1], type=pa.int32()) |
| assert result.column(15) == pa.array(["HelloHello", "WorldWorld", "!!"]) |
| assert result.column(16) == pa.array(["He??o", "Wor?d", "!"]) |
| assert result.column(17) == pa.array(["olleH", "dlroW", "!"]) |
| assert result.column(18) == pa.array(["ello", "orld", "!"]) |
| assert result.column(19) == pa.array(["Hello ", "World ", "! "]) |
| assert result.column(20) == pa.array(["hello", " world", " !"]) |
| assert result.column(21) == pa.array(["He", "Wor", "!"]) |
| assert result.column(22) == pa.array([False, True, False]) |
| assert result.column(23) == pa.array([5, 2, 0], type=pa.int32()) |
| assert result.column(24) == pa.array(["llo", "rld", ""]) |
| assert result.column(25) == pa.array(["Helll", "Wldld", "!"]) |
| assert result.column(26) == pa.array(["hello", "world", "!"]) |
| assert result.column(27) == pa.array(["HELLO ", " WORLD ", " !"]) |
| |
| |
| def test_hash_functions(df): |
| exprs = [ |
| f.digest(column("a"), literal(m)) |
| for m in ( |
| "md5", |
| "sha224", |
| "sha256", |
| "sha384", |
| "sha512", |
| "blake2s", |
| "blake3", |
| ) |
| ] |
| df = df.select( |
| *exprs, |
| f.sha224(column("a")), |
| f.sha256(column("a")), |
| f.sha384(column("a")), |
| f.sha512(column("a")), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| b = bytearray.fromhex |
| assert result.column(0) == pa.array( |
| [ |
| b("8B1A9953C4611296A827ABF8C47804D7"), |
| b("F5A7924E621E84C9280A9A27E1BCB7F6"), |
| b("9033E0E305F247C0C3C80D0C7848C8B3"), |
| ] |
| ) |
| assert result.column(1) == pa.array( |
| [ |
| b("4149DA18AA8BFC2B1E382C6C26556D01A92C261B6436DAD5E3BE3FCC"), |
| b("12972632B6D3B6AA52BD6434552F08C1303D56B817119406466E9236"), |
| b("6641A7E8278BCD49E476E7ACAE158F4105B2952D22AEB2E0B9A231A0"), |
| ] |
| ) |
| assert result.column(2) == pa.array( |
| [ |
| b( |
| "185F8DB32271FE25F561A6FC938B2E26" |
| "4306EC304EDA518007D1764826381969" |
| ), |
| b( |
| "78AE647DC5544D227130A0682A51E30B" |
| "C7777FBB6D8A8F17007463A3ECD1D524" |
| ), |
| b( |
| "BB7208BC9B5D7C04F1236A82A0093A5E" |
| "33F40423D5BA8D4266F7092C3BA43B62" |
| ), |
| ] |
| ) |
| assert result.column(3) == pa.array( |
| [ |
| b( |
| "3519FE5AD2C596EFE3E276A6F351B8FC" |
| "0B03DB861782490D45F7598EBD0AB5FD" |
| "5520ED102F38C4A5EC834E98668035FC" |
| ), |
| b( |
| "ED7CED84875773603AF90402E42C65F3" |
| "B48A5E77F84ADC7A19E8F3E8D3101010" |
| "22F552AEC70E9E1087B225930C1D260A" |
| ), |
| b( |
| "1D0EC8C84EE9521E21F06774DE232367" |
| "B64DE628474CB5B2E372B699A1F55AE3" |
| "35CC37193EF823E33324DFD9A70738A6" |
| ), |
| ] |
| ) |
| assert result.column(4) == pa.array( |
| [ |
| b( |
| "3615F80C9D293ED7402687F94B22D58E" |
| "529B8CC7916F8FAC7FDDF7FBD5AF4CF7" |
| "77D3D795A7A00A16BF7E7F3FB9561EE9" |
| "BAAE480DA9FE7A18769E71886B03F315" |
| ), |
| b( |
| "8EA77393A42AB8FA92500FB077A9509C" |
| "C32BC95E72712EFA116EDAF2EDFAE34F" |
| "BB682EFDD6C5DD13C117E08BD4AAEF71" |
| "291D8AACE2F890273081D0677C16DF0F" |
| ), |
| b( |
| "3831A6A6155E509DEE59A7F451EB3532" |
| "4D8F8F2DF6E3708894740F98FDEE2388" |
| "9F4DE5ADB0C5010DFB555CDA77C8AB5D" |
| "C902094C52DE3278F35A75EBC25F093A" |
| ), |
| ] |
| ) |
| assert result.column(5) == pa.array( |
| [ |
| b( |
| "F73A5FBF881F89B814871F46E26AD3FA" |
| "37CB2921C5E8561618639015B3CCBB71" |
| ), |
| b( |
| "B792A0383FB9E7A189EC150686579532" |
| "854E44B71AC394831DAED169BA85CCC5" |
| ), |
| b( |
| "27988A0E51812297C77A433F63523334" |
| "6AEE29A829DCF4F46E0F58F402C6CFCB" |
| ), |
| ] |
| ) |
| assert result.column(6) == pa.array( |
| [ |
| b( |
| "FBC2B0516EE8744D293B980779178A35" |
| "08850FDCFE965985782C39601B65794F" |
| ), |
| b( |
| "BF73D18575A736E4037D45F9E316085B" |
| "86C19BE6363DE6AA789E13DEAACC1C4E" |
| ), |
| b( |
| "C8D11B9F7237E4034ADBCD2005735F9B" |
| "C4C597C75AD89F4492BEC8F77D15F7EB" |
| ), |
| ] |
| ) |
| assert result.column(7) == result.column(1) # SHA-224 |
| assert result.column(8) == result.column(2) # SHA-256 |
| assert result.column(9) == result.column(3) # SHA-384 |
| assert result.column(10) == result.column(4) # SHA-512 |
| |
| |
| def test_temporal_functions(df): |
| df = df.select( |
| f.date_part(literal("month"), column("d")), |
| f.datepart(literal("year"), column("d")), |
| f.date_trunc(literal("month"), column("d")), |
| f.datetrunc(literal("day"), column("d")), |
| f.date_bin( |
| literal("15 minutes"), |
| column("d"), |
| literal("2001-01-01 00:02:30"), |
| ), |
| f.from_unixtime(literal(1673383974)), |
| f.to_timestamp(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_seconds(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_millis(literal("2023-09-07 05:06:14.523952")), |
| f.to_timestamp_micros(literal("2023-09-07 05:06:14.523952")), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array([12, 6, 7], type=pa.float64()) |
| assert result.column(1) == pa.array([2022, 2027, 2020], type=pa.float64()) |
| assert result.column(2) == pa.array( |
| [datetime(2022, 12, 1), datetime(2027, 6, 1), datetime(2020, 7, 1)], |
| type=pa.timestamp("us"), |
| ) |
| assert result.column(3) == pa.array( |
| [datetime(2022, 12, 31), datetime(2027, 6, 26), datetime(2020, 7, 2)], |
| type=pa.timestamp("us"), |
| ) |
| assert result.column(4) == pa.array( |
| [ |
| datetime(2022, 12, 30, 23, 47, 30), |
| datetime(2027, 6, 25, 23, 47, 30), |
| datetime(2020, 7, 1, 23, 47, 30), |
| ], |
| type=pa.timestamp("ns"), |
| ) |
| assert result.column(5) == pa.array( |
| [datetime(2023, 1, 10, 20, 52, 54)] * 3, type=pa.timestamp("s") |
| ) |
| assert result.column(6) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("ns") |
| ) |
| assert result.column(7) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14)] * 3, type=pa.timestamp("s") |
| ) |
| assert result.column(8) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523000)] * 3, type=pa.timestamp("ms") |
| ) |
| assert result.column(9) == pa.array( |
| [datetime(2023, 9, 7, 5, 6, 14, 523952)] * 3, type=pa.timestamp("us") |
| ) |
| |
| |
| def test_case(df): |
| df = df.select( |
| f.case(column("b")) |
| .when(literal(4), literal(10)) |
| .otherwise(literal(8)), |
| f.case(column("a")) |
| .when(literal("Hello"), literal("Hola")) |
| .when(literal("World"), literal("Mundo")) |
| .otherwise(literal("!!")), |
| f.case(column("a")) |
| .when(literal("Hello"), literal("Hola")) |
| .when(literal("World"), literal("Mundo")) |
| .end(), |
| ) |
| |
| result = df.collect() |
| result = result[0] |
| assert result.column(0) == pa.array([10, 8, 8]) |
| assert result.column(1) == pa.array(["Hola", "Mundo", "!!"]) |
| assert result.column(2) == pa.array(["Hola", "Mundo", None]) |
| |
| |
| def test_regr_funcs(df): |
| # test case base on |
| # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330 |
| ctx = SessionContext() |
| result = ctx.sql( |
| "select regr_slope(1,1), regr_intercept(1,1), " |
| "regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), " |
| "regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), " |
| "regr_sxy(1,1);" |
| ).collect() |
| |
| assert result[0].column(0) == pa.array([None], type=pa.float64()) |
| assert result[0].column(1) == pa.array([None], type=pa.float64()) |
| assert result[0].column(2) == pa.array([1], type=pa.float64()) |
| assert result[0].column(3) == pa.array([None], type=pa.float64()) |
| assert result[0].column(4) == pa.array([1], type=pa.float64()) |
| assert result[0].column(5) == pa.array([1], type=pa.float64()) |
| assert result[0].column(6) == pa.array([0], type=pa.float64()) |
| assert result[0].column(7) == pa.array([0], type=pa.float64()) |
| assert result[0].column(8) == pa.array([0], type=pa.float64()) |
| |
| |
| def test_first_last_value(df): |
| df = df.aggregate( |
| [], |
| [ |
| f.first_value(column("a")), |
| f.first_value(column("b")), |
| f.first_value(column("d")), |
| f.last_value(column("a")), |
| f.last_value(column("b")), |
| f.last_value(column("d")), |
| ], |
| ) |
| |
| result = df.collect() |
| result = result[0] |
| assert result.column(0) == pa.array(["Hello"]) |
| assert result.column(1) == pa.array([4]) |
| assert result.column(2) == pa.array([datetime(2022, 12, 31)]) |
| assert result.column(3) == pa.array(["!"]) |
| assert result.column(4) == pa.array([6]) |
| assert result.column(5) == pa.array([datetime(2020, 7, 2)]) |
| |
| |
| def test_binary_string_functions(df): |
| df = df.select( |
| f.encode(column("a"), literal("base64")), |
| f.decode(f.encode(column("a"), literal("base64")), literal("base64")), |
| ) |
| result = df.collect() |
| assert len(result) == 1 |
| result = result[0] |
| assert result.column(0) == pa.array(["SGVsbG8", "V29ybGQ", "IQ"]) |
| assert pa.array(result.column(1)).cast(pa.string()) == pa.array( |
| ["Hello", "World", "!"] |
| ) |