blob: daa2f19672ebbd716eaa08e5fd95fd2183f13ff0 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import pyarrow as pa
import pytest
from datafusion import SessionContext, column
from datafusion import functions as f
from datafusion import literal
@pytest.fixture
def df():
ctx = SessionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]])
def test_literal(df):
df = df.select(
literal(1),
literal("1"),
literal("OK"),
literal(3.14),
literal(True),
literal(b"hello world"),
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([1] * 3)
assert result.column(1) == pa.array(["1"] * 3)
assert result.column(2) == pa.array(["OK"] * 3)
assert result.column(3) == pa.array([3.14] * 3)
assert result.column(4) == pa.array([True] * 3)
assert result.column(5) == pa.array([b"hello world"] * 3)
def test_lit_arith(df):
"""
Test literals with arithmetic operations
"""
df = df.select(
literal(1) + column("b"), f.concat(column("a"), literal("!"))
)
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array([5, 6, 7])
assert result.column(1) == pa.array(["Hello!", "World!", "!!"])
def test_math_functions():
ctx = SessionContext()
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([0.1, -0.7, 0.55])], names=["value"]
)
df = ctx.create_dataframe([[batch]])
values = np.array([0.1, -0.7, 0.55])
col_v = column("value")
df = df.select(
f.abs(col_v),
f.sin(col_v),
f.cos(col_v),
f.tan(col_v),
f.asin(col_v),
f.acos(col_v),
f.exp(col_v),
f.ln(col_v + literal(pa.scalar(1))),
f.log2(col_v + literal(pa.scalar(1))),
f.log10(col_v + literal(pa.scalar(1))),
f.random(),
)
batches = df.collect()
assert len(batches) == 1
result = batches[0]
np.testing.assert_array_almost_equal(result.column(0), np.abs(values))
np.testing.assert_array_almost_equal(result.column(1), np.sin(values))
np.testing.assert_array_almost_equal(result.column(2), np.cos(values))
np.testing.assert_array_almost_equal(result.column(3), np.tan(values))
np.testing.assert_array_almost_equal(result.column(4), np.arcsin(values))
np.testing.assert_array_almost_equal(result.column(5), np.arccos(values))
np.testing.assert_array_almost_equal(result.column(6), np.exp(values))
np.testing.assert_array_almost_equal(
result.column(7), np.log(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(8), np.log2(values + 1.0)
)
np.testing.assert_array_almost_equal(
result.column(9), np.log10(values + 1.0)
)
np.testing.assert_array_less(result.column(10), np.ones_like(values))
def test_string_functions(df):
df = df.select(f.md5(column("a")), f.lower(column("a")))
result = df.collect()
assert len(result) == 1
result = result[0]
assert result.column(0) == pa.array(
[
"8b1a9953c4611296a827abf8c47804d7",
"f5a7924e621e84c9280a9a27e1bcb7f6",
"9033e0e305f247c0c3c80d0c7848c8b3",
]
)
assert result.column(1) == pa.array(["hello", "world", "!"])
def test_hash_functions(df):
exprs = [
f.digest(column("a"), literal(m))
for m in ("md5", "sha256", "sha512", "blake2s", "blake3")
]
df = df.select(*exprs)
result = df.collect()
assert len(result) == 1
result = result[0]
b = bytearray.fromhex
assert result.column(0) == pa.array(
[
b("8B1A9953C4611296A827ABF8C47804D7"),
b("F5A7924E621E84C9280A9A27E1BCB7F6"),
b("9033E0E305F247C0C3C80D0C7848C8B3"),
]
)
assert result.column(1) == pa.array(
[
b(
"185F8DB32271FE25F561A6FC938B2E26"
"4306EC304EDA518007D1764826381969"
),
b(
"78AE647DC5544D227130A0682A51E30B"
"C7777FBB6D8A8F17007463A3ECD1D524"
),
b(
"BB7208BC9B5D7C04F1236A82A0093A5E"
"33F40423D5BA8D4266F7092C3BA43B62"
),
]
)
assert result.column(2) == pa.array(
[
b(
"3615F80C9D293ED7402687F94B22D58E"
"529B8CC7916F8FAC7FDDF7FBD5AF4CF7"
"77D3D795A7A00A16BF7E7F3FB9561EE9"
"BAAE480DA9FE7A18769E71886B03F315"
),
b(
"8EA77393A42AB8FA92500FB077A9509C"
"C32BC95E72712EFA116EDAF2EDFAE34F"
"BB682EFDD6C5DD13C117E08BD4AAEF71"
"291D8AACE2F890273081D0677C16DF0F"
),
b(
"3831A6A6155E509DEE59A7F451EB3532"
"4D8F8F2DF6E3708894740F98FDEE2388"
"9F4DE5ADB0C5010DFB555CDA77C8AB5D"
"C902094C52DE3278F35A75EBC25F093A"
),
]
)
assert result.column(3) == pa.array(
[
b(
"F73A5FBF881F89B814871F46E26AD3FA"
"37CB2921C5E8561618639015B3CCBB71"
),
b(
"B792A0383FB9E7A189EC150686579532"
"854E44B71AC394831DAED169BA85CCC5"
),
b(
"27988A0E51812297C77A433F63523334"
"6AEE29A829DCF4F46E0F58F402C6CFCB"
),
]
)
assert result.column(4) == pa.array(
[
b(
"FBC2B0516EE8744D293B980779178A35"
"08850FDCFE965985782C39601B65794F"
),
b(
"BF73D18575A736E4037D45F9E316085B"
"86C19BE6363DE6AA789E13DEAACC1C4E"
),
b(
"C8D11B9F7237E4034ADBCD2005735F9B"
"C4C597C75AD89F4492BEC8F77D15F7EB"
),
]
)