blob: 891d4bbac7395c8781a5b734659fe9371ce6164d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pathlib import PosixPath
import pyarrow as pa
import pytest
from datafusion import SessionContext
from pyarrow import Table as pa_table
from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference
from pyiceberg.expressions.literals import LongLiteral
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.table import UpsertResult
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.upsert_util import create_match_filter
from pyiceberg.types import IntegerType, NestedField, StringType, StructType
from tests.catalog.test_base import InMemoryCatalog, Table
@pytest.fixture
def catalog(tmp_path: PosixPath) -> InMemoryCatalog:
catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix())
catalog.create_namespace("default")
return catalog
def _drop_table(catalog: Catalog, identifier: str) -> None:
try:
catalog.drop_table(identifier)
except NoSuchTableError:
pass
def show_iceberg_table(table: Table, ctx: SessionContext) -> None:
import pyarrow.dataset as ds
table_name = "target"
if ctx.table_exist(table_name):
ctx.deregister_table(table_name)
ctx.register_dataset(table_name, ds.dataset(table.scan().to_arrow()))
ctx.sql(f"SELECT * FROM {table_name} limit 5").show()
def show_df(df: pa_table, ctx: SessionContext) -> None:
import pyarrow.dataset as ds
ctx.register_dataset("df", ds.dataset(df))
ctx.sql("select * from df limit 10").show()
def gen_source_dataset(start_row: int, end_row: int, composite_key: bool, add_dup: bool, ctx: SessionContext) -> pa_table:
additional_columns = ", t.order_id + 1000 as order_line_id" if composite_key else ""
dup_row = (
f"""
UNION ALL
(
SELECT t.order_id {additional_columns}
, date '2021-01-01' as order_date, 'B' as order_type
from t
limit 1
)
"""
if add_dup
else ""
)
sql = f"""
with t as (SELECT unnest(range({start_row},{end_row + 1})) as order_id)
SELECT t.order_id {additional_columns}
, date '2021-01-01' as order_date, 'B' as order_type
from t
{dup_row}
"""
df = ctx.sql(sql).to_arrow_table()
return df
def gen_target_iceberg_table(
start_row: int, end_row: int, composite_key: bool, ctx: SessionContext, catalog: InMemoryCatalog, identifier: str
) -> Table:
additional_columns = ", t.order_id + 1000 as order_line_id" if composite_key else ""
df = ctx.sql(f"""
with t as (SELECT unnest(range({start_row},{end_row + 1})) as order_id)
SELECT t.order_id {additional_columns}
, date '2021-01-01' as order_date, 'A' as order_type
from t
""").to_arrow_table()
table = catalog.create_table(identifier, df.schema)
table.append(df)
return table
def assert_upsert_result(res: UpsertResult, expected_updated: int, expected_inserted: int) -> None:
assert res.rows_updated == expected_updated, f"rows updated should be {expected_updated}, but got {res.rows_updated}"
assert res.rows_inserted == expected_inserted, f"rows inserted should be {expected_inserted}, but got {res.rows_inserted}"
@pytest.mark.parametrize(
"join_cols, src_start_row, src_end_row, target_start_row, target_end_row, when_matched_update_all, when_not_matched_insert_all, expected_updated, expected_inserted",
[
(["order_id"], 1, 2, 2, 3, True, True, 1, 1), # single row
(["order_id"], 5001, 15000, 1, 10000, True, True, 5000, 5000), # 10k rows
(["order_id"], 501, 1500, 1, 1000, True, False, 500, 0), # update only
(["order_id"], 501, 1500, 1, 1000, False, True, 0, 500), # insert only
],
)
def test_merge_rows(
catalog: Catalog,
join_cols: list[str],
src_start_row: int,
src_end_row: int,
target_start_row: int,
target_end_row: int,
when_matched_update_all: bool,
when_not_matched_insert_all: bool,
expected_updated: int,
expected_inserted: int,
) -> None:
identifier = "default.test_merge_rows"
_drop_table(catalog, identifier)
ctx = SessionContext()
source_df = gen_source_dataset(src_start_row, src_end_row, False, False, ctx)
ice_table = gen_target_iceberg_table(target_start_row, target_end_row, False, ctx, catalog, identifier)
res = ice_table.upsert(
df=source_df,
join_cols=join_cols,
when_matched_update_all=when_matched_update_all,
when_not_matched_insert_all=when_not_matched_insert_all,
)
assert_upsert_result(res, expected_updated, expected_inserted)
def test_merge_scenario_skip_upd_row(catalog: Catalog) -> None:
"""
tests a single insert and update; skips a row that does not need to be updated
"""
identifier = "default.test_merge_scenario_skip_upd_row"
_drop_table(catalog, identifier)
ctx = SessionContext()
df = ctx.sql("""
select 1 as order_id, date '2021-01-01' as order_date, 'A' as order_type
union all
select 2 as order_id, date '2021-01-01' as order_date, 'A' as order_type
""").to_arrow_table()
table = catalog.create_table(identifier, df.schema)
table.append(df)
source_df = ctx.sql("""
select 1 as order_id, date '2021-01-01' as order_date, 'A' as order_type
union all
select 2 as order_id, date '2021-01-01' as order_date, 'B' as order_type
union all
select 3 as order_id, date '2021-01-01' as order_date, 'A' as order_type
""").to_arrow_table()
res = table.upsert(df=source_df, join_cols=["order_id"])
expected_updated = 1
expected_inserted = 1
assert_upsert_result(res, expected_updated, expected_inserted)
def test_merge_scenario_date_as_key(catalog: Catalog) -> None:
"""
tests a single insert and update; primary key is a date column
"""
ctx = SessionContext()
identifier = "default.test_merge_scenario_date_as_key"
_drop_table(catalog, identifier)
df = ctx.sql("""
select date '2021-01-01' as order_date, 'A' as order_type
union all
select date '2021-01-02' as order_date, 'A' as order_type
""").to_arrow_table()
table = catalog.create_table(identifier, df.schema)
table.append(df)
source_df = ctx.sql("""
select date '2021-01-01' as order_date, 'A' as order_type
union all
select date '2021-01-02' as order_date, 'B' as order_type
union all
select date '2021-01-03' as order_date, 'A' as order_type
""").to_arrow_table()
res = table.upsert(df=source_df, join_cols=["order_date"])
expected_updated = 1
expected_inserted = 1
assert_upsert_result(res, expected_updated, expected_inserted)
def test_merge_scenario_string_as_key(catalog: Catalog) -> None:
"""
tests a single insert and update; primary key is a string column
"""
identifier = "default.test_merge_scenario_string_as_key"
_drop_table(catalog, identifier)
ctx = SessionContext()
df = ctx.sql("""
select 'abc' as order_id, 'A' as order_type
union all
select 'def' as order_id, 'A' as order_type
""").to_arrow_table()
table = catalog.create_table(identifier, df.schema)
table.append(df)
source_df = ctx.sql("""
select 'abc' as order_id, 'A' as order_type
union all
select 'def' as order_id, 'B' as order_type
union all
select 'ghi' as order_id, 'A' as order_type
""").to_arrow_table()
res = table.upsert(df=source_df, join_cols=["order_id"])
expected_updated = 1
expected_inserted = 1
assert_upsert_result(res, expected_updated, expected_inserted)
def test_merge_scenario_composite_key(catalog: Catalog) -> None:
"""
tests merging 200 rows with a composite key
"""
identifier = "default.test_merge_scenario_composite_key"
_drop_table(catalog, identifier)
ctx = SessionContext()
table = gen_target_iceberg_table(1, 200, True, ctx, catalog, identifier)
source_df = gen_source_dataset(101, 300, True, False, ctx)
res = table.upsert(df=source_df, join_cols=["order_id", "order_line_id"])
expected_updated = 100
expected_inserted = 100
assert_upsert_result(res, expected_updated, expected_inserted)
def test_merge_source_dups(catalog: Catalog) -> None:
"""
tests duplicate rows in source
"""
identifier = "default.test_merge_source_dups"
_drop_table(catalog, identifier)
ctx = SessionContext()
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier)
source_df = gen_source_dataset(5, 15, False, True, ctx)
with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"):
table.upsert(df=source_df, join_cols=["order_id"])
def test_key_cols_misaligned(catalog: Catalog) -> None:
"""
tests join columns missing from one of the tables
"""
identifier = "default.test_key_cols_misaligned"
_drop_table(catalog, identifier)
ctx = SessionContext()
df = ctx.sql("select 1 as order_id, date '2021-01-01' as order_date, 'A' as order_type").to_arrow_table()
table = catalog.create_table(identifier, df.schema)
table.append(df)
df_src = ctx.sql("select 1 as item_id, date '2021-05-01' as order_date, 'B' as order_type").to_arrow_table()
with pytest.raises(Exception, match=r"""Field ".*" does not exist in schema"""):
table.upsert(df=df_src, join_cols=["order_id"])
def test_upsert_with_identifier_fields(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_identifier_fields"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "population", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("population", pa.int32(), nullable=False),
]
)
# Write some data
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "population": 921402},
{"city": "San Francisco", "population": 808988},
{"city": "Drachten", "population": 45019},
{"city": "Paris", "population": 2103000},
],
schema=arrow_schema,
)
tbl.append(df)
df = pa.Table.from_pylist(
[
# Will be updated, the population has been updated
{"city": "Drachten", "population": 45505},
# New row, will be inserted
{"city": "Berlin", "population": 3432000},
# Ignored, already exists in the table
{"city": "Paris", "population": 2103000},
],
schema=arrow_schema,
)
upd = tbl.upsert(df)
expected_operations = [Operation.APPEND, Operation.OVERWRITE, Operation.APPEND, Operation.APPEND]
assert upd.rows_updated == 1
assert upd.rows_inserted == 1
assert [snap.summary.operation for snap in tbl.snapshots() if snap.summary is not None] == expected_operations
# This should be a no-op
upd = tbl.upsert(df)
assert upd.rows_updated == 0
assert upd.rows_inserted == 0
assert [snap.summary.operation for snap in tbl.snapshots() if snap.summary is not None] == expected_operations
def test_upsert_into_empty_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_into_empty_table"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "inhabitants", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("inhabitants", pa.int32(), nullable=False),
]
)
# Write some data
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "inhabitants": 921402},
{"city": "San Francisco", "inhabitants": 808988},
{"city": "Drachten", "inhabitants": 45019},
{"city": "Paris", "inhabitants": 2103000},
],
schema=arrow_schema,
)
upd = tbl.upsert(df)
assert upd.rows_updated == 0
assert upd.rows_inserted == 4
def test_create_match_filter_single_condition() -> None:
"""
Test create_match_filter with a composite key where the source yields exactly one unique key.
Expected: The function returns the single And condition directly.
"""
data = [
{"order_id": 101, "order_line_id": 1, "extra": "x"},
{"order_id": 101, "order_line_id": 1, "extra": "x"}, # duplicate
]
schema = pa.schema([pa.field("order_id", pa.int32()), pa.field("order_line_id", pa.int32()), pa.field("extra", pa.string())])
table = pa.Table.from_pylist(data, schema=schema)
expr = create_match_filter(table, ["order_id", "order_line_id"])
assert expr == And(
EqualTo(term=Reference(name="order_id"), literal=LongLiteral(101)),
EqualTo(term=Reference(name="order_line_id"), literal=LongLiteral(1)),
)
def test_upsert_with_duplicate_rows_in_table(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_duplicate_rows_in_table"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "inhabitants", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("inhabitants", pa.int32(), nullable=False),
]
)
# Write some data
df = pa.Table.from_pylist(
[
{"city": "Drachten", "inhabitants": 45019},
{"city": "Drachten", "inhabitants": 45019},
],
schema=arrow_schema,
)
tbl.append(df)
df = pa.Table.from_pylist(
[
# Will be updated, the inhabitants has been updated
{"city": "Drachten", "inhabitants": 45505},
],
schema=arrow_schema,
)
with pytest.raises(ValueError, match="Target table has duplicate rows, aborting upsert"):
_ = tbl.upsert(df)
def test_upsert_without_identifier_fields(catalog: Catalog) -> None:
identifier = "default.test_upsert_without_identifier_fields"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "population", IntegerType(), required=True),
# No identifier field :o
identifier_field_ids=[],
)
tbl = catalog.create_table(identifier, schema=schema)
# Write some data
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "population": 921402},
{"city": "San Francisco", "population": 808988},
{"city": "Drachten", "population": 45019},
{"city": "Paris", "population": 2103000},
],
schema=schema_to_pyarrow(schema),
)
with pytest.raises(
ValueError, match="Join columns could not be found, please set identifier-field-ids or pass in explicitly."
):
tbl.upsert(df)
def test_upsert_with_struct_field_as_non_join_key(catalog: Catalog) -> None:
identifier = "default.test_upsert_struct_field_fails"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(
2,
"nested_type",
StructType(
NestedField(3, "sub1", StringType(), required=True),
NestedField(4, "sub2", StringType(), required=True),
),
required=False,
),
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field(
"nested_type",
pa.struct(
[
pa.field("sub1", pa.large_string(), nullable=False),
pa.field("sub2", pa.large_string(), nullable=False),
]
),
nullable=True,
),
]
)
initial_data = pa.Table.from_pylist(
[
{
"id": 1,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
}
],
schema=arrow_schema,
)
tbl.append(initial_data)
update_data = pa.Table.from_pylist(
[
{
"id": 2,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
},
{
"id": 1,
"nested_type": {"sub1": "bla1", "sub2": "bla2"},
},
],
schema=arrow_schema,
)
res = tbl.upsert(update_data, join_cols=["id"])
expected_updated = 1
expected_inserted = 1
assert_upsert_result(res, expected_updated, expected_inserted)
update_data = pa.Table.from_pylist(
[
{
"id": 2,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
},
{
"id": 1,
"nested_type": {"sub1": "bla1", "sub2": "bla2"},
},
],
schema=arrow_schema,
)
res = tbl.upsert(update_data, join_cols=["id"])
expected_updated = 0
expected_inserted = 0
assert_upsert_result(res, expected_updated, expected_inserted)
def test_upsert_with_struct_field_as_join_key(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_struct_field_as_join_key"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(
2,
"nested_type",
StructType(
NestedField(3, "sub1", StringType(), required=True),
NestedField(4, "sub2", StringType(), required=True),
),
required=False,
),
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field(
"nested_type",
pa.struct(
[
pa.field("sub1", pa.large_string(), nullable=False),
pa.field("sub2", pa.large_string(), nullable=False),
]
),
nullable=True,
),
]
)
initial_data = pa.Table.from_pylist(
[
{
"id": 1,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
}
],
schema=arrow_schema,
)
tbl.append(initial_data)
update_data = pa.Table.from_pylist(
[
{
"id": 2,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
},
{
"id": 1,
"nested_type": {"sub1": "bla1", "sub2": "bla"},
},
],
schema=arrow_schema,
)
with pytest.raises(
pa.lib.ArrowNotImplementedError, match="Keys of type struct<sub1: large_string not null, sub2: large_string not null>"
):
_ = tbl.upsert(update_data, join_cols=["nested_type"])
def test_upsert_with_nulls(catalog: Catalog) -> None:
identifier = "default.test_upsert_with_nulls"
_drop_table(catalog, identifier)
schema = pa.schema(
[
("foo", pa.string()),
("bar", pa.int32()),
("baz", pa.bool_()),
]
)
# create table with null value
table = catalog.create_table(identifier, schema)
data_with_null = pa.Table.from_pylist(
[
{"foo": "apple", "bar": None, "baz": False},
{"foo": "banana", "bar": None, "baz": False},
],
schema=schema,
)
table.append(data_with_null)
assert table.scan().to_arrow()["bar"].is_null()
# upsert table with non-null value
data_without_null = pa.Table.from_pylist(
[
{"foo": "apple", "bar": 7, "baz": False},
],
schema=schema,
)
upd = table.upsert(data_without_null, join_cols=["foo"])
assert upd.rows_updated == 1
assert upd.rows_inserted == 0
assert table.scan().to_arrow() == pa.Table.from_pylist(
[
{"foo": "apple", "bar": 7, "baz": False},
{"foo": "banana", "bar": None, "baz": False},
],
schema=schema,
)
def test_transaction(catalog: Catalog) -> None:
"""Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is
rolled back."""
identifier = "default.test_merge_source_dups"
_drop_table(catalog, identifier)
ctx = SessionContext()
table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier)
df_before_transaction = table.scan().to_arrow()
source_df = gen_source_dataset(5, 15, False, True, ctx)
with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"):
with table.transaction() as tx:
tx.delete(delete_filter=AlwaysTrue())
tx.upsert(df=source_df, join_cols=["order_id"])
df = table.scan().to_arrow()
assert df_before_transaction == df
def test_transaction_multiple_upserts(catalog: Catalog) -> None:
identifier = "default.test_multi_upsert"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "id", IntegerType(), required=True),
NestedField(2, "name", StringType(), required=True),
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
# Define exact schema: required int32 and required string
arrow_schema = pa.schema(
[
pa.field("id", pa.int32(), nullable=False),
pa.field("name", pa.string(), nullable=False),
]
)
tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema))
df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema)
with tbl.transaction() as txn:
txn.delete(delete_filter="id = 1")
txn.append(df)
# This should read the uncommitted changes
txn.upsert(df, join_cols=["id"])
result = tbl.scan().to_arrow().to_pylist()
assert sorted(result, key=lambda x: x["id"]) == [
{"id": 1, "name": "Alicia"},
{"id": 2, "name": "Bob"},
]
def test_stage_only_upsert(catalog: Catalog) -> None:
identifier = "default.test_stage_only_dynamic_partition_overwrite_files"
_drop_table(catalog, identifier)
schema = Schema(
NestedField(1, "city", StringType(), required=True),
NestedField(2, "inhabitants", IntegerType(), required=True),
# Mark City as the identifier field, also known as the primary-key
identifier_field_ids=[1],
)
tbl = catalog.create_table(identifier, schema=schema)
arrow_schema = pa.schema(
[
pa.field("city", pa.string(), nullable=False),
pa.field("inhabitants", pa.int32(), nullable=False),
]
)
# Write some data
df = pa.Table.from_pylist(
[
{"city": "Amsterdam", "inhabitants": 921402},
{"city": "San Francisco", "inhabitants": 808988},
{"city": "Drachten", "inhabitants": 45019},
{"city": "Paris", "inhabitants": 2103000},
],
schema=arrow_schema,
)
tbl.append(df.slice(0, 1))
current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None
original_count = len(tbl.scan().to_arrow())
assert original_count == 1
# write to staging snapshot
upd = tbl.upsert(df, branch=None)
assert upd.rows_updated == 0
assert upd.rows_inserted == 3
assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count
snapshots = tbl.snapshots()
assert len(snapshots) == 2
# Write to main ref
tbl.append(df.slice(1, 1))
# Main ref has changed
assert current_snapshot != tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == 2
snapshots = tbl.snapshots()
assert len(snapshots) == 3
sorted_snapshots = sorted(tbl.snapshots(), key=lambda s: s.timestamp_ms)
operations = [snapshot.summary.operation.value if snapshot.summary else None for snapshot in sorted_snapshots]
parent_snapshot_id = [snapshot.parent_snapshot_id for snapshot in sorted_snapshots]
assert operations == ["append", "append", "append"]
# both subsequent parent id should be the first snapshot id
assert parent_snapshot_id == [None, current_snapshot, current_snapshot]