| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # pylint:disable=redefined-outer-name |
| import math |
| import os |
| import random |
| import re |
| import time |
| import uuid |
| from datetime import date, datetime, timedelta |
| from decimal import Decimal |
| from pathlib import Path |
| from typing import Any |
| from urllib.parse import urlparse |
| |
| import fastavro |
| import pandas as pd |
| import pandas.testing |
| import pyarrow as pa |
| import pyarrow.compute as pc |
| import pyarrow.parquet as pq |
| import pytest |
| import pytz |
| from pyarrow.fs import S3FileSystem |
| from pydantic_core import ValidationError |
| from pyspark.sql import SparkSession |
| from pytest_lazy_fixtures import lf |
| from pytest_mock.plugin import MockerFixture |
| |
| from pyiceberg.catalog import Catalog, load_catalog |
| from pyiceberg.catalog.hive import HiveCatalog |
| from pyiceberg.catalog.sql import SqlCatalog |
| from pyiceberg.exceptions import CommitFailedException, NoSuchTableError |
| from pyiceberg.expressions import And, EqualTo, GreaterThanOrEqual, In, LessThan, Not |
| from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _dataframe_to_data_files |
| from pyiceberg.manifest import FileFormat |
| from pyiceberg.partitioning import PartitionField, PartitionSpec |
| from pyiceberg.schema import Schema |
| from pyiceberg.table import TableProperties |
| from pyiceberg.table.refs import MAIN_BRANCH |
| from pyiceberg.table.sorting import SortDirection, SortField, SortOrder |
| from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, Transform |
| from pyiceberg.types import ( |
| DateType, |
| DecimalType, |
| DoubleType, |
| IntegerType, |
| ListType, |
| LongType, |
| NestedField, |
| StringType, |
| UUIDType, |
| ) |
| from utils import TABLE_SCHEMA, _create_table |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v1_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_v1_with_null" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v1_without_data(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: |
| identifier = "default.arrow_table_v1_without_data" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_without_data]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v1_with_only_nulls(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: |
| identifier = "default.arrow_table_v1_with_only_nulls" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_only_nulls]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v1_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_v1_appended_with_null" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, 2 * [arrow_table_with_null]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v2_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_v2_with_null" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v2_without_data(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: |
| identifier = "default.arrow_table_v2_without_data" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_without_data]) |
| assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v2_with_only_nulls(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: |
| identifier = "default.arrow_table_v2_with_only_nulls" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_only_nulls]) |
| assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_v2_appended_with_null" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, 2 * [arrow_table_with_null]) |
| assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" |
| |
| |
| @pytest.fixture(scope="session", autouse=True) |
| def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_v1_v2_appended_with_null" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| with tbl.transaction() as tx: |
| tx.upgrade_table_version(format_version=2) |
| |
| tbl.append(arrow_table_with_null) |
| |
| assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_query_count(spark: SparkSession, format_version: int) -> None: |
| df = spark.table(f"default.arrow_table_v{format_version}_with_null") |
| assert df.count() == 3, "Expected 3 rows" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_query_filter_null(spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = f"default.arrow_table_v{format_version}_with_null" |
| df = spark.table(identifier) |
| for col in arrow_table_with_null.column_names: |
| assert df.where(f"{col} is null").count() == 1, f"Expected 1 row for {col}" |
| assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_query_filter_without_data(spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = f"default.arrow_table_v{format_version}_without_data" |
| df = spark.table(identifier) |
| for col in arrow_table_with_null.column_names: |
| assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" |
| assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_query_filter_only_nulls(spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = f"default.arrow_table_v{format_version}_with_only_nulls" |
| df = spark.table(identifier) |
| for col in arrow_table_with_null.column_names: |
| assert df.where(f"{col} is null").count() == 2, f"Expected 2 rows for {col}" |
| assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_query_filter_appended_null(spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = f"default.arrow_table_v{format_version}_appended_with_null" |
| df = spark.table(identifier) |
| for col in arrow_table_with_null.column_names: |
| assert df.where(f"{col} is null").count() == 2, f"Expected 1 row for {col}" |
| assert df.where(f"{col} is not null").count() == 4, f"Expected 2 rows for {col}" |
| |
| |
| @pytest.mark.integration |
| def test_query_filter_v1_v2_append_null( |
| spark: SparkSession, |
| arrow_table_with_null: pa.Table, |
| ) -> None: |
| identifier = "default.arrow_table_v1_v2_appended_with_null" |
| df = spark.table(identifier) |
| for col in arrow_table_with_null.column_names: |
| assert df.where(f"{col} is null").count() == 2, f"Expected 1 row for {col}" |
| assert df.where(f"{col} is not null").count() == 4, f"Expected 2 rows for {col}" |
| |
| |
| @pytest.mark.integration |
| def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_table_summaries" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, 2 * [arrow_table_with_null]) |
| tbl.overwrite(arrow_table_with_null) |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, summary |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| |
| operations = [row.operation for row in rows] |
| assert operations == ["append", "append", "delete", "append"] |
| |
| summaries = [row.summary for row in rows] |
| |
| file_size = int(summaries[0]["added-files-size"]) |
| assert file_size > 0 |
| |
| # Append |
| assert summaries[0] == { |
| "added-data-files": "1", |
| "added-files-size": str(file_size), |
| "added-records": "3", |
| "total-data-files": "1", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": str(file_size), |
| "total-position-deletes": "0", |
| "total-records": "3", |
| } |
| |
| # Append |
| assert summaries[1] == { |
| "added-data-files": "1", |
| "added-files-size": str(file_size), |
| "added-records": "3", |
| "total-data-files": "2", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": str(file_size * 2), |
| "total-position-deletes": "0", |
| "total-records": "6", |
| } |
| |
| # Delete |
| assert summaries[2] == { |
| "deleted-data-files": "2", |
| "deleted-records": "6", |
| "removed-files-size": str(file_size * 2), |
| "total-data-files": "0", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": "0", |
| "total-position-deletes": "0", |
| "total-records": "0", |
| } |
| |
| # Append |
| assert summaries[3] == { |
| "added-data-files": "1", |
| "added-files-size": str(file_size), |
| "added-records": "3", |
| "total-data-files": "1", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": str(file_size), |
| "total-position-deletes": "0", |
| "total-records": "3", |
| } |
| |
| |
| @pytest.mark.integration |
| def test_summaries_partial_overwrite(spark: SparkSession, session_catalog: Catalog) -> None: |
| identifier = "default.test_summaries_partial_overwrite" |
| TEST_DATA = { |
| "id": [1, 2, 3, 1, 1], |
| "name": ["AB", "CD", "EF", "CD", "EF"], |
| } |
| pa_schema = pa.schema( |
| [ |
| pa.field("id", pa.int32()), |
| pa.field("name", pa.string()), |
| ] |
| ) |
| arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema) |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, schema=pa_schema) |
| with tbl.update_spec() as txn: |
| txn.add_identity("id") |
| tbl.append(arrow_table) |
| |
| assert len(tbl.inspect.data_files()) == 3 |
| |
| tbl.delete(delete_filter="id == 1 and name = 'AB'") # partial overwrite data from 1 data file |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, summary |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| |
| operations = [row.operation for row in rows] |
| assert operations == ["append", "overwrite"] |
| |
| summaries = [row.summary for row in rows] |
| |
| file_size = int(summaries[0]["added-files-size"]) |
| assert file_size > 0 |
| |
| # APPEND |
| assert "added-files-size" in summaries[0] |
| assert "total-files-size" in summaries[0] |
| assert summaries[0] == { |
| "added-data-files": "3", |
| "added-files-size": summaries[0]["added-files-size"], |
| "added-records": "5", |
| "changed-partition-count": "3", |
| "total-data-files": "3", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": summaries[0]["total-files-size"], |
| "total-position-deletes": "0", |
| "total-records": "5", |
| } |
| # Java produces: |
| # { |
| # "added-data-files": "1", |
| # "added-files-size": "707", |
| # "added-records": "2", |
| # "app-id": "local-1743678304626", |
| # "changed-partition-count": "1", |
| # "deleted-data-files": "1", |
| # "deleted-records": "3", |
| # "engine-name": "spark", |
| # "engine-version": "3.5.5", |
| # "iceberg-version": "Apache Iceberg 1.8.1 (commit 9ce0fcf0af7becf25ad9fc996c3bad2afdcfd33d)", |
| # "removed-files-size": "693", |
| # "spark.app.id": "local-1743678304626", |
| # "total-data-files": "3", |
| # "total-delete-files": "0", |
| # "total-equality-deletes": "0", |
| # "total-files-size": "1993", |
| # "total-position-deletes": "0", |
| # "total-records": "4" |
| # } |
| files = tbl.inspect.data_files() |
| assert len(files) == 3 |
| assert "added-files-size" in summaries[1] |
| assert "removed-files-size" in summaries[1] |
| assert "total-files-size" in summaries[1] |
| assert summaries[1] == { |
| "added-data-files": "1", |
| "added-files-size": summaries[1]["added-files-size"], |
| "added-records": "2", |
| "changed-partition-count": "1", |
| "deleted-data-files": "1", |
| "deleted-records": "3", |
| "removed-files-size": summaries[1]["removed-files-size"], |
| "total-data-files": "3", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": summaries[1]["total-files-size"], |
| "total-position-deletes": "0", |
| "total-records": "4", |
| } |
| assert len(tbl.scan().to_pandas()) == 4 |
| |
| |
| @pytest.mark.integration |
| def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_data_files" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) |
| |
| tbl.append(arrow_table_with_null) |
| # should produce a DELETE entry |
| tbl.overwrite(arrow_table_with_null) |
| # Since we don't rewrite, this should produce a new manifest with an ADDED entry |
| tbl.append(arrow_table_with_null) |
| |
| rows = spark.sql( |
| f""" |
| SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count |
| FROM {identifier}.all_manifests |
| """ |
| ).collect() |
| |
| assert [row.added_data_files_count for row in rows] == [1, 0, 1, 1, 1] |
| assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0] |
| assert [row.deleted_data_files_count for row in rows] == [0, 1, 0, 0, 0] |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_object_storage_data_files( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| tbl = _create_table( |
| session_catalog=session_catalog, |
| identifier="default.object_stored", |
| properties={"format-version": format_version, TableProperties.OBJECT_STORE_ENABLED: True}, |
| data=[arrow_table_with_null], |
| ) |
| tbl.append(arrow_table_with_null) |
| |
| paths = tbl.inspect.data_files().to_pydict()["file_path"] |
| assert len(paths) == 2 |
| |
| for location in paths: |
| assert location.startswith("s3://warehouse/default/object_stored/data/") |
| parts = location.split("/") |
| assert len(parts) == 11 |
| |
| # Entropy binary directories should have been injected |
| for dir_name in parts[6:10]: |
| assert dir_name |
| assert all(c in "01" for c in dir_name) |
| |
| |
| @pytest.mark.integration |
| def test_python_writes_with_spark_snapshot_reads( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table |
| ) -> None: |
| identifier = "default.python_writes_with_spark_snapshot_reads" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) |
| |
| def get_current_snapshot_id(identifier: str) -> int: |
| return ( |
| spark.sql(f"SELECT snapshot_id FROM {identifier}.snapshots order by committed_at desc limit 1") |
| .collect()[0] |
| .snapshot_id |
| ) |
| |
| tbl.append(arrow_table_with_null) |
| assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore |
| tbl.overwrite(arrow_table_with_null) |
| assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore |
| tbl.append(arrow_table_with_null) |
| assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_python_writes_special_character_column_with_spark_reads( |
| spark: SparkSession, session_catalog: Catalog, format_version: int |
| ) -> None: |
| identifier = "default.python_writes_special_character_column_with_spark_reads" |
| column_name_with_special_character = "letter/abc" |
| TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN = { |
| column_name_with_special_character: ["a", None, "z"], |
| "id": [1, 2, 3], |
| "name": ["AB", "CD", "EF"], |
| "address": [ |
| {"street": "123", "city": "SFO", "zip": 12345, column_name_with_special_character: "a"}, |
| {"street": "456", "city": "SW", "zip": 67890, column_name_with_special_character: "b"}, |
| {"street": "789", "city": "Random", "zip": 10112, column_name_with_special_character: "c"}, |
| ], |
| } |
| pa_schema = pa.schema( |
| [ |
| pa.field(column_name_with_special_character, pa.string()), |
| pa.field("id", pa.int32()), |
| pa.field("name", pa.string()), |
| pa.field( |
| "address", |
| pa.struct( |
| [ |
| pa.field("street", pa.string()), |
| pa.field("city", pa.string()), |
| pa.field("zip", pa.int32()), |
| pa.field(column_name_with_special_character, pa.string()), |
| ] |
| ), |
| ), |
| ] |
| ) |
| arrow_table_with_special_character_column = pa.Table.from_pydict(TEST_DATA_WITH_SPECIAL_CHARACTER_COLUMN, schema=pa_schema) |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) |
| |
| tbl.append(arrow_table_with_special_character_column) |
| spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas() |
| pyiceberg_df = tbl.scan().to_pandas() |
| assert spark_df.equals(pyiceberg_df) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_python_writes_dictionary_encoded_column_with_spark_reads( |
| spark: SparkSession, session_catalog: Catalog, format_version: int |
| ) -> None: |
| identifier = "default.python_writes_dictionary_encoded_column_with_spark_reads" |
| TEST_DATA = { |
| "id": [1, 2, 3, 1, 1], |
| "name": ["AB", "CD", "EF", "CD", "EF"], |
| } |
| pa_schema = pa.schema( |
| [ |
| pa.field("id", pa.dictionary(pa.int32(), pa.int32(), False)), |
| pa.field("name", pa.dictionary(pa.int32(), pa.string(), False)), |
| ] |
| ) |
| arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema) |
| |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) |
| |
| tbl.append(arrow_table) |
| spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas() |
| pyiceberg_df = tbl.scan().to_pandas() |
| |
| # We're just interested in the content, PyIceberg actually makes a nice Categorical out of it: |
| # E AssertionError: Attributes of DataFrame.iloc[:, 1] (column name="name") are different |
| # E |
| # E Attribute "dtype" are different |
| # E [left]: object |
| # E [right]: CategoricalDtype(categories=['AB', 'CD', 'EF'], ordered=False, categories_dtype=object) |
| pandas.testing.assert_frame_equal(spark_df, pyiceberg_df, check_dtype=False, check_categorical=False) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_python_writes_with_small_and_large_types_spark_reads( |
| spark: SparkSession, session_catalog: Catalog, format_version: int |
| ) -> None: |
| identifier = "default.python_writes_with_small_and_large_types_spark_reads" |
| TEST_DATA = { |
| "foo": ["a", None, "z"], |
| "id": [1, 2, 3], |
| "name": ["AB", "CD", "EF"], |
| "address": [ |
| {"street": "123", "city": "SFO", "zip": 12345, "bar": "a"}, |
| {"street": "456", "city": "SW", "zip": 67890, "bar": "b"}, |
| {"street": "789", "city": "Random", "zip": 10112, "bar": "c"}, |
| ], |
| } |
| pa_schema = pa.schema( |
| [ |
| pa.field("foo", pa.string()), |
| pa.field("id", pa.int32()), |
| pa.field("name", pa.string()), |
| pa.field( |
| "address", |
| pa.struct( |
| [ |
| pa.field("street", pa.string()), |
| pa.field("city", pa.string()), |
| pa.field("zip", pa.int32()), |
| pa.field("bar", pa.string()), |
| ] |
| ), |
| ), |
| ] |
| ) |
| arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema) |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema) |
| |
| tbl.append(arrow_table) |
| spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas() |
| pyiceberg_df = tbl.scan().to_pandas() |
| assert spark_df.equals(pyiceberg_df) |
| arrow_table_on_read = tbl.scan().to_arrow() |
| assert arrow_table_on_read.schema == pa.schema( |
| [ |
| pa.field("foo", pa.string()), |
| pa.field("id", pa.int32()), |
| pa.field("name", pa.string()), |
| pa.field( |
| "address", |
| pa.struct( |
| [ |
| pa.field("street", pa.string()), |
| pa.field("city", pa.string()), |
| pa.field("zip", pa.int32()), |
| pa.field("bar", pa.string()), |
| ] |
| ), |
| ), |
| ] |
| ) |
| |
| |
| @pytest.mark.integration |
| def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.write_bin_pack_data_files" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) |
| |
| def get_data_files_count(identifier: str) -> int: |
| return spark.sql( |
| f""" |
| SELECT * |
| FROM {identifier}.files |
| """ |
| ).count() |
| |
| # writes 1 data file since the table is smaller than default target file size |
| assert arrow_table_with_null.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| tbl.append(arrow_table_with_null) |
| assert get_data_files_count(identifier) == 1 |
| |
| # writes 1 data file as long as table is smaller than default target file size |
| bigger_arrow_tbl = pa.concat_tables([arrow_table_with_null] * 10) |
| assert bigger_arrow_tbl.nbytes < TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT |
| tbl.overwrite(bigger_arrow_tbl) |
| assert get_data_files_count(identifier) == 1 |
| |
| # writes multiple data files once target file size is overridden |
| target_file_size = arrow_table_with_null.nbytes |
| tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction() |
| assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES) |
| assert target_file_size < bigger_arrow_tbl.nbytes |
| tbl.overwrite(bigger_arrow_tbl) |
| assert get_data_files_count(identifier) == 10 |
| |
| # writes half the number of data files when target file size doubles |
| target_file_size = arrow_table_with_null.nbytes * 2 |
| tbl = tbl.transaction().set_properties({TableProperties.WRITE_TARGET_FILE_SIZE_BYTES: target_file_size}).commit_transaction() |
| assert str(target_file_size) == tbl.properties.get(TableProperties.WRITE_TARGET_FILE_SIZE_BYTES) |
| assert target_file_size < bigger_arrow_tbl.nbytes |
| tbl.overwrite(bigger_arrow_tbl) |
| assert get_data_files_count(identifier) == 5 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| @pytest.mark.parametrize( |
| "properties, expected_compression_name", |
| [ |
| # REST catalog uses Zstandard by default: https://github.com/apache/iceberg/pull/8593 |
| ({}, "ZSTD"), |
| ({"write.parquet.compression-codec": "uncompressed"}, "UNCOMPRESSED"), |
| ({"write.parquet.compression-codec": "gzip", "write.parquet.compression-level": "1"}, "GZIP"), |
| ({"write.parquet.compression-codec": "zstd", "write.parquet.compression-level": "1"}, "ZSTD"), |
| ({"write.parquet.compression-codec": "snappy"}, "SNAPPY"), |
| ], |
| ) |
| def test_write_parquet_compression_properties( |
| spark: SparkSession, |
| session_catalog: Catalog, |
| arrow_table_with_null: pa.Table, |
| format_version: int, |
| properties: dict[str, Any], |
| expected_compression_name: str, |
| ) -> None: |
| identifier = "default.write_parquet_compression_properties" |
| |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version, **properties}, [arrow_table_with_null]) |
| |
| data_file_paths = [task.file.file_path for task in tbl.scan().plan_files()] |
| |
| fs = S3FileSystem( |
| endpoint_override=session_catalog.properties["s3.endpoint"], |
| access_key=session_catalog.properties["s3.access-key-id"], |
| secret_key=session_catalog.properties["s3.secret-access-key"], |
| ) |
| uri = urlparse(data_file_paths[0]) |
| with fs.open_input_file(f"{uri.netloc}{uri.path}") as f: |
| parquet_metadata = pq.read_metadata(f) |
| compression = parquet_metadata.row_group(0).column(0).compression |
| |
| assert compression == expected_compression_name |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize( |
| "properties, expected_kwargs", |
| [ |
| ({"write.parquet.page-size-bytes": "42"}, {"data_page_size": 42}), |
| ({"write.parquet.dict-size-bytes": "42"}, {"dictionary_pagesize_limit": 42}), |
| ], |
| ) |
| def test_write_parquet_other_properties( |
| mocker: MockerFixture, |
| spark: SparkSession, |
| session_catalog: Catalog, |
| arrow_table_with_null: pa.Table, |
| properties: dict[str, Any], |
| expected_kwargs: dict[str, Any], |
| ) -> None: |
| identifier = "default.test_write_parquet_other_properties" |
| |
| # The properties we test cannot be checked on the resulting Parquet file, so we spy on the ParquetWriter call instead |
| ParquetWriter = mocker.spy(pq, "ParquetWriter") |
| _create_table(session_catalog, identifier, properties, [arrow_table_with_null]) |
| |
| call_kwargs = ParquetWriter.call_args[1] |
| for key, value in expected_kwargs.items(): |
| assert call_kwargs.get(key) == value |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize( |
| "properties", |
| [ |
| {"write.parquet.row-group-size-bytes": "42"}, |
| {"write.parquet.bloom-filter-enabled.column.bool": "42"}, |
| {"write.parquet.bloom-filter-max-bytes": "42"}, |
| ], |
| ) |
| def test_write_parquet_unsupported_properties( |
| spark: SparkSession, |
| session_catalog: Catalog, |
| arrow_table_with_null: pa.Table, |
| properties: dict[str, str], |
| ) -> None: |
| identifier = "default.write_parquet_unsupported_properties" |
| |
| tbl = _create_table(session_catalog, identifier, properties, []) |
| with pytest.warns(UserWarning, match=r"Parquet writer option.*"): |
| tbl.append(arrow_table_with_null) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_spark_writes_orc_pyiceberg_reads(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: |
| """Test that ORC files written by Spark can be read by PyIceberg.""" |
| identifier = f"default.spark_writes_orc_pyiceberg_reads_v{format_version}" |
| |
| # Create test data |
| test_data = [ |
| (1, "Alice", 25, True), |
| (2, "Bob", 30, False), |
| (3, "Charlie", 35, True), |
| (4, "David", 28, True), |
| (5, "Eve", 32, False), |
| ] |
| |
| # Create Spark DataFrame |
| spark_df = spark.createDataFrame(test_data, ["id", "name", "age", "is_active"]) |
| |
| # Ensure a clean slate to avoid replacing a v2 table with v1 |
| spark.sql(f"DROP TABLE IF EXISTS {identifier}") |
| |
| # Create table with Spark using ORC format and desired format-version |
| spark_df.writeTo(identifier).using("iceberg").tableProperty("write.format.default", "orc").tableProperty( |
| "format-version", str(format_version) |
| ).createOrReplace() |
| |
| # Write data with ORC format using Spark |
| spark_df.writeTo(identifier).using("iceberg").append() |
| |
| # Read with PyIceberg - this is the main focus of our validation |
| tbl = session_catalog.load_table(identifier) |
| pyiceberg_df = tbl.scan().to_pandas() |
| |
| # Verify PyIceberg results have the expected number of rows |
| assert len(pyiceberg_df) == 10 # 5 rows from create + 5 rows from append |
| |
| # Verify PyIceberg column names |
| assert list(pyiceberg_df.columns) == ["id", "name", "age", "is_active"] |
| |
| # Verify PyIceberg data integrity - check the actual data values |
| expected_data = [ |
| (1, "Alice", 25, True), |
| (2, "Bob", 30, False), |
| (3, "Charlie", 35, True), |
| (4, "David", 28, True), |
| (5, "Eve", 32, False), |
| ] |
| |
| # Verify PyIceberg results contain the expected data (appears twice due to create + append) |
| pyiceberg_data = list( |
| zip(pyiceberg_df["id"], pyiceberg_df["name"], pyiceberg_df["age"], pyiceberg_df["is_active"], strict=True) |
| ) |
| assert pyiceberg_data == expected_data + expected_data # Data should appear twice |
| |
| # Verify PyIceberg data types are correct |
| assert pyiceberg_df["id"].dtype == "int64" |
| assert pyiceberg_df["name"].dtype == "object" # string |
| assert pyiceberg_df["age"].dtype == "int64" |
| assert pyiceberg_df["is_active"].dtype == "bool" |
| |
| # Cross-validate with Spark to ensure consistency (ensure deterministic ordering) |
| spark_result = spark.sql(f"SELECT * FROM {identifier}").toPandas() |
| sort_cols = ["id", "name", "age", "is_active"] |
| spark_result = spark_result.sort_values(by=sort_cols).reset_index(drop=True) |
| pyiceberg_df = pyiceberg_df.sort_values(by=sort_cols).reset_index(drop=True) |
| pandas.testing.assert_frame_equal(spark_result, pyiceberg_df, check_dtype=False) |
| |
| # Verify the files are actually ORC format |
| files = list(tbl.scan().plan_files()) |
| assert len(files) > 0 |
| for file_task in files: |
| assert file_task.file.file_format == FileFormat.ORC |
| |
| |
| @pytest.mark.integration |
| def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.arrow_data_files" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) |
| |
| with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): |
| tbl.overwrite("not a df") |
| |
| with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): |
| tbl.append("not a df") |
| |
| |
| @pytest.mark.integration |
| def test_summaries_with_only_nulls( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_without_data: pa.Table, arrow_table_with_only_nulls: pa.Table |
| ) -> None: |
| identifier = "default.arrow_table_summaries_with_only_nulls" |
| tbl = _create_table( |
| session_catalog, identifier, {"format-version": "1"}, [arrow_table_without_data, arrow_table_with_only_nulls] |
| ) |
| tbl.overwrite(arrow_table_without_data) |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, summary |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| |
| operations = [row.operation for row in rows] |
| assert operations == ["append", "append", "delete", "append"] |
| |
| summaries = [row.summary for row in rows] |
| |
| file_size = int(summaries[1]["added-files-size"]) |
| assert file_size > 0 |
| |
| assert summaries[0] == { |
| "total-data-files": "0", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": "0", |
| "total-position-deletes": "0", |
| "total-records": "0", |
| } |
| |
| assert summaries[1] == { |
| "added-data-files": "1", |
| "added-files-size": str(file_size), |
| "added-records": "2", |
| "total-data-files": "1", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": str(file_size), |
| "total-position-deletes": "0", |
| "total-records": "2", |
| } |
| |
| assert summaries[2] == { |
| "deleted-data-files": "1", |
| "deleted-records": "2", |
| "removed-files-size": str(file_size), |
| "total-data-files": "0", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": "0", |
| "total-position-deletes": "0", |
| "total-records": "0", |
| } |
| |
| assert summaries[3] == { |
| "total-data-files": "0", |
| "total-delete-files": "0", |
| "total-equality-deletes": "0", |
| "total-files-size": "0", |
| "total-position-deletes": "0", |
| "total-records": "0", |
| } |
| |
| |
| @pytest.mark.integration |
| def test_duckdb_url_import(warehouse: Path, arrow_table_with_null: pa.Table) -> None: |
| os.environ["TZ"] = "Etc/UTC" |
| time.tzset() |
| tz = pytz.timezone(os.environ["TZ"]) |
| |
| catalog = SqlCatalog("test_sql_catalog", uri="sqlite:///:memory:", warehouse=f"/{warehouse}") |
| catalog.create_namespace("default") |
| |
| identifier = "default.arrow_table_v1_with_null" |
| tbl = _create_table(catalog, identifier, {}, [arrow_table_with_null]) |
| location = tbl.metadata_location |
| |
| import duckdb |
| |
| duckdb.sql("INSTALL iceberg; LOAD iceberg;") |
| result = duckdb.sql( |
| f""" |
| SELECT * |
| FROM iceberg_scan('{location}') |
| """ |
| ).fetchall() |
| |
| assert result == [ |
| ( |
| False, |
| "a", |
| "aaaaaaaaaaaaaaaaaaaaaa", |
| 1, |
| 1, |
| 0.0, |
| 0.0, |
| datetime(2023, 1, 1, 19, 25), |
| datetime(2023, 1, 1, 19, 25, tzinfo=tz), |
| date(2023, 1, 1), |
| b"\x01", |
| b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", |
| ), |
| (None, None, None, None, None, None, None, None, None, None, None, None), |
| ( |
| True, |
| "z", |
| "zzzzzzzzzzzzzzzzzzzzzz", |
| 9, |
| 9, |
| 0.8999999761581421, |
| 0.9, |
| datetime(2023, 3, 1, 19, 25), |
| datetime(2023, 3, 1, 19, 25, tzinfo=tz), |
| date(2023, 3, 1), |
| b"\x12", |
| b"\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11", |
| ), |
| ] |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None: |
| identifier = f"default.arrow_write_data_and_evolve_schema_v{format_version}" |
| |
| try: |
| session_catalog.drop_table(identifier=identifier) |
| except NoSuchTableError: |
| pass |
| |
| pa_table = pa.Table.from_pydict( |
| { |
| "foo": ["a", None, "z"], |
| }, |
| schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), |
| ) |
| |
| tbl = session_catalog.create_table( |
| identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} |
| ) |
| |
| pa_table_with_column = pa.Table.from_pydict( |
| { |
| "foo": ["a", None, "z"], |
| "bar": [19, None, 25], |
| }, |
| schema=pa.schema( |
| [ |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=True), |
| ] |
| ), |
| ) |
| |
| with tbl.transaction() as txn: |
| with txn.update_schema() as schema_txn: |
| schema_txn.union_by_name(pa_table_with_column.schema) |
| |
| txn.append(pa_table_with_column) |
| txn.overwrite(pa_table_with_column) |
| txn.delete("foo = 'a'") |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| @pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) |
| def test_create_table_transaction(catalog: Catalog, format_version: int) -> None: |
| identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}" |
| |
| try: |
| catalog.drop_table(identifier=identifier) |
| except NoSuchTableError: |
| pass |
| |
| pa_table = pa.Table.from_pydict( |
| { |
| "foo": ["a", None, "z"], |
| }, |
| schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]), |
| ) |
| |
| pa_table_with_column = pa.Table.from_pydict( |
| { |
| "foo": ["a", None, "z"], |
| "bar": [19, None, 25], |
| }, |
| schema=pa.schema( |
| [ |
| pa.field("foo", pa.string(), nullable=True), |
| pa.field("bar", pa.int32(), nullable=True), |
| ] |
| ), |
| ) |
| |
| with catalog.create_table_transaction( |
| identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)} |
| ) as txn: |
| with txn.update_snapshot().fast_append() as snapshot_update: |
| for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io): |
| snapshot_update.append_data_file(data_file) |
| |
| with txn.update_schema() as schema_txn: |
| schema_txn.union_by_name(pa_table_with_column.schema) |
| |
| with txn.update_snapshot().fast_append() as snapshot_update: |
| for data_file in _dataframe_to_data_files( |
| table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io |
| ): |
| snapshot_update.append_data_file(data_file) |
| |
| tbl = catalog.load_table(identifier=identifier) |
| assert tbl.format_version == format_version |
| assert len(tbl.scan().to_arrow()) == 6 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| @pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) |
| def test_create_table_with_non_default_values(catalog: Catalog, table_schema_with_all_types: Schema, format_version: int) -> None: |
| identifier = f"default.arrow_create_table_transaction_with_non_default_values_{catalog.name}_{format_version}" |
| identifier_ref = f"default.arrow_create_table_transaction_with_non_default_values_ref_{catalog.name}_{format_version}" |
| |
| try: |
| catalog.drop_table(identifier=identifier) |
| except NoSuchTableError: |
| pass |
| |
| try: |
| catalog.drop_table(identifier=identifier_ref) |
| except NoSuchTableError: |
| pass |
| |
| iceberg_spec = PartitionSpec( |
| *[PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="integer_partition")] |
| ) |
| |
| sort_order = SortOrder(*[SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC)]) |
| |
| txn = catalog.create_table_transaction( |
| identifier=identifier, |
| schema=table_schema_with_all_types, |
| partition_spec=iceberg_spec, |
| sort_order=sort_order, |
| properties={"format-version": format_version}, |
| ) |
| txn.commit_transaction() |
| |
| tbl = catalog.load_table(identifier) |
| |
| tbl_ref = catalog.create_table( |
| identifier=identifier_ref, |
| schema=table_schema_with_all_types, |
| partition_spec=iceberg_spec, |
| sort_order=sort_order, |
| properties={"format-version": format_version}, |
| ) |
| |
| assert tbl.format_version == tbl_ref.format_version |
| assert tbl.schema() == tbl_ref.schema() |
| assert tbl.schemas() == tbl_ref.schemas() |
| assert tbl.spec() == tbl_ref.spec() |
| assert tbl.specs() == tbl_ref.specs() |
| assert tbl.sort_order() == tbl_ref.sort_order() |
| assert tbl.sort_orders() == tbl_ref.sort_orders() |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_table_properties_int_value( |
| session_catalog: Catalog, |
| arrow_table_with_null: pa.Table, |
| format_version: int, |
| ) -> None: |
| # table properties can be set to int, but still serialized to string |
| property_with_int = {"property_name": 42} |
| identifier = "default.test_table_properties_int_value" |
| |
| tbl = _create_table( |
| session_catalog, identifier, {"format-version": format_version, **property_with_int}, [arrow_table_with_null] |
| ) |
| assert isinstance(tbl.properties["property_name"], str) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_table_properties_raise_for_none_value( |
| session_catalog: Catalog, |
| arrow_table_with_null: pa.Table, |
| format_version: int, |
| ) -> None: |
| property_with_none = {"property_name": None} |
| identifier = "default.test_table_properties_raise_for_none_value" |
| |
| with pytest.raises(ValidationError) as exc_info: |
| _ = _create_table( |
| session_catalog, identifier, {"format-version": format_version, **property_with_none}, [arrow_table_with_null] |
| ) |
| assert "None type is not a supported value in properties: property_name" in str(exc_info.value) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_inspect_snapshots( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = "default.table_metadata_snapshots" |
| tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}) |
| |
| tbl.append(arrow_table_with_null) |
| # should produce a DELETE entry |
| tbl.overwrite(arrow_table_with_null) |
| # Since we don't rewrite, this should produce a new manifest with an ADDED entry |
| tbl.append(arrow_table_with_null) |
| |
| df = tbl.inspect.snapshots() |
| |
| assert df.column_names == [ |
| "committed_at", |
| "snapshot_id", |
| "parent_id", |
| "operation", |
| "manifest_list", |
| "summary", |
| ] |
| |
| for committed_at in df["committed_at"]: |
| assert isinstance(committed_at.as_py(), datetime) |
| |
| for snapshot_id in df["snapshot_id"]: |
| assert isinstance(snapshot_id.as_py(), int) |
| |
| assert df["parent_id"][0].as_py() is None |
| assert df["parent_id"][1:].to_pylist() == df["snapshot_id"][:-1].to_pylist() |
| |
| assert [operation.as_py() for operation in df["operation"]] == ["append", "delete", "append", "append"] |
| |
| for manifest_list in df["manifest_list"]: |
| assert manifest_list.as_py().startswith("s3://") |
| |
| file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size")) |
| assert file_size > 0 |
| |
| # Append |
| assert df["summary"][0].as_py() == [ |
| ("added-files-size", str(file_size)), |
| ("added-data-files", "1"), |
| ("added-records", "3"), |
| ("total-data-files", "1"), |
| ("total-delete-files", "0"), |
| ("total-records", "3"), |
| ("total-files-size", str(file_size)), |
| ("total-position-deletes", "0"), |
| ("total-equality-deletes", "0"), |
| ] |
| |
| # Delete |
| assert df["summary"][1].as_py() == [ |
| ("removed-files-size", str(file_size)), |
| ("deleted-data-files", "1"), |
| ("deleted-records", "3"), |
| ("total-data-files", "0"), |
| ("total-delete-files", "0"), |
| ("total-records", "0"), |
| ("total-files-size", "0"), |
| ("total-position-deletes", "0"), |
| ("total-equality-deletes", "0"), |
| ] |
| |
| lhs = spark.table(f"{identifier}.snapshots").toPandas() |
| rhs = df.to_pandas() |
| for column in df.column_names: |
| for left, right in zip(lhs[column].to_list(), rhs[column].to_list(), strict=True): |
| if column == "summary": |
| # Arrow returns a list of tuples, instead of a dict |
| right = dict(right) |
| |
| if isinstance(left, float) and math.isnan(left) and isinstance(right, float) and math.isnan(right): |
| # NaN != NaN in Python |
| continue |
| |
| assert left == right, f"Difference in column {column}: {left} != {right}" |
| |
| |
| @pytest.mark.integration |
| def test_write_within_transaction(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.write_in_open_transaction" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) |
| |
| def get_metadata_entries_count(identifier: str) -> int: |
| return spark.sql( |
| f""" |
| SELECT * |
| FROM {identifier}.metadata_log_entries |
| """ |
| ).count() |
| |
| # one metadata entry from table creation |
| assert get_metadata_entries_count(identifier) == 1 |
| |
| # one more metadata entry from transaction |
| with tbl.transaction() as tx: |
| tx.set_properties({"test": "1"}) |
| tx.append(arrow_table_with_null) |
| assert get_metadata_entries_count(identifier) == 2 |
| |
| # two more metadata entries added from two separate transactions |
| tbl.transaction().set_properties({"test": "2"}).commit_transaction() |
| tbl.append(arrow_table_with_null) |
| assert get_metadata_entries_count(identifier) == 4 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_hive_catalog_storage_descriptor( |
| session_catalog_hive: HiveCatalog, |
| pa_schema: pa.Schema, |
| arrow_table_with_null: pa.Table, |
| spark: SparkSession, |
| format_version: int, |
| ) -> None: |
| tbl = _create_table( |
| session_catalog_hive, "default.test_storage_descriptor", {"format-version": format_version}, [arrow_table_with_null] |
| ) |
| |
| # check if pyiceberg can read the table |
| assert len(tbl.scan().to_arrow()) == 3 |
| # check if spark can read the table |
| assert spark.sql("SELECT * FROM hive.default.test_storage_descriptor").count() == 3 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_hive_catalog_storage_descriptor_has_changed( |
| session_catalog_hive: HiveCatalog, |
| pa_schema: pa.Schema, |
| arrow_table_with_null: pa.Table, |
| spark: SparkSession, |
| format_version: int, |
| ) -> None: |
| tbl = _create_table( |
| session_catalog_hive, "default.test_storage_descriptor", {"format-version": format_version}, [arrow_table_with_null] |
| ) |
| |
| with tbl.transaction() as tx: |
| with tx.update_schema() as schema: |
| schema.update_column("string_long", doc="this is string_long") |
| schema.update_column("binary", doc="this is binary") |
| |
| with session_catalog_hive._client as open_client: |
| hive_table = session_catalog_hive._get_hive_table(open_client, "default", "test_storage_descriptor") |
| assert "this is string_long" in str(hive_table.sd) |
| assert "this is binary" in str(hive_table.sd) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("catalog", [lf("session_catalog_hive"), lf("session_catalog")]) |
| def test_sanitize_character_partitioned(catalog: Catalog) -> None: |
| table_name = "default.test_table_partitioned_sanitized_character" |
| try: |
| catalog.drop_table(table_name) |
| except NoSuchTableError: |
| pass |
| |
| tbl = _create_table( |
| session_catalog=catalog, |
| identifier=table_name, |
| schema=Schema(NestedField(field_id=1, name="some.id", type=IntegerType(), required=True)), |
| partition_spec=PartitionSpec( |
| PartitionField(source_id=1, field_id=1000, name="some.id_identity", transform=IdentityTransform()) |
| ), |
| data=[pa.Table.from_arrays([range(22)], schema=pa.schema([pa.field("some.id", pa.int32(), nullable=False)]))], |
| ) |
| |
| assert len(tbl.scan().to_arrow()) == 22 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("catalog", [lf("session_catalog")]) |
| def test_sanitize_character_partitioned_avro_bug(catalog: Catalog) -> None: |
| table_name = "default.test_table_partitioned_sanitized_character_avro" |
| try: |
| catalog.drop_table(table_name) |
| except NoSuchTableError: |
| pass |
| |
| schema = Schema( |
| NestedField(id=1, name="😎", field_type=StringType(), required=False), |
| ) |
| |
| partition_spec = PartitionSpec( |
| PartitionField( |
| source_id=1, |
| field_id=1001, |
| transform=IdentityTransform(), |
| name="😎", |
| ) |
| ) |
| |
| tbl = _create_table( |
| session_catalog=catalog, |
| identifier=table_name, |
| schema=schema, |
| partition_spec=partition_spec, |
| data=[ |
| pa.Table.from_arrays( |
| [pa.array([str(i) for i in range(22)])], schema=pa.schema([pa.field("😎", pa.string(), nullable=False)]) |
| ) |
| ], |
| ) |
| |
| assert len(tbl.scan().to_arrow()) == 22 |
| |
| # verify that we can read the table with DuckDB |
| import duckdb |
| |
| location = tbl.metadata_location |
| duckdb.sql("INSTALL iceberg; LOAD iceberg;") |
| # Configure S3 settings for DuckDB to match the catalog configuration |
| duckdb.sql("SET s3_endpoint='localhost:9000';") |
| duckdb.sql("SET s3_access_key_id='admin';") |
| duckdb.sql("SET s3_secret_access_key='password';") |
| duckdb.sql("SET s3_use_ssl=false;") |
| duckdb.sql("SET s3_url_style='path';") |
| result = duckdb.sql(f"SELECT * FROM iceberg_scan('{location}')").fetchall() |
| assert len(result) == 22 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_cross_platform_special_character_compatibility( |
| spark: SparkSession, session_catalog: Catalog, format_version: int |
| ) -> None: |
| """Test cross-platform compatibility with special characters in column names.""" |
| identifier = "default.test_cross_platform_special_characters" |
| |
| # Test various special characters that need sanitization |
| special_characters = [ |
| "😎", # emoji - Java produces _xD83D_xDE0E, Python produces _x1F60E |
| "a.b", # dot - both should produce a_x2Eb |
| "a#b", # hash - both should produce a_x23b |
| "9x", # starts with digit - both should produce _9x |
| "x_", # valid - should remain unchanged |
| "letter/abc", # slash - both should produce letter_x2Fabc |
| ] |
| |
| for i, special_char in enumerate(special_characters): |
| table_name = f"{identifier}_{format_version}_{i}" |
| pyiceberg_table_name = f"{identifier}_pyiceberg_{format_version}_{i}" |
| |
| try: |
| session_catalog.drop_table(table_name) |
| except Exception: |
| pass |
| try: |
| session_catalog.drop_table(pyiceberg_table_name) |
| except Exception: |
| pass |
| |
| try: |
| # Test 1: Spark writes, PyIceberg reads |
| spark_df = spark.createDataFrame([("test_value",)], [special_char]) |
| spark_df.writeTo(table_name).using("iceberg").createOrReplace() |
| |
| # Read with PyIceberg table scan |
| tbl = session_catalog.load_table(table_name) |
| pyiceberg_df = tbl.scan().to_pandas() |
| assert len(pyiceberg_df) == 1 |
| assert special_char in pyiceberg_df.columns |
| assert pyiceberg_df.iloc[0][special_char] == "test_value" |
| |
| # Test 2: PyIceberg writes, Spark reads |
| from pyiceberg.schema import Schema |
| from pyiceberg.types import NestedField, StringType |
| |
| schema = Schema(NestedField(field_id=1, name=special_char, field_type=StringType(), required=True)) |
| |
| tbl_pyiceberg = session_catalog.create_table( |
| identifier=pyiceberg_table_name, schema=schema, properties={"format-version": str(format_version)} |
| ) |
| |
| import pyarrow as pa |
| |
| # Create PyArrow schema with required field to match Iceberg schema |
| pa_schema = pa.schema([pa.field(special_char, pa.string(), nullable=False)]) |
| data = pa.Table.from_pydict({special_char: ["pyiceberg_value"]}, schema=pa_schema) |
| tbl_pyiceberg.append(data) |
| |
| # Read with Spark |
| spark_df_read = spark.table(pyiceberg_table_name) |
| spark_result = spark_df_read.collect() |
| |
| # Verify data integrity |
| assert len(spark_result) == 1 |
| assert special_char in spark_df_read.columns |
| assert spark_result[0][special_char] == "pyiceberg_value" |
| |
| finally: |
| try: |
| session_catalog.drop_table(table_name) |
| except Exception: |
| pass |
| try: |
| session_catalog.drop_table(pyiceberg_table_name) |
| except Exception: |
| pass |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = "default.test_table_write_subset_of_schema" |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table_with_null]) |
| arrow_table_without_some_columns = arrow_table_with_null.combine_chunks().drop(arrow_table_with_null.column_names[0]) |
| assert len(arrow_table_without_some_columns.columns) < len(arrow_table_with_null.columns) |
| tbl.overwrite(arrow_table_without_some_columns) |
| tbl.append(arrow_table_without_some_columns) |
| # overwrite and then append should produce twice the data |
| assert len(tbl.scan().to_arrow()) == len(arrow_table_without_some_columns) * 2 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| @pytest.mark.filterwarnings("ignore:Delete operation did not match any records") |
| def test_table_write_out_of_order_schema(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| identifier = "default.test_table_write_out_of_order_schema" |
| # rotate the schema fields by 1 |
| fields = list(arrow_table_with_null.schema) |
| rotated_fields = fields[1:] + fields[:1] |
| rotated_schema = pa.schema(rotated_fields) |
| assert arrow_table_with_null.schema != rotated_schema |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=rotated_schema) |
| |
| tbl.overwrite(arrow_table_with_null) |
| |
| tbl.append(arrow_table_with_null) |
| # overwrite and then append should produce twice the data |
| assert len(tbl.scan().to_arrow()) == len(arrow_table_with_null) * 2 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_table_write_schema_with_valid_nullability_diff( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = "default.test_table_write_with_valid_nullability_diff" |
| table_schema = Schema( |
| NestedField(field_id=1, name="long", field_type=LongType(), required=False), |
| ) |
| other_schema = pa.schema( |
| ( |
| pa.field("long", pa.int64(), nullable=False), # can support writing required pyarrow field to optional Iceberg field |
| ) |
| ) |
| arrow_table = pa.Table.from_pydict( |
| { |
| "long": [1, 9], |
| }, |
| schema=other_schema, |
| ) |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [arrow_table], schema=table_schema) |
| # table's long field should cast to be optional on read |
| written_arrow_table = tbl.scan().to_arrow() |
| assert written_arrow_table == arrow_table.cast(pa.schema((pa.field("long", pa.int64(), nullable=True),))) |
| lhs = spark.table(f"{identifier}").toPandas() |
| rhs = written_arrow_table.to_pandas() |
| |
| for column in written_arrow_table.column_names: |
| for left, right in zip(lhs[column].to_list(), rhs[column].to_list(), strict=True): |
| assert left == right |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_table_write_schema_with_valid_upcast( |
| spark: SparkSession, |
| session_catalog: Catalog, |
| format_version: int, |
| table_schema_with_promoted_types: Schema, |
| pyarrow_schema_with_promoted_types: pa.Schema, |
| pyarrow_table_with_promoted_types: pa.Table, |
| ) -> None: |
| identifier = "default.test_table_write_with_valid_upcast" |
| |
| tbl = _create_table( |
| session_catalog, |
| identifier, |
| {"format-version": format_version}, |
| [pyarrow_table_with_promoted_types], |
| schema=table_schema_with_promoted_types, |
| ) |
| # table's long field should cast to long on read |
| written_arrow_table = tbl.scan().to_arrow() |
| assert written_arrow_table == pyarrow_table_with_promoted_types.cast( |
| pa.schema( |
| ( |
| pa.field("long", pa.int64(), nullable=True), |
| pa.field("list", pa.list_(pa.int64()), nullable=False), |
| pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False), |
| pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double |
| pa.field("uuid", pa.uuid(), nullable=True), |
| ) |
| ) |
| ) |
| lhs = spark.table(f"{identifier}").toPandas() |
| rhs = written_arrow_table.to_pandas() |
| |
| for column in written_arrow_table.column_names: |
| for left, right in zip(lhs[column].to_list(), rhs[column].to_list(), strict=True): |
| if column == "map": |
| # Arrow returns a list of tuples, instead of a dict |
| right = dict(right) |
| if column == "list": |
| # Arrow returns an array, convert to list for equality check |
| left, right = list(left), list(right) |
| if column == "uuid": |
| # Spark Iceberg represents UUID as hex string like '715a78ef-4e53-4089-9bf9-3ad0ee9bf545' |
| # whereas PyIceberg represents UUID as bytes on read |
| left, right = left.replace("-", ""), right.hex() |
| assert left == right |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_write_all_timestamp_precision( |
| mocker: MockerFixture, |
| spark: SparkSession, |
| session_catalog: Catalog, |
| format_version: int, |
| arrow_table_schema_with_all_timestamp_precisions: pa.Schema, |
| arrow_table_with_all_timestamp_precisions: pa.Table, |
| arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, |
| ) -> None: |
| identifier = "default.table_all_timestamp_precision" |
| mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) |
| |
| tbl = _create_table( |
| session_catalog, |
| identifier, |
| {"format-version": format_version}, |
| data=[arrow_table_with_all_timestamp_precisions], |
| schema=arrow_table_schema_with_all_timestamp_precisions, |
| ) |
| tbl.overwrite(arrow_table_with_all_timestamp_precisions) |
| written_arrow_table = tbl.scan().to_arrow() |
| |
| assert written_arrow_table.schema == arrow_table_schema_with_all_microseconds_timestamp_precisions |
| assert written_arrow_table == arrow_table_with_all_timestamp_precisions.cast( |
| arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False |
| ) |
| lhs = spark.table(f"{identifier}").toPandas() |
| rhs = written_arrow_table.to_pandas() |
| |
| for column in written_arrow_table.column_names: |
| for left, right in zip(lhs[column].to_list(), rhs[column].to_list(), strict=True): |
| if pd.isnull(left): |
| assert pd.isnull(right) |
| else: |
| # Check only upto microsecond precision since Spark loaded dtype is timezone unaware |
| # and supports upto microsecond precision |
| assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}" |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_merge_manifests(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| tbl_a = _create_table( |
| session_catalog, |
| "default.merge_manifest_a", |
| {"commit.manifest-merge.enabled": "true", "commit.manifest.min-count-to-merge": "1", "format-version": format_version}, |
| [], |
| ) |
| tbl_b = _create_table( |
| session_catalog, |
| "default.merge_manifest_b", |
| { |
| "commit.manifest-merge.enabled": "true", |
| "commit.manifest.min-count-to-merge": "1", |
| "commit.manifest.target-size-bytes": "1", |
| "format-version": format_version, |
| }, |
| [], |
| ) |
| tbl_c = _create_table( |
| session_catalog, |
| "default.merge_manifest_c", |
| {"commit.manifest.min-count-to-merge": "1", "format-version": format_version}, |
| [], |
| ) |
| |
| # tbl_a should merge all manifests into 1 |
| tbl_a.append(arrow_table_with_null) |
| tbl_a.append(arrow_table_with_null) |
| tbl_a.append(arrow_table_with_null) |
| |
| # tbl_b should not merge any manifests because the target size is too small |
| tbl_b.append(arrow_table_with_null) |
| tbl_b.append(arrow_table_with_null) |
| tbl_b.append(arrow_table_with_null) |
| |
| # tbl_c should not merge any manifests because merging is disabled |
| tbl_c.append(arrow_table_with_null) |
| tbl_c.append(arrow_table_with_null) |
| tbl_c.append(arrow_table_with_null) |
| |
| assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore |
| assert len(tbl_b.current_snapshot().manifests(tbl_b.io)) == 3 # type: ignore |
| assert len(tbl_c.current_snapshot().manifests(tbl_c.io)) == 3 # type: ignore |
| |
| # tbl_a and tbl_c should contain the same data |
| assert tbl_a.scan().to_arrow().equals(tbl_c.scan().to_arrow()) |
| # tbl_b and tbl_c should contain the same data |
| assert tbl_b.scan().to_arrow().equals(tbl_c.scan().to_arrow()) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_merge_manifests_file_content(session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int) -> None: |
| tbl_a = _create_table( |
| session_catalog, |
| "default.merge_manifest_a", |
| {"commit.manifest-merge.enabled": "true", "commit.manifest.min-count-to-merge": "1", "format-version": format_version}, |
| [], |
| ) |
| |
| # tbl_a should merge all manifests into 1 |
| tbl_a.append(arrow_table_with_null) |
| |
| tbl_a_first_entries = tbl_a.inspect.entries().to_pydict() |
| first_snapshot_id = tbl_a_first_entries["snapshot_id"][0] |
| first_data_file_path = tbl_a_first_entries["data_file"][0]["file_path"] |
| |
| tbl_a.append(arrow_table_with_null) |
| tbl_a.append(arrow_table_with_null) |
| |
| assert len(tbl_a.current_snapshot().manifests(tbl_a.io)) == 1 # type: ignore |
| |
| # verify the sequence number of tbl_a's only manifest file |
| tbl_a_manifest = tbl_a.current_snapshot().manifests(tbl_a.io)[0] # type: ignore |
| assert tbl_a_manifest.sequence_number == (3 if format_version == 2 else 0) |
| assert tbl_a_manifest.min_sequence_number == (1 if format_version == 2 else 0) |
| |
| # verify the manifest entries of tbl_a, in which the manifests are merged |
| tbl_a_entries = tbl_a.inspect.entries().to_pydict() |
| assert tbl_a_entries["status"] == [1, 0, 0] |
| assert tbl_a_entries["sequence_number"] == [3, 2, 1] if format_version == 2 else [0, 0, 0] |
| assert tbl_a_entries["file_sequence_number"] == [3, 2, 1] if format_version == 2 else [0, 0, 0] |
| for i in range(3): |
| tbl_a_data_file = tbl_a_entries["data_file"][i] |
| assert tbl_a_data_file["column_sizes"] == [ |
| (1, 51), |
| (2, 80), |
| (3, 130), |
| (4, 96), |
| (5, 120), |
| (6, 96), |
| (7, 120), |
| (8, 120), |
| (9, 120), |
| (10, 96), |
| (11, 80), |
| (12, 111), |
| ] |
| assert tbl_a_data_file["content"] == 0 |
| assert tbl_a_data_file["equality_ids"] is None |
| assert tbl_a_data_file["file_format"] == "PARQUET" |
| assert tbl_a_data_file["file_path"].startswith("s3://warehouse/default/merge_manifest_a/data/") |
| if tbl_a_data_file["file_path"] == first_data_file_path: |
| # verify that the snapshot id recorded should be the one where the file was added |
| assert tbl_a_entries["snapshot_id"][i] == first_snapshot_id |
| assert tbl_a_data_file["key_metadata"] is None |
| assert tbl_a_data_file["lower_bounds"] == [ |
| (1, b"\x00"), |
| (2, b"a"), |
| (3, b"aaaaaaaaaaaaaaaa"), |
| (4, b"\x01\x00\x00\x00"), |
| (5, b"\x01\x00\x00\x00\x00\x00\x00\x00"), |
| (6, b"\x00\x00\x00\x80"), |
| (7, b"\x00\x00\x00\x00\x00\x00\x00\x80"), |
| (8, b"\x00\x9bj\xca8\xf1\x05\x00"), |
| (9, b"\x00\x9bj\xca8\xf1\x05\x00"), |
| (10, b"\x9eK\x00\x00"), |
| (11, b"\x01"), |
| (12, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), |
| ] |
| assert tbl_a_data_file["nan_value_counts"] == [] |
| assert tbl_a_data_file["null_value_counts"] == [ |
| (1, 1), |
| (2, 1), |
| (3, 1), |
| (4, 1), |
| (5, 1), |
| (6, 1), |
| (7, 1), |
| (8, 1), |
| (9, 1), |
| (10, 1), |
| (11, 1), |
| (12, 1), |
| ] |
| assert tbl_a_data_file["partition"] == {} |
| assert tbl_a_data_file["record_count"] == 3 |
| assert tbl_a_data_file["sort_order_id"] is None |
| assert tbl_a_data_file["split_offsets"] == [4] |
| assert tbl_a_data_file["upper_bounds"] == [ |
| (1, b"\x01"), |
| (2, b"z"), |
| (3, b"zzzzzzzzzzzzzzz{"), |
| (4, b"\t\x00\x00\x00"), |
| (5, b"\t\x00\x00\x00\x00\x00\x00\x00"), |
| (6, b"fff?"), |
| (7, b"\xcd\xcc\xcc\xcc\xcc\xcc\xec?"), |
| (8, b"\x00\xbb\r\xab\xdb\xf5\x05\x00"), |
| (9, b"\x00\xbb\r\xab\xdb\xf5\x05\x00"), |
| (10, b"\xd9K\x00\x00"), |
| (11, b"\x12"), |
| (12, b"\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11\x11"), |
| ] |
| assert tbl_a_data_file["value_counts"] == [ |
| (1, 3), |
| (2, 3), |
| (3, 3), |
| (4, 3), |
| (5, 3), |
| (6, 3), |
| (7, 3), |
| (8, 3), |
| (9, 3), |
| (10, 3), |
| (11, 3), |
| (12, 3), |
| ] |
| |
| |
| @pytest.mark.integration |
| def test_rest_catalog_with_empty_catalog_name_append_data(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_rest_append" |
| test_catalog = load_catalog( |
| "", # intentionally empty |
| **session_catalog.properties, |
| ) |
| tbl = _create_table(test_catalog, identifier, data=[]) |
| tbl.append(arrow_table_with_null) |
| |
| |
| @pytest.mark.integration |
| def test_table_v1_with_null_nested_namespace(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.table_v1_with_null_nested_namespace" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, [arrow_table_with_null]) |
| assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" |
| |
| assert session_catalog.load_table(identifier) is not None |
| assert session_catalog.table_exists(identifier) |
| |
| # We expect no error here |
| session_catalog.drop_table(identifier) |
| |
| |
| @pytest.mark.integration |
| def test_view_exists( |
| spark: SparkSession, |
| session_catalog: Catalog, |
| ) -> None: |
| identifier = "default.some_view" |
| spark.sql( |
| f""" |
| CREATE VIEW {identifier} |
| AS |
| (SELECT 1 as some_col) |
| """ |
| ).collect() |
| assert session_catalog.view_exists(identifier) |
| session_catalog.drop_view(identifier) # clean up |
| |
| |
| @pytest.mark.integration |
| def test_overwrite_all_data_with_filter(session_catalog: Catalog) -> None: |
| schema = Schema( |
| NestedField(1, "id", StringType(), required=True), |
| NestedField(2, "name", StringType(), required=False), |
| identifier_field_ids=[1], |
| ) |
| |
| data = pa.Table.from_pylist( |
| [ |
| {"id": "1", "name": "Amsterdam"}, |
| {"id": "2", "name": "San Francisco"}, |
| {"id": "3", "name": "Drachten"}, |
| ], |
| schema=schema.as_arrow(), |
| ) |
| |
| identifier = "default.test_overwrite_all_data_with_filter" |
| tbl = _create_table(session_catalog, identifier, data=[data], schema=schema) |
| tbl.overwrite(data, In("id", ["1", "2", "3"])) |
| |
| assert len(tbl.scan().to_arrow()) == 3 |
| |
| |
| @pytest.mark.integration |
| def test_delete_threshold(session_catalog: Catalog) -> None: |
| schema = Schema( |
| NestedField(field_id=101, name="id", field_type=LongType(), required=True), |
| NestedField(field_id=103, name="created_at", field_type=DateType(), required=False), |
| NestedField(field_id=104, name="relevancy_score", field_type=DoubleType(), required=False), |
| ) |
| |
| partition_spec = PartitionSpec(PartitionField(source_id=103, field_id=2000, transform=DayTransform(), name="created_at_day")) |
| |
| try: |
| session_catalog.drop_table( |
| identifier="default.scores", |
| ) |
| except NoSuchTableError: |
| pass |
| |
| session_catalog.create_table( |
| identifier="default.scores", |
| schema=schema, |
| partition_spec=partition_spec, |
| ) |
| |
| # Parameters |
| num_rows = 100 # Number of rows in the dataframe |
| id_min, id_max = 1, 10000 |
| date_start, date_end = date(2024, 1, 1), date(2024, 2, 1) |
| |
| # Generate the 'id' column |
| id_column = [random.randint(id_min, id_max) for _ in range(num_rows)] |
| |
| # Generate the 'created_at' column as dates only |
| date_range = pd.date_range(start=date_start, end=date_end, freq="D").to_list() # Daily frequency for dates |
| created_at_column = [random.choice(date_range) for _ in range(num_rows)] # Convert to string (YYYY-MM-DD format) |
| |
| # Generate the 'relevancy_score' column with a peak around 0.1 |
| relevancy_score_column = [random.betavariate(2, 20) for _ in range(num_rows)] # Adjusting parameters to peak around 0.1 |
| |
| # Create the dataframe |
| df = pd.DataFrame({"id": id_column, "created_at": created_at_column, "relevancy_score": relevancy_score_column}) |
| |
| iceberg_table = session_catalog.load_table("default.scores") |
| |
| # Convert the pandas DataFrame to a PyArrow Table with the Iceberg schema |
| arrow_schema = iceberg_table.schema().as_arrow() |
| docs_table = pa.Table.from_pandas(df, schema=arrow_schema) |
| |
| # Append the data to the Iceberg table |
| iceberg_table.append(docs_table) |
| |
| delete_condition = GreaterThanOrEqual("relevancy_score", 0.1) |
| lower_before = len(iceberg_table.scan(row_filter=Not(delete_condition)).to_arrow()) |
| assert len(iceberg_table.scan(row_filter=Not(delete_condition)).to_arrow()) == lower_before |
| iceberg_table.delete(delete_condition) |
| assert len(iceberg_table.scan().to_arrow()) == lower_before |
| |
| |
| @pytest.mark.integration |
| def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) -> None: |
| random.seed(876) |
| N = 1440 |
| d = { |
| "timestamp": pa.array([datetime(2023, 1, 1, 0, 0, 0) + timedelta(minutes=i) for i in range(N)]), |
| "category": pa.array([random.choice(["A", "B", "C"]) for _ in range(N)]), |
| "value": pa.array([random.gauss(0, 1) for _ in range(N)]), |
| } |
| data = pa.Table.from_pydict(d) |
| |
| try: |
| session_catalog.drop_table( |
| identifier="default.test_error_table", |
| ) |
| except NoSuchTableError: |
| pass |
| |
| table = session_catalog.create_table( |
| "default.test_error_table", |
| schema=data.schema, |
| ) |
| |
| with table.update_spec() as update: |
| update.add_field("timestamp", transform=HourTransform()) |
| |
| table.append(data) |
| |
| with table.update_spec() as update: |
| update.add_field("category", transform=IdentityTransform()) |
| |
| data_ = data.filter( |
| (pc.field("category") == "A") |
| & (pc.field("timestamp") >= datetime(2023, 1, 1, 0)) |
| & (pc.field("timestamp") < datetime(2023, 1, 1, 1)) |
| ) |
| |
| table.overwrite( |
| df=data_, |
| overwrite_filter=And( |
| And( |
| GreaterThanOrEqual("timestamp", datetime(2023, 1, 1, 0).isoformat()), |
| LessThan("timestamp", datetime(2023, 1, 1, 1).isoformat()), |
| ), |
| EqualTo("category", "A"), |
| ), |
| ) |
| |
| |
| @pytest.mark.integration |
| def test_writing_null_structs(session_catalog: Catalog) -> None: |
| import pyarrow as pa |
| |
| schema = pa.schema( |
| [ |
| pa.field( |
| "struct_field_1", |
| pa.struct( |
| [ |
| pa.field("string_nested_1", pa.string()), |
| pa.field("int_item_2", pa.int32()), |
| pa.field("float_item_2", pa.float32()), |
| ] |
| ), |
| ), |
| ] |
| ) |
| |
| records = [ |
| { |
| "struct_field_1": { |
| "string_nested_1": "nest_1", |
| "int_item_2": 1234, |
| "float_item_2": 1.234, |
| }, |
| }, |
| {}, |
| ] |
| |
| try: |
| session_catalog.drop_table( |
| identifier="default.test_writing_null_structs", |
| ) |
| except NoSuchTableError: |
| pass |
| |
| table = session_catalog.create_table("default.test_writing_null_structs", schema) |
| |
| pyarrow_table: pa.Table = pa.Table.from_pylist(records, schema=schema) |
| table.append(pyarrow_table) |
| |
| assert pyarrow_table.to_pandas()["struct_field_1"].tolist() == table.scan().to_pandas()["struct_field_1"].tolist() |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_abort_table_transaction_on_exception( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = "default.table_test_abort_table_transaction_on_exception" |
| tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}) |
| |
| # Pre-populate some data |
| tbl.append(arrow_table_with_null) |
| table_size = len(arrow_table_with_null) |
| assert len(tbl.scan().to_pandas()) == table_size |
| |
| # try to commit a transaction that raises exception at the middle |
| with pytest.raises(ValueError): |
| with tbl.transaction() as txn: |
| txn.append(arrow_table_with_null) |
| raise ValueError |
| txn.append(arrow_table_with_null) # type: ignore |
| |
| # Validate the transaction is aborted and no partial update is applied |
| assert len(tbl.scan().to_pandas()) == table_size # type: ignore |
| |
| |
| @pytest.mark.integration |
| def test_write_optional_list(session_catalog: Catalog) -> None: |
| identifier = "default.test_write_optional_list" |
| schema = Schema( |
| NestedField(field_id=1, name="name", field_type=StringType(), required=False), |
| NestedField( |
| field_id=3, |
| name="my_list", |
| field_type=ListType(element_id=45, element=StringType(), element_required=False), |
| required=False, |
| ), |
| ) |
| session_catalog.create_table_if_not_exists(identifier, schema) |
| |
| df_1 = pa.Table.from_pylist( |
| [ |
| {"name": "one", "my_list": ["test"]}, |
| {"name": "another", "my_list": ["test"]}, |
| ] |
| ) |
| session_catalog.load_table(identifier).append(df_1) |
| |
| assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 2 |
| |
| df_2 = pa.Table.from_pylist( |
| [ |
| {"name": "one"}, |
| {"name": "another"}, |
| ] |
| ) |
| session_catalog.load_table(identifier).append(df_2) |
| |
| assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_double_commit_transaction( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = "default.arrow_data_files" |
| tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, []) |
| |
| assert len(tbl.metadata.metadata_log) == 0 |
| |
| with tbl.transaction() as tx: |
| tx.append(arrow_table_with_null) |
| tx.commit_transaction() |
| |
| assert len(tbl.metadata.metadata_log) == 1 |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_evolve_and_write( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = "default.test_evolve_and_write" |
| tbl = _create_table(session_catalog, identifier, properties={"format-version": format_version}, schema=Schema()) |
| other_table = session_catalog.load_table(identifier) |
| |
| numbers = pa.array([1, 2, 3, 4], type=pa.int32()) |
| |
| with tbl.update_schema() as upd: |
| # This is not known by other_table |
| upd.add_column("id", IntegerType()) |
| |
| with other_table.transaction() as tx: |
| # Refreshes the underlying metadata, and the schema |
| other_table.refresh() |
| tx.append( |
| pa.Table.from_arrays( |
| [ |
| numbers, |
| ], |
| schema=pa.schema( |
| [ |
| pa.field("id", pa.int32(), nullable=True), |
| ] |
| ), |
| ) |
| ) |
| |
| assert session_catalog.load_table(identifier).scan().to_arrow().column(0).combine_chunks() == numbers |
| |
| |
| @pytest.mark.integration |
| def test_read_write_decimals(session_catalog: Catalog) -> None: |
| """Roundtrip decimal types to make sure that we correctly write them as ints""" |
| identifier = "default.test_read_write_decimals" |
| |
| arrow_table = pa.Table.from_pydict( |
| { |
| "decimal8": pa.array([Decimal("123.45"), Decimal("678.91")], pa.decimal128(8, 2)), |
| "decimal16": pa.array([Decimal("12345679.123456"), Decimal("67891234.678912")], pa.decimal128(16, 6)), |
| "decimal19": pa.array([Decimal("1234567890123.123456"), Decimal("9876543210703.654321")], pa.decimal128(19, 6)), |
| }, |
| ) |
| |
| tbl = _create_table( |
| session_catalog, |
| identifier, |
| properties={"format-version": 2}, |
| schema=Schema( |
| NestedField(1, "decimal8", DecimalType(8, 2)), |
| NestedField(2, "decimal16", DecimalType(16, 6)), |
| NestedField(3, "decimal19", DecimalType(19, 6)), |
| ), |
| ) |
| |
| tbl.append(arrow_table) |
| |
| assert tbl.scan().to_arrow() == arrow_table |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize( |
| "transform", |
| [ |
| IdentityTransform(), |
| # Bucket is disabled because of an issue in Iceberg Java: |
| # https://github.com/apache/iceberg/pull/13324 |
| # BucketTransform(32) |
| ], |
| ) |
| def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transform: Transform) -> None: # type: ignore |
| identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}" |
| |
| schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True)) |
| |
| try: |
| session_catalog.drop_table(identifier=identifier) |
| except NoSuchTableError: |
| pass |
| |
| partition_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=transform, name="uuid_identity")) |
| |
| import pyarrow as pa |
| |
| arr_table = pa.Table.from_pydict( |
| { |
| "uuid": [ |
| uuid.UUID("00000000-0000-0000-0000-000000000000").bytes, |
| uuid.UUID("11111111-1111-1111-1111-111111111111").bytes, |
| ], |
| }, |
| schema=pa.schema( |
| [ |
| # Uuid not yet supported, so we have to stick with `binary(16)` |
| # https://github.com/apache/arrow/issues/46468 |
| pa.field("uuid", pa.binary(16), nullable=False), |
| ] |
| ), |
| ) |
| |
| tbl = session_catalog.create_table( |
| identifier=identifier, |
| schema=schema, |
| partition_spec=partition_spec, |
| ) |
| |
| tbl.append(arr_table) |
| |
| lhs = [r[0] for r in spark.table(identifier).collect()] |
| rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()] |
| assert lhs == rhs |
| |
| |
| @pytest.mark.integration |
| def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_avro_compression_codecs" |
| tbl = _create_table(session_catalog, identifier, schema=arrow_table_with_null.schema, data=[arrow_table_with_null]) |
| |
| current_snapshot = tbl.current_snapshot() |
| assert current_snapshot is not None |
| |
| with tbl.io.new_input(current_snapshot.manifest_list).open() as f: |
| reader = fastavro.reader(f) |
| assert reader.codec == "deflate" |
| |
| with tbl.transaction() as tx: |
| tx.set_properties(**{TableProperties.WRITE_AVRO_COMPRESSION: "null"}) # type: ignore |
| |
| tbl.append(arrow_table_with_null) |
| |
| current_snapshot = tbl.current_snapshot() |
| assert current_snapshot is not None |
| |
| with tbl.io.new_input(current_snapshot.manifest_list).open() as f: |
| reader = fastavro.reader(f) |
| assert reader.codec == "null" |
| |
| |
| @pytest.mark.integration |
| def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_non_existing_branch" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, []) |
| with pytest.raises( |
| CommitFailedException, match=f"Table has no snapshots and can only be written to the {MAIN_BRANCH} BRANCH." |
| ): |
| tbl.append(arrow_table_with_null, branch="non_existing_branch") |
| |
| |
| @pytest.mark.integration |
| def test_append_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_existing_branch_append" |
| branch = "existing_branch" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| |
| assert tbl.metadata.current_snapshot_id is not None |
| |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| tbl.append(arrow_table_with_null, branch=branch) |
| |
| assert len(tbl.scan().use_ref(branch).to_arrow()) == 6 |
| assert len(tbl.scan().to_arrow()) == 3 |
| branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| assert branch_snapshot is not None |
| main_snapshot = tbl.metadata.snapshot_by_name("main") |
| assert main_snapshot is not None |
| assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| |
| |
| @pytest.mark.integration |
| def test_delete_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_existing_branch_delete" |
| branch = "existing_branch" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| |
| assert tbl.metadata.current_snapshot_id is not None |
| |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| tbl.delete(delete_filter="int = 9", branch=branch) |
| |
| assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 |
| assert len(tbl.scan().to_arrow()) == 3 |
| branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| assert branch_snapshot is not None |
| main_snapshot = tbl.metadata.snapshot_by_name("main") |
| assert main_snapshot is not None |
| assert branch_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| |
| |
| @pytest.mark.integration |
| def test_overwrite_to_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_existing_branch_overwrite" |
| branch = "existing_branch" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| |
| assert tbl.metadata.current_snapshot_id is not None |
| |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| tbl.overwrite(arrow_table_with_null, branch=branch) |
| |
| assert len(tbl.scan().use_ref(branch).to_arrow()) == 3 |
| assert len(tbl.scan().to_arrow()) == 3 |
| branch_snapshot = tbl.metadata.snapshot_by_name(branch) |
| assert branch_snapshot is not None and branch_snapshot.parent_snapshot_id is not None |
| delete_snapshot = tbl.metadata.snapshot_by_id(branch_snapshot.parent_snapshot_id) |
| assert delete_snapshot is not None |
| main_snapshot = tbl.metadata.snapshot_by_name("main") |
| assert main_snapshot is not None |
| assert ( |
| delete_snapshot.parent_snapshot_id == main_snapshot.snapshot_id |
| ) # Currently overwrite is a delete followed by an append operation |
| |
| |
| @pytest.mark.integration |
| def test_intertwined_branch_writes(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: |
| identifier = "default.test_intertwined_branch_operations" |
| branch1 = "existing_branch_1" |
| branch2 = "existing_branch_2" |
| |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| |
| assert tbl.metadata.current_snapshot_id is not None |
| |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch1).commit() |
| |
| tbl.delete("int = 9", branch=branch1) |
| |
| tbl.append(arrow_table_with_null) |
| |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch2).commit() |
| |
| tbl.overwrite(arrow_table_with_null, branch=branch2) |
| |
| assert len(tbl.scan().use_ref(branch1).to_arrow()) == 2 |
| assert len(tbl.scan().use_ref(branch2).to_arrow()) == 3 |
| assert len(tbl.scan().to_arrow()) == 6 |
| |
| |
| @pytest.mark.integration |
| def test_branch_spark_write_py_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: |
| # Initialize table with branch |
| identifier = "default.test_branch_spark_write_py_read" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| branch = "existing_spark_branch" |
| |
| # Create branch in Spark |
| spark.sql(f"ALTER TABLE {identifier} CREATE BRANCH {branch}") |
| |
| # Spark Write |
| spark.sql( |
| f""" |
| DELETE FROM {identifier}.branch_{branch} |
| WHERE int = 9 |
| """ |
| ) |
| |
| # Refresh table to get new refs |
| tbl.refresh() |
| |
| # Python Read |
| assert len(tbl.scan().to_arrow()) == 3 |
| assert len(tbl.scan().use_ref(branch).to_arrow()) == 2 |
| |
| |
| @pytest.mark.integration |
| def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table) -> None: |
| # Initialize table with branch |
| identifier = "default.test_branch_py_write_spark_read" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "2"}, [arrow_table_with_null]) |
| branch = "existing_py_branch" |
| |
| assert tbl.metadata.current_snapshot_id is not None |
| |
| # Create branch |
| tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() |
| |
| # Python Write |
| tbl.delete("int = 9", branch=branch) |
| |
| # Spark Read |
| main_df = spark.sql( |
| f""" |
| SELECT * |
| FROM {identifier} |
| """ |
| ) |
| branch_df = spark.sql( |
| f""" |
| SELECT * |
| FROM {identifier}.branch_{branch} |
| """ |
| ) |
| assert main_df.count() == 3 |
| assert branch_df.count() == 2 |
| |
| |
| @pytest.mark.integration |
| def test_nanosecond_support_on_catalog( |
| session_catalog: Catalog, arrow_table_schema_with_all_timestamp_precisions: pa.Schema |
| ) -> None: |
| identifier = "default.test_nanosecond_support_on_catalog" |
| |
| catalog = load_catalog("default", type="in-memory") |
| catalog.create_namespace("ns") |
| |
| _create_table(session_catalog, identifier, {"format-version": "3"}, schema=arrow_table_schema_with_all_timestamp_precisions) |
| |
| with pytest.raises(NotImplementedError, match="Writing V3 is not yet supported"): |
| catalog.create_table( |
| "ns.table1", schema=arrow_table_schema_with_all_timestamp_precisions, properties={"format-version": "3"} |
| ) |
| |
| with pytest.raises( |
| UnsupportedPyArrowTypeException, match=re.escape("Column 'timestamp_ns' has an unsupported type: timestamp[ns]") |
| ): |
| _create_table( |
| session_catalog, identifier, {"format-version": "2"}, schema=arrow_table_schema_with_all_timestamp_precisions |
| ) |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_stage_only_delete( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = f"default.test_stage_only_delete_files_v{format_version}" |
| iceberg_spec = PartitionSpec( |
| *[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")] |
| ) |
| tbl = _create_table( |
| session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec |
| ) |
| |
| current_snapshot = tbl.metadata.current_snapshot_id |
| assert current_snapshot is not None |
| |
| original_count = len(tbl.scan().to_arrow()) |
| assert original_count == 3 |
| |
| tbl.delete("int = 9", branch=None) |
| |
| # a new delete snapshot is added |
| snapshots = tbl.snapshots() |
| assert len(snapshots) == 2 |
| # snapshot main ref has not changed |
| assert current_snapshot == tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == original_count |
| |
| # Write to main branch |
| tbl.append(arrow_table_with_null) |
| |
| # Main ref has changed |
| assert current_snapshot != tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == 6 |
| snapshots = tbl.snapshots() |
| assert len(snapshots) == 3 |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, parent_id |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| operations = [row.operation for row in rows] |
| parent_snapshot_id = [row.parent_id for row in rows] |
| assert operations == ["append", "delete", "append"] |
| # both subsequent parent id should be the first snapshot id |
| assert parent_snapshot_id == [None, current_snapshot, current_snapshot] |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_stage_only_append( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = f"default.test_stage_only_fast_append_files_v{format_version}" |
| tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) |
| |
| current_snapshot = tbl.metadata.current_snapshot_id |
| assert current_snapshot is not None |
| |
| original_count = len(tbl.scan().to_arrow()) |
| assert original_count == 3 |
| |
| # Write to staging branch |
| tbl.append(arrow_table_with_null, branch=None) |
| |
| # Main ref has not changed and data is not yet appended |
| assert current_snapshot == tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == original_count |
| # There should be a new staged snapshot |
| snapshots = tbl.snapshots() |
| assert len(snapshots) == 2 |
| |
| # Write to main branch |
| tbl.append(arrow_table_with_null) |
| |
| # Main ref has changed |
| assert current_snapshot != tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == 6 |
| snapshots = tbl.snapshots() |
| assert len(snapshots) == 3 |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, parent_id |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| operations = [row.operation for row in rows] |
| parent_snapshot_id = [row.parent_id for row in rows] |
| assert operations == ["append", "append", "append"] |
| # both subsequent parent id should be the first snapshot id |
| assert parent_snapshot_id == [None, current_snapshot, current_snapshot] |
| |
| |
| @pytest.mark.integration |
| @pytest.mark.parametrize("format_version", [1, 2]) |
| def test_stage_only_overwrite_files( |
| spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int |
| ) -> None: |
| identifier = f"default.test_stage_only_overwrite_files_v{format_version}" |
| tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null]) |
| first_snapshot = tbl.metadata.current_snapshot_id |
| |
| # duplicate data with a new insert |
| tbl.append(arrow_table_with_null) |
| |
| second_snapshot = tbl.metadata.current_snapshot_id |
| assert second_snapshot is not None |
| original_count = len(tbl.scan().to_arrow()) |
| assert original_count == 6 |
| |
| # write to non-main branch |
| tbl.overwrite(arrow_table_with_null, branch=None) |
| assert second_snapshot == tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == original_count |
| snapshots = tbl.snapshots() |
| # overwrite will create 2 snapshots |
| assert len(snapshots) == 4 |
| |
| # Write to main branch again |
| tbl.append(arrow_table_with_null) |
| |
| # Main ref has changed |
| assert second_snapshot != tbl.metadata.current_snapshot_id |
| assert len(tbl.scan().to_arrow()) == 9 |
| snapshots = tbl.snapshots() |
| assert len(snapshots) == 5 |
| |
| rows = spark.sql( |
| f""" |
| SELECT operation, parent_id, snapshot_id |
| FROM {identifier}.snapshots |
| ORDER BY committed_at ASC |
| """ |
| ).collect() |
| operations = [row.operation for row in rows] |
| parent_snapshot_id = [row.parent_id for row in rows] |
| assert operations == ["append", "append", "delete", "append", "append"] |
| |
| assert parent_snapshot_id == [None, first_snapshot, second_snapshot, second_snapshot, second_snapshot] |
| |
| |
| @pytest.mark.skip("V3 writer support is not enabled.") |
| @pytest.mark.integration |
| def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Catalog) -> None: |
| """Test writing to a v3 table and reading with Spark.""" |
| identifier = "default.test_v3_write_and_read" |
| tbl = _create_table(session_catalog, identifier, {"format-version": "3"}) |
| assert tbl.format_version == 3, f"Expected v3, got: v{tbl.format_version}" |
| initial_next_row_id = tbl.metadata.next_row_id or 0 |
| |
| test_data = pa.Table.from_pydict( |
| { |
| "bool": [True, False, True], |
| "string": ["a", "b", "c"], |
| "string_long": ["a_long", "b_long", "c_long"], |
| "int": [1, 2, 3], |
| "long": [11, 22, 33], |
| "float": [1.1, 2.2, 3.3], |
| "double": [1.11, 2.22, 3.33], |
| "timestamp": [datetime(2023, 1, 1, 1, 1, 1), datetime(2023, 2, 2, 2, 2, 2), datetime(2023, 3, 3, 3, 3, 3)], |
| "timestamptz": [ |
| datetime(2023, 1, 1, 1, 1, 1, tzinfo=pytz.utc), |
| datetime(2023, 2, 2, 2, 2, 2, tzinfo=pytz.utc), |
| datetime(2023, 3, 3, 3, 3, 3, tzinfo=pytz.utc), |
| ], |
| "date": [date(2023, 1, 1), date(2023, 2, 2), date(2023, 3, 3)], |
| "binary": [b"\x01", b"\x02", b"\x03"], |
| "fixed": [b"1234567890123456", b"1234567890123456", b"1234567890123456"], |
| }, |
| schema=TABLE_SCHEMA.as_arrow(), |
| ) |
| |
| tbl.append(test_data) |
| |
| assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), ( |
| "Expected next_row_id to be incremented by the number of added rows" |
| ) |