| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| import datetime as dt |
| import gzip |
| import pathlib |
| |
| import pyarrow as pa |
| import pyarrow.dataset as ds |
| import pytest |
| from datafusion import ( |
| DataFrame, |
| RuntimeEnvBuilder, |
| SessionConfig, |
| SessionContext, |
| SQLOptions, |
| column, |
| literal, |
| ) |
| |
| |
| def test_create_context_no_args(): |
| SessionContext() |
| |
| |
| def test_create_context_session_config_only(): |
| SessionContext(config=SessionConfig()) |
| |
| |
| def test_create_context_runtime_config_only(): |
| SessionContext(runtime=RuntimeEnvBuilder()) |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_runtime_configs(tmp_path, path_to_str): |
| path1 = tmp_path / "dir1" |
| path2 = tmp_path / "dir2" |
| |
| path1 = str(path1) if path_to_str else path1 |
| path2 = str(path2) if path_to_str else path2 |
| |
| runtime = RuntimeEnvBuilder().with_disk_manager_specified(path1, path2) |
| config = SessionConfig().with_default_catalog_and_schema("foo", "bar") |
| ctx = SessionContext(config, runtime) |
| assert ctx is not None |
| |
| db = ctx.catalog("foo").schema("bar") |
| assert db is not None |
| |
| |
| @pytest.mark.parametrize("path_to_str", [True, False]) |
| def test_temporary_files(tmp_path, path_to_str): |
| path = str(tmp_path) if path_to_str else tmp_path |
| |
| runtime = RuntimeEnvBuilder().with_temp_file_path(path) |
| config = SessionConfig().with_default_catalog_and_schema("foo", "bar") |
| ctx = SessionContext(config, runtime) |
| assert ctx is not None |
| |
| db = ctx.catalog("foo").schema("bar") |
| assert db is not None |
| |
| |
| def test_create_context_with_all_valid_args(): |
| runtime = RuntimeEnvBuilder().with_disk_manager_os().with_fair_spill_pool(10000000) |
| config = ( |
| SessionConfig() |
| .with_create_default_catalog_and_schema(enabled=True) |
| .with_default_catalog_and_schema("foo", "bar") |
| .with_target_partitions(1) |
| .with_information_schema(enabled=True) |
| .with_repartition_joins(enabled=False) |
| .with_repartition_aggregations(enabled=False) |
| .with_repartition_windows(enabled=False) |
| .with_parquet_pruning(enabled=False) |
| ) |
| |
| ctx = SessionContext(config, runtime) |
| |
| # verify that at least some of the arguments worked |
| ctx.catalog("foo").schema("bar") |
| with pytest.raises(KeyError): |
| ctx.catalog("datafusion") |
| |
| |
| def test_register_record_batches(ctx): |
| # create a RecordBatch and register it as memtable |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| ctx.register_record_batches("t", [[batch]]) |
| |
| assert ctx.catalog().schema().names() == {"t"} |
| |
| result = ctx.sql("SELECT a+b, a-b FROM t").collect() |
| |
| assert result[0].column(0) == pa.array([5, 7, 9]) |
| assert result[0].column(1) == pa.array([-3, -3, -3]) |
| |
| |
| def test_create_dataframe_registers_unique_table_name(ctx): |
| # create a RecordBatch and register it as memtable |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| df = ctx.create_dataframe([[batch]]) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert len(tables[0]) == 33 |
| assert tables[0].startswith("c") |
| # ensure that the rest of the table name contains |
| # only hexadecimal numbers |
| for c in tables[0][1:]: |
| assert c in "0123456789abcdef" |
| |
| |
| def test_create_dataframe_registers_with_defined_table_name(ctx): |
| # create a RecordBatch and register it as memtable |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| |
| df = ctx.create_dataframe([[batch]], name="tbl") |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert tables[0] == "tbl" |
| |
| |
| def test_from_arrow_table(ctx): |
| # create a PyArrow table |
| data = {"a": [1, 2, 3], "b": [4, 5, 6]} |
| table = pa.Table.from_pydict(data) |
| |
| # convert to DataFrame |
| df = ctx.from_arrow(table) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert df.collect()[0].num_rows == 3 |
| |
| |
| def record_batch_generator(num_batches: int): |
| schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) |
| for _i in range(num_batches): |
| yield pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema |
| ) |
| |
| |
| @pytest.mark.parametrize( |
| "source", |
| [ |
| # __arrow_c_array__ sources |
| pa.array([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]), |
| # __arrow_c_stream__ sources |
| pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), |
| pa.RecordBatchReader.from_batches( |
| pa.schema([("a", pa.int64()), ("b", pa.int64())]), record_batch_generator(1) |
| ), |
| pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), |
| ], |
| ) |
| def test_from_arrow_sources(ctx, source) -> None: |
| df = ctx.from_arrow(source) |
| assert df |
| assert isinstance(df, DataFrame) |
| assert df.schema().names == ["a", "b"] |
| assert df.count() == 3 |
| |
| |
| def test_from_arrow_table_with_name(ctx): |
| # create a PyArrow table |
| data = {"a": [1, 2, 3], "b": [4, 5, 6]} |
| table = pa.Table.from_pydict(data) |
| |
| # convert to DataFrame with optional name |
| df = ctx.from_arrow(table, name="tbl") |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert tables[0] == "tbl" |
| |
| |
| def test_from_arrow_table_empty(ctx): |
| data = {"a": [], "b": []} |
| schema = pa.schema([("a", pa.int32()), ("b", pa.string())]) |
| table = pa.Table.from_pydict(data, schema=schema) |
| |
| # convert to DataFrame |
| df = ctx.from_arrow(table) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert len(df.collect()) == 0 |
| |
| |
| def test_from_arrow_table_empty_no_schema(ctx): |
| data = {"a": [], "b": []} |
| table = pa.Table.from_pydict(data) |
| |
| # convert to DataFrame |
| df = ctx.from_arrow(table) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert len(df.collect()) == 0 |
| |
| |
| def test_from_pylist(ctx): |
| # create a dataframe from Python list |
| data = [ |
| {"a": 1, "b": 4}, |
| {"a": 2, "b": 5}, |
| {"a": 3, "b": 6}, |
| ] |
| |
| df = ctx.from_pylist(data) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert df.collect()[0].num_rows == 3 |
| |
| |
| def test_from_pydict(ctx): |
| # create a dataframe from Python dictionary |
| data = {"a": [1, 2, 3], "b": [4, 5, 6]} |
| |
| df = ctx.from_pydict(data) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert df.collect()[0].num_rows == 3 |
| |
| |
| def test_from_pandas(ctx): |
| # create a dataframe from pandas dataframe |
| pd = pytest.importorskip("pandas") |
| data = {"a": [1, 2, 3], "b": [4, 5, 6]} |
| pandas_df = pd.DataFrame(data) |
| |
| df = ctx.from_pandas(pandas_df) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert df.collect()[0].num_rows == 3 |
| |
| |
| def test_from_polars(ctx): |
| # create a dataframe from Polars dataframe |
| pd = pytest.importorskip("polars") |
| data = {"a": [1, 2, 3], "b": [4, 5, 6]} |
| polars_df = pd.DataFrame(data) |
| |
| df = ctx.from_polars(polars_df) |
| tables = list(ctx.catalog().schema().names()) |
| |
| assert df |
| assert len(tables) == 1 |
| assert isinstance(df, DataFrame) |
| assert set(df.schema().names) == {"a", "b"} |
| assert df.collect()[0].num_rows == 3 |
| |
| |
| def test_register_table(ctx, database): |
| default = ctx.catalog() |
| public = default.schema("public") |
| assert public.names() == {"csv", "csv1", "csv2"} |
| table = public.table("csv") |
| |
| ctx.register_table("csv3", table) |
| assert public.names() == {"csv", "csv1", "csv2", "csv3"} |
| |
| |
| def test_read_table(ctx, database): |
| default = ctx.catalog() |
| public = default.schema("public") |
| assert public.names() == {"csv", "csv1", "csv2"} |
| |
| table = public.table("csv") |
| table_df = ctx.read_table(table) |
| table_df.show() |
| |
| |
| def test_deregister_table(ctx, database): |
| default = ctx.catalog() |
| public = default.schema("public") |
| assert public.names() == {"csv", "csv1", "csv2"} |
| |
| ctx.deregister_table("csv") |
| assert public.names() == {"csv1", "csv2"} |
| |
| |
| def test_register_dataset(ctx): |
| # create a RecordBatch and register it as a pyarrow.dataset.Dataset |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| |
| assert ctx.catalog().schema().names() == {"t"} |
| |
| result = ctx.sql("SELECT a+b, a-b FROM t").collect() |
| |
| assert result[0].column(0) == pa.array([5, 7, 9]) |
| assert result[0].column(1) == pa.array([-3, -3, -3]) |
| |
| |
| def test_dataset_filter(ctx, capfd): |
| # create a RecordBatch and register it as a pyarrow.dataset.Dataset |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| |
| assert ctx.catalog().schema().names() == {"t"} |
| df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5") |
| |
| # Make sure the filter was pushed down in Physical Plan |
| df.explain() |
| captured = capfd.readouterr() |
| assert "filter_expr=(((a >= 2) and (a <= 3)) and (b > 5))" in captured.out |
| |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array([9]) |
| assert result[0].column(1) == pa.array([-3]) |
| |
| |
| def test_dataset_count(ctx): |
| # `datafusion-python` issue: https://github.com/apache/datafusion-python/issues/800 |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| |
| # Testing the dataframe API |
| df = ctx.table("t") |
| assert df.count() == 3 |
| |
| # Testing the SQL API |
| count = ctx.sql("SELECT COUNT(*) FROM t") |
| count = count.collect() |
| assert count[0].column(0) == pa.array([3]) |
| |
| |
| def test_pyarrow_predicate_pushdown_is_null(ctx, capfd): |
| """Ensure that pyarrow filter gets pushed down for `IsNull`""" |
| # create a RecordBatch and register it as a pyarrow.dataset.Dataset |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([7, None, 9])], |
| names=["a", "b", "c"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| # Make sure the filter was pushed down in Physical Plan |
| df = ctx.sql("SELECT a FROM t WHERE c is NULL") |
| df.explain() |
| captured = capfd.readouterr() |
| assert "filter_expr=is_null(c, {nan_is_null=false})" in captured.out |
| |
| result = df.collect() |
| assert result[0].column(0) == pa.array([2]) |
| |
| |
| def test_pyarrow_predicate_pushdown_timestamp(ctx, tmpdir, capfd): |
| """Ensure that pyarrow filter gets pushed down for timestamp""" |
| # Ref: https://github.com/apache/datafusion-python/issues/703 |
| |
| # create pyarrow dataset with no actual files |
| col_type = pa.timestamp("ns", "+00:00") |
| nyd_2000 = pa.scalar(dt.datetime(2000, 1, 1, tzinfo=dt.timezone.utc), col_type) |
| pa_dataset_fs = pa.fs.SubTreeFileSystem(str(tmpdir), pa.fs.LocalFileSystem()) |
| pa_dataset_format = pa.dataset.ParquetFileFormat() |
| pa_dataset_partition = pa.dataset.field("a") <= nyd_2000 |
| fragments = [ |
| # NOTE: we never actually make this file. |
| # Working predicate pushdown means it never gets accessed |
| pa_dataset_format.make_fragment( |
| "1.parquet", |
| filesystem=pa_dataset_fs, |
| partition_expression=pa_dataset_partition, |
| ) |
| ] |
| pa_dataset = pa.dataset.FileSystemDataset( |
| fragments, |
| pa.schema([pa.field("a", col_type)]), |
| pa_dataset_format, |
| pa_dataset_fs, |
| ) |
| |
| ctx.register_dataset("t", pa_dataset) |
| |
| # the partition for our only fragment is for a < 2000-01-01. |
| # so querying for a > 2024-01-01 should not touch any files |
| df = ctx.sql("SELECT * FROM t WHERE a > '2024-01-01T00:00:00+00:00'") |
| assert df.collect() == [] |
| |
| |
| def test_dataset_filter_nested_data(ctx): |
| # create Arrow StructArrays to test nested data types |
| data = pa.StructArray.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| batch = pa.RecordBatch.from_arrays( |
| [data], |
| names=["nested_data"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| |
| assert ctx.catalog().schema().names() == {"t"} |
| |
| df = ctx.table("t") |
| |
| # This filter will not be pushed down to DatasetExec since it |
| # isn't supported |
| df = df.filter(column("nested_data")["b"] > literal(5)).select( |
| column("nested_data")["a"] + column("nested_data")["b"], |
| column("nested_data")["a"] - column("nested_data")["b"], |
| ) |
| |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array([9]) |
| assert result[0].column(1) == pa.array([-3]) |
| |
| |
| def test_table_exist(ctx): |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset("t", dataset) |
| |
| assert ctx.table_exist("t") is True |
| |
| |
| def test_table_not_found(ctx): |
| from uuid import uuid4 |
| |
| with pytest.raises(KeyError): |
| ctx.table(f"not-found-{uuid4()}") |
| |
| |
| def test_read_json(ctx): |
| path = pathlib.Path(__file__).parent.resolve() |
| |
| # Default |
| test_data_path = path / "data_test_context" / "data.json" |
| df = ctx.read_json(test_data_path) |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array(["a", "b", "c"]) |
| assert result[0].column(1) == pa.array([1, 2, 3]) |
| |
| # Schema |
| schema = pa.schema( |
| [ |
| pa.field("A", pa.string(), nullable=True), |
| ] |
| ) |
| df = ctx.read_json(test_data_path, schema=schema) |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array(["a", "b", "c"]) |
| assert result[0].schema == schema |
| |
| # File extension |
| test_data_path = path / "data_test_context" / "data.json" |
| df = ctx.read_json(test_data_path, file_extension=".json") |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array(["a", "b", "c"]) |
| assert result[0].column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_read_json_compressed(ctx, tmp_path): |
| path = pathlib.Path(__file__).parent.resolve() |
| test_data_path = path / "data_test_context" / "data.json" |
| |
| # File compression type |
| gzip_path = tmp_path / "data.json.gz" |
| |
| with ( |
| pathlib.Path.open(test_data_path, "rb") as csv_file, |
| gzip.open(gzip_path, "wb") as gzipped_file, |
| ): |
| gzipped_file.writelines(csv_file) |
| |
| df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz") |
| result = df.collect() |
| |
| assert result[0].column(0) == pa.array(["a", "b", "c"]) |
| assert result[0].column(1) == pa.array([1, 2, 3]) |
| |
| |
| def test_read_csv(ctx): |
| csv_df = ctx.read_csv(path="testing/data/csv/aggregate_test_100.csv") |
| csv_df.select(column("c1")).show() |
| |
| |
| def test_read_csv_list(ctx): |
| csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"]) |
| expected = csv_df.count() * 2 |
| |
| double_csv_df = ctx.read_csv( |
| path=[ |
| "testing/data/csv/aggregate_test_100.csv", |
| "testing/data/csv/aggregate_test_100.csv", |
| ] |
| ) |
| actual = double_csv_df.count() |
| |
| double_csv_df.select(column("c1")).show() |
| assert actual == expected |
| |
| |
| def test_read_csv_compressed(ctx, tmp_path): |
| test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv") |
| |
| # File compression type |
| gzip_path = tmp_path / "aggregate_test_100.csv.gz" |
| |
| with ( |
| pathlib.Path.open(test_data_path, "rb") as csv_file, |
| gzip.open(gzip_path, "wb") as gzipped_file, |
| ): |
| gzipped_file.writelines(csv_file) |
| |
| csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz") |
| csv_df.select(column("c1")).show() |
| |
| |
| def test_read_parquet(ctx): |
| parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet") |
| parquet_df.show() |
| assert parquet_df is not None |
| |
| path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet" |
| parquet_df = ctx.read_parquet(path=path) |
| assert parquet_df is not None |
| |
| |
| def test_read_avro(ctx): |
| avro_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro") |
| avro_df.show() |
| assert avro_df is not None |
| |
| path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro" |
| avro_df = ctx.read_avro(path=path) |
| assert avro_df is not None |
| |
| |
| def test_create_sql_options(): |
| SQLOptions() |
| |
| |
| def test_sql_with_options_no_ddl(ctx): |
| sql = "CREATE TABLE IF NOT EXISTS valuetable AS VALUES(1,'HELLO'),(12,'DATAFUSION')" |
| ctx.sql(sql) |
| options = SQLOptions().with_allow_ddl(allow=False) |
| with pytest.raises(Exception, match="DDL"): |
| ctx.sql_with_options(sql, options=options) |
| |
| |
| def test_sql_with_options_no_dml(ctx): |
| table_name = "t" |
| batch = pa.RecordBatch.from_arrays( |
| [pa.array([1, 2, 3]), pa.array([4, 5, 6])], |
| names=["a", "b"], |
| ) |
| dataset = ds.dataset([batch]) |
| ctx.register_dataset(table_name, dataset) |
| sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);' |
| ctx.sql(sql) |
| options = SQLOptions().with_allow_dml(allow=False) |
| with pytest.raises(Exception, match="DML"): |
| ctx.sql_with_options(sql, options=options) |
| |
| |
| def test_sql_with_options_no_statements(ctx): |
| sql = "SET time zone = 1;" |
| ctx.sql(sql) |
| options = SQLOptions().with_allow_statements(allow=False) |
| with pytest.raises(Exception, match="SetVariable"): |
| ctx.sql_with_options(sql, options=options) |
| |
| |
| @pytest.fixture |
| def batch(): |
| return pa.RecordBatch.from_arrays( |
| [pa.array([4, 5, 6])], |
| names=["a"], |
| ) |
| |
| |
| def test_create_dataframe_with_global_ctx(batch): |
| ctx = SessionContext.global_ctx() |
| |
| df = ctx.create_dataframe([[batch]]) |
| |
| result = df.collect()[0].column(0) |
| |
| assert result == pa.array([4, 5, 6]) |