blob: 4782b1dc9342d1e39dc55025cbac6ac01e231308 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import shutil
import tempfile
import unittest
import uuid
from unittest.mock import Mock, patch
import pyarrow as pa
import ray
from pypaimon import CatalogFactory, Schema
from pypaimon.ray import (
WhenMatched, WhenNotMatched, merge_into, read_paimon,
source_col, target_col, lit,
)
try:
import datafusion # noqa: F401
_HAS_DATAFUSION = True
except ImportError:
_HAS_DATAFUSION = False
_SKIP_CONDITION = not _HAS_DATAFUSION
_SKIP_REASON = "datafusion not installed"
_TEST_NUM_PARTITIONS = 2
class RayDataEvolutionMergeIntoTest(unittest.TestCase):
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
de_options = {
'row-tracking.enabled': 'true',
'data-evolution.enabled': 'true',
}
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog_options = {'warehouse': cls.warehouse}
cls.catalog = CatalogFactory.create(cls.catalog_options)
cls.catalog.create_database('default', True)
if not ray.is_initialized():
ray.init(ignore_reinit_error=True, num_cpus=2)
@classmethod
def tearDownClass(cls):
try:
if ray.is_initialized():
ray.shutdown()
except Exception:
pass
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_table(self, options=None):
opts = options if options is not None else self.de_options
name = f'default.tbl_{uuid.uuid4().hex[:8]}'
s = Schema.from_pyarrow_schema(self.pa_schema, options=opts)
self.catalog.create_table(name, s, False)
return name
def _source(self, ids=(1,)):
return pa.Table.from_pydict(
{
'id': pa.array(list(ids), type=pa.int32()),
'name': ['x'] * len(ids),
'age': [10] * len(ids),
},
schema=self.pa_schema,
)
def _write(self, target, data):
table = self.catalog.get_table(target)
wb = table.new_batch_write_builder()
writer = wb.new_write()
writer.write_arrow(data)
wb.new_commit().commit(writer.prepare_commit())
writer.close()
def _read_sorted(self, target):
table = self.catalog.get_table(target)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
return rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
def _snapshot_id(self, target):
table = self.catalog.get_table(target)
snap = table.snapshot_manager().get_latest_snapshot()
return snap.id if snap is not None else None
def test_paimon_source_table_pins_snapshot(self):
from pypaimon.ray import data_evolution_merge_into as m
target = self._create_table()
source = self._create_table()
self._write(source, self._source(ids=(1,)))
expected_snapshot_id = self._snapshot_id(source)
fake_ds = Mock()
fake_ds.schema.return_value = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
with patch(
'pypaimon.ray.ray_paimon.read_paimon',
return_value=fake_ds,
) as mock_read_paimon:
m._prepare(
target, source, self.catalog_options,
[WhenMatched(update='*')], [], ['id'],
)
mock_read_paimon.assert_called_once_with(
source,
self.catalog_options,
snapshot_id=expected_snapshot_id,
)
def test_no_clause_raises(self):
target = self._create_table()
with self.assertRaises(ValueError):
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
num_partitions=_TEST_NUM_PARTITIONS,
)
def test_unconditional_non_last_matched_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*'),
WhenMatched(update={'age': 's.age'}, condition='s.age > 10'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('when_matched', str(ctx.exception))
self.assertIn('unreachable', str(ctx.exception))
def test_unconditional_non_last_not_matched_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert='*'),
WhenNotMatched(insert='*', condition='s.age > 10'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('when_not_matched', str(ctx.exception))
self.assertIn('unreachable', str(ctx.exception))
def test_non_de_table_rejected(self):
target = self._create_table(options={'row-tracking.enabled': 'true'})
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('data-evolution.enabled', str(ctx.exception))
def test_no_row_tracking_rejected(self):
target = self._create_table(options={'data-evolution.enabled': 'true'})
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('row-tracking.enabled', str(ctx.exception))
def test_source_missing_on_col_raises(self):
target = self._create_table()
bad_source = pa.Table.from_pydict(
{'name': ['x'], 'age': [10]},
schema=pa.schema([('name', pa.string()), ('age', pa.int32())]),
)
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=bad_source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn("'id'", str(ctx.exception))
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_not_matched_condition_rejects_target_refs(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert='*', condition='t.age > 10')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('t.', str(ctx.exception))
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_condition_unknown_source_col_rejected(self):
target = self._create_table()
self._write(target, self._source())
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.nonexistent > 0')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('nonexistent', str(ctx.exception))
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_condition_unknown_target_col_rejected(self):
target = self._create_table()
self._write(target, self._source())
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > t.nonexistent')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('nonexistent', str(ctx.exception))
def test_matched_update_star(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([2, 3, 4], type=pa.int32()),
'name': ['b2', 'c2', 'd'],
'age': pa.array([22, 33, 40], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b2', 'c2'])
self.assertEqual(out['age'], [10, 22, 33])
def test_not_matched_insert_appends_unmatched(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([2, 3, 4], type=pa.int32()),
'name': ['b2', 'c2', 'd'],
'age': pa.array([22, 33, 40], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3, 4])
self.assertEqual(out['name'], ['a', 'b', 'c', 'd'])
self.assertEqual(out['age'], [10, 20, 30, 40])
def test_combined_update_and_insert(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([2, 3], type=pa.int32()),
'name': ['b2', 'c'],
'age': pa.array([22, 30], type=pa.int32()),
},
schema=self.pa_schema,
)
metrics = merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b2', 'c'])
self.assertEqual(out['age'], [10, 22, 30])
self.assertEqual(metrics, {
'num_matched': 1, 'num_inserted': 1, 'num_unchanged': 0,
})
def test_on_with_renamed_columns_star(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source_schema = pa.schema([
('uid', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
source = pa.Table.from_pydict(
{
'uid': pa.array([2, 3], type=pa.int32()),
'name': ['b2', 'c'],
'age': pa.array([22, 30], type=pa.int32()),
},
schema=source_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on={'id': 'uid'},
when_matched=[WhenMatched(update='*')],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b2', 'c'])
self.assertEqual(out['age'], [10, 22, 30])
def test_insert_into_empty_target(self):
target = self._create_table()
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b', 'c'])
self.assertEqual(out['age'], [10, 20, 30])
def test_multi_source_match_raises_by_default(self):
# One target row matched by several source rows: the winning value is
# undefined (Spark DE's checkCardinality=false), so we refuse by default.
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 1], type=pa.int32()),
'name': ['x', 'y'],
'age': pa.array([100, 200], type=pa.int32()),
},
schema=self.pa_schema,
)
with self.assertRaises(Exception) as ctx:
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn("multiple source rows", str(ctx.exception))
def test_blob_columns_excluded(self):
import types
from pypaimon.ray.data_evolution_merge_into import _blob_col_names
from pypaimon.schema.data_types import AtomicType, DataField
fake_table = types.SimpleNamespace(
table_schema=types.SimpleNamespace(
fields=[
DataField(0, 'id', AtomicType('INT')),
DataField(1, 'payload', AtomicType('BLOB')),
]
)
)
self.assertEqual({'payload'}, _blob_col_names(fake_table))
def test_blob_table_feature_update(self):
blob_schema = pa.schema([
('id', pa.int32()),
('payload', pa.large_binary()),
('feature', pa.int32()),
])
name = f'default.tbl_{uuid.uuid4().hex[:8]}'
schema = Schema.from_pyarrow_schema(
blob_schema, options=self.de_options)
self.catalog.create_table(name, schema, False)
self._write(
name,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'payload': [b'aa', b'bbb', b'cccc'],
'feature': pa.array([10, 20, 30], type=pa.int32()),
},
schema=blob_schema,
),
)
num_partitions = _TEST_NUM_PARTITIONS
records_to_process = ray.data.from_arrow(pa.Table.from_pydict({
'id': pa.array([1, 3], type=pa.int32()),
}))
target_rows = read_paimon(
name,
self.catalog_options,
projection=['id', 'payload'],
)
selected = records_to_process.join(
target_rows,
join_type='inner',
num_partitions=num_partitions,
on=['id'],
)
def compute_feature(batch):
payloads = batch['payload'].to_pylist()
return pa.Table.from_pydict({
'id': batch['id'],
'new_feature': pa.array(
[len(v) if v is not None else 0 for v in payloads],
type=pa.int32(),
),
})
updates = selected.map_batches(compute_feature, batch_format='pyarrow')
metrics = merge_into(
target=name,
source=updates,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update={'feature': source_col('new_feature')})
],
num_partitions=num_partitions,
)
table = self.catalog.get_table(name)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
out = rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['feature'], [2, 20, 4])
self.assertEqual(out['payload'], [b'aa', b'bbb', b'cccc'])
self.assertEqual(metrics, {
'num_matched': 2, 'num_inserted': 0, 'num_unchanged': 0,
})
def test_blob_descriptor_resolve_and_merge(self):
from pypaimon.table.row.blob import BlobDescriptor, Blob
from pypaimon.common.uri_reader import UriReaderFactory
blob_schema = pa.schema([
('id', pa.int32()),
('payload', pa.large_binary()),
('feature', pa.int32()),
])
name = f'default.tbl_{uuid.uuid4().hex[:8]}'
schema = Schema.from_pyarrow_schema(
blob_schema, options=self.de_options)
self.catalog.create_table(name, schema, False)
self._write(
name,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'payload': [b'aa', b'bbb', b'cccc'],
'feature': pa.array([10, 20, 30], type=pa.int32()),
},
schema=blob_schema,
),
)
num_partitions = _TEST_NUM_PARTITIONS
input_ids = ray.data.from_arrow(pa.Table.from_pydict({
'id': pa.array([1, 3], type=pa.int32()),
}))
target_rows = read_paimon(
name,
self.catalog_options,
projection=['id', 'payload'],
dynamic_options={'blob-as-descriptor': 'true'},
)
matched = input_ids.join(
target_rows, join_type='inner',
num_partitions=num_partitions, on=['id'],
)
uri_factory = UriReaderFactory(self.catalog_options)
def resolve_and_compute(batch):
features = []
for desc_bytes in batch['payload'].to_pylist():
desc = BlobDescriptor.deserialize(desc_bytes)
reader = uri_factory.create(desc.uri)
data = Blob.from_descriptor(reader, desc).to_data()
features.append(len(data) * 100)
return pa.Table.from_pydict({
'id': batch['id'],
'new_feature': pa.array(features, type=pa.int32()),
})
updates = matched.map_batches(
resolve_and_compute, batch_format='pyarrow')
metrics = merge_into(
target=name,
source=updates,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update={'feature': source_col('new_feature')})
],
num_partitions=num_partitions,
)
table = self.catalog.get_table(name)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
out = rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['feature'], [200, 20, 400])
self.assertEqual(out['payload'], [b'aa', b'bbb', b'cccc'])
self.assertEqual(metrics['num_matched'], 2)
def test_combined_writes_single_snapshot(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
before = self._snapshot_id(target)
source = pa.Table.from_pydict(
{
'id': pa.array([2, 3], type=pa.int32()),
'name': ['b2', 'c'],
'age': pa.array([22, 30], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
after = self._snapshot_id(target)
self.assertEqual(after, before + 1)
def test_empty_target_matched_update_is_noop(self):
target = self._create_table()
before = self._snapshot_id(target)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertEqual(self._snapshot_id(target), before)
def test_matched_on_partitioned_table(self):
pt_schema = pa.schema([
('pt', pa.string()),
('id', pa.int32()),
('name', pa.string()),
])
name = f'default.tbl_{uuid.uuid4().hex[:8]}'
s = Schema.from_pyarrow_schema(
pt_schema, partition_keys=['pt'], options=self.de_options,
)
self.catalog.create_table(name, s, False)
table = self.catalog.get_table(name)
wb = table.new_batch_write_builder()
writer = wb.new_write()
writer.write_arrow(pa.Table.from_pydict(
{
'pt': ['a', 'a'],
'id': pa.array([1, 2], type=pa.int32()),
'name': ['old_1', 'old_2'],
},
schema=pt_schema,
))
wb.new_commit().commit(writer.prepare_commit())
writer.close()
source = pa.Table.from_pydict(
{
'pt': ['a'],
'id': pa.array([1], type=pa.int32()),
'name': ['new_1'],
},
schema=pt_schema,
)
# Non-partition column update should succeed
merge_into(
target=name,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'name': source_col('name')})],
num_partitions=_TEST_NUM_PARTITIONS,
)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
out = rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
self.assertEqual(out['name'], ['new_1', 'old_2'])
self.assertEqual(out['pt'], ['a', 'a'])
# Partition column update should be rejected
with self.assertRaises(ValueError) as ctx:
merge_into(
target=name,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'pt': source_col('pt')})],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('partition', str(ctx.exception))
def test_partitioned_insert_allowed(self):
pt_schema = pa.schema([
('pt', pa.string()),
('id', pa.int32()),
('name', pa.string()),
])
name = f'default.tbl_{uuid.uuid4().hex[:8]}'
s = Schema.from_pyarrow_schema(
pt_schema, partition_keys=['pt'], options=self.de_options,
)
self.catalog.create_table(name, s, False)
source = pa.Table.from_pydict(
{
'pt': ['a', 'b'],
'id': pa.array([1, 2], type=pa.int32()),
'name': ['x', 'y'],
},
schema=pt_schema,
)
merge_into(
target=name,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[WhenNotMatched(insert='*')],
num_partitions=_TEST_NUM_PARTITIONS,
)
table = self.catalog.get_table(name)
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
out = rb.new_read().to_arrow(splits).sort_by('id').to_pydict()
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['pt'], ['a', 'b'])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_matched_update_with_condition(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a2', 'b2', 'c2'],
'age': pa.array([15, 25, 45], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*', condition='s.age > t.age + 10')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b', 'c2'])
self.assertEqual(out['age'], [10, 20, 45])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_matched_condition_with_source_on_key(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a2', 'b2', 'c2'],
'age': pa.array([15, 25, 35], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*', condition='s.id >= 2')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b2', 'c2'])
self.assertEqual(out['age'], [10, 25, 35])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_not_matched_insert_with_condition(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([2, 3, 4], type=pa.int32()),
'name': ['b', 'c', 'd'],
'age': pa.array([15, 25, 5], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert='*', condition='s.age >= 10')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b', 'c'])
self.assertEqual(out['age'], [10, 15, 25])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_combined_with_conditions(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3, 4], type=pa.int32()),
'name': ['a2', 'b2', 'c', 'd'],
'age': pa.array([50, 5, 30, 8], type=pa.int32()),
},
schema=self.pa_schema,
)
metrics = merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
when_not_matched=[
WhenNotMatched(insert='*', condition='s.age > 10')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a2', 'b', 'c'])
self.assertEqual(out['age'], [50, 20, 30])
self.assertEqual(metrics['num_matched'], 1)
self.assertEqual(metrics['num_inserted'], 1)
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_condition_no_rows_match_is_noop(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a2', 'b2'],
'age': pa.array([5, 5], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update='*', condition='s.age > t.age')],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['name'], ['a', 'b'])
self.assertEqual(out['age'], [10, 20])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_duplicate_source_filtered_by_condition(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 1], type=pa.int32()),
'name': ['x', 'y'],
'age': pa.array([5, 20], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > t.age')
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1])
self.assertEqual(out['name'], ['y'])
self.assertEqual(out['age'], [20])
def test_matched_partial_update(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a2', 'b2'],
'age': pa.array([99, 88], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'age': 's.age'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['name'], ['a', 'b'])
self.assertEqual(out['age'], [99, 88])
def test_insert_partial_mapping(self):
target = self._create_table()
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert={'id': 's.id', 'name': 's.name'})
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['name'], ['a', 'b'])
self.assertEqual(out['age'], [None, None])
def test_update_with_literal(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['ignored'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'name': 'updated'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['updated'])
self.assertEqual(out['age'], [10])
def test_invalid_target_column_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'nonexistent': 's.id'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('nonexistent', str(ctx.exception))
def test_invalid_target_ref_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'name': 't.nme'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('nme', str(ctx.exception))
def test_empty_mapping_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError):
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={})],
num_partitions=_TEST_NUM_PARTITIONS,
)
def test_insert_target_ref_rejected(self):
target = self._create_table()
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert={'name': 't.name'})
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('t.', str(ctx.exception))
def test_matched_update_with_target_ref(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['ignored'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'age': 's.age', 'name': 't.name'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['old'])
self.assertEqual(out['age'], [99])
def test_callable_value_rejected(self):
target = self._create_table()
with self.assertRaises(TypeError):
merge_into(
target=target,
source=self._source(),
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'name': lambda r: r})],
num_partitions=_TEST_NUM_PARTITIONS,
)
def test_source_missing_referenced_col(self):
target = self._create_table()
source = pa.Table.from_pydict(
{'id': pa.array([1], type=pa.int32())},
schema=pa.schema([('id', pa.int32())]),
)
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={'name': 's.name'})],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('name', str(ctx.exception))
def test_partial_insert_auto_fills_on_key(self):
target = self._create_table()
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert={'name': 's.name'})
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['name'], ['a', 'b'])
def test_partial_insert_renamed_on_key_auto_filled(self):
target = self._create_table()
source_schema = pa.schema([
('uid', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
source = pa.Table.from_pydict(
{
'uid': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=source_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on={'id': 'uid'},
when_not_matched=[
WhenNotMatched(insert={'name': 's.name'})
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['name'], ['a', 'b'])
def test_explicit_source_ref_not_remapped_by_on_key(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source_schema = pa.schema([
('uid', pa.int32()),
('id', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
source = pa.Table.from_pydict(
{
'uid': pa.array([1], type=pa.int32()),
'id': pa.array([42], type=pa.int32()),
'name': ['new'],
'age': pa.array([99], type=pa.int32()),
},
schema=source_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on={'id': 'uid'},
when_matched=[WhenMatched(update={
'age': source_col('id'),
})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['age'], [42])
self.assertEqual(out['name'], ['old'])
def test_renamed_on_key_missing_source_col_rejected(self):
target = self._create_table()
source_schema = pa.schema([
('uid', pa.int32()),
('name', pa.string()),
('age', pa.int32()),
])
source = pa.Table.from_pydict(
{
'uid': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=source_schema,
)
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on={'id': 'uid'},
when_matched=[WhenMatched(update={
'id': source_col('id'),
})],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('id', str(ctx.exception))
def test_lit_prevents_column_ref_interpretation(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['ignored'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={
'name': lit('s.active'),
})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['s.active'])
self.assertEqual(out['age'], [10])
def test_source_col_helper(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['new'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={
'age': source_col('age'),
})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['old'])
self.assertEqual(out['age'], [99])
def test_target_col_helper(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['keep'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['ignored'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[WhenMatched(update={
'age': source_col('age'),
'name': target_col('name'),
})],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['keep'])
self.assertEqual(out['age'], [99])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_matched_clause_fall_through(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a2', 'b2', 'c2'],
'age': pa.array([99, 88, 77], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > 80'),
WhenMatched(update='*'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
self.assertEqual(out['age'], [99, 88, 77])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_not_matched_clause_fall_through(self):
target = self._create_table()
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([25, 15, 5], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert='*', condition='s.age >= 20'),
WhenNotMatched(insert='*'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_matched_null_falls_through(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a2', 'b2', 'c2'],
'age': pa.array([None, 50, 60], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > 40'),
WhenMatched(update='*'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
self.assertEqual(out['age'], [None, 50, 60])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_not_matched_null_falls_through(self):
target = self._create_table()
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([None, 25], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_not_matched=[
WhenNotMatched(insert='*', condition='s.age > 20'),
WhenNotMatched(insert='*'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2])
self.assertEqual(out['age'], [None, 25])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_clause_no_match_skipped(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a2', 'b2'],
'age': pa.array([5, 5], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > 50'),
WhenMatched(update='*', condition='s.age > 30'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['a', 'b'])
self.assertEqual(out['age'], [10, 20])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_clause_first_wins(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['old'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['first'],
'age': pa.array([99], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update={'name': 's.name'},
condition='s.age > 50'),
WhenMatched(update={'age': 's.age'},
condition='s.age > 10'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['first'])
self.assertEqual(out['age'], [10])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_clause_duplicate_source_one_actionable(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 1], type=pa.int32()),
'name': ['x', 'y'],
'age': pa.array([99, 5], type=pa.int32()),
},
schema=self.pa_schema,
)
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > 50'),
WhenMatched(update='*', condition='s.age > 80'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['x'])
self.assertEqual(out['age'], [99])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_multi_clause_duplicate_both_actionable_raises(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'age': pa.array([10], type=pa.int32()),
},
schema=self.pa_schema,
),
)
source = pa.Table.from_pydict(
{
'id': pa.array([1, 1], type=pa.int32()),
'name': ['x', 'y'],
'age': pa.array([99, 50], type=pa.int32()),
},
schema=self.pa_schema,
)
with self.assertRaises(Exception) as ctx:
merge_into(
target=target,
source=source,
catalog_options=self.catalog_options,
on=['id'],
when_matched=[
WhenMatched(update='*', condition='s.age > 80'),
WhenMatched(update='*', condition='s.age > 30'),
],
num_partitions=_TEST_NUM_PARTITIONS,
)
self.assertIn('multiple source rows', str(ctx.exception))
def test_self_merge_update_literal(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update={'age': lit(99)})],
)
self.assertEqual(result['num_matched'], 3)
out = self._read_sorted(target)
self.assertEqual(out['age'], [99, 99, 99])
self.assertEqual(out['name'], ['a', 'b', 'c'])
def test_self_merge_update_star(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update='*')],
)
self.assertEqual(result['num_matched'], 3)
out = self._read_sorted(target)
self.assertEqual(out['id'], [1, 2, 3])
self.assertEqual(out['name'], ['a', 'b', 'c'])
self.assertEqual(out['age'], [10, 20, 30])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_with_condition(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update={'age': lit(99)}, condition='t.age > 15')],
)
self.assertEqual(result['num_matched'], 2)
out = self._read_sorted(target)
self.assertEqual(out['age'], [10, 99, 99])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_with_source_condition(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(
update={'name': lit('updated')},
condition='s.age > 15',
)],
)
self.assertEqual(result['num_matched'], 2)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['a', 'updated', 'updated'])
self.assertEqual(out['age'], [10, 20, 30])
def test_self_merge_rejects_not_matched(self):
target = self._create_table()
self._write(target, self._source(ids=(1,)))
with self.assertRaises(ValueError) as ctx:
merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update='*')],
when_not_matched=[WhenNotMatched(insert='*')],
)
self.assertIn('Self-merge', str(ctx.exception))
def test_self_merge_partial_set(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['old_a', 'old_b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update={'name': lit('updated')})],
)
self.assertEqual(result['num_matched'], 2)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['updated', 'updated'])
self.assertEqual(out['age'], [10, 20])
def test_self_merge_source_col_row_id(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'age': pa.array([10, 20], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[WhenMatched(update={'name': source_col('_ROW_ID')})],
)
self.assertEqual(result['num_matched'], 2)
out = self._read_sorted(target)
for v in out['name']:
self.assertTrue(int(v) >= 0)
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_condition_on_row_id(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[
WhenMatched(
update={'age': lit(99)},
condition='s._ROW_ID >= 0',
),
],
)
self.assertEqual(result['num_matched'], 3)
out = self._read_sorted(target)
self.assertEqual(out['age'], [99, 99, 99])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_condition_on_target_row_id(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[
WhenMatched(
update={'age': lit(99)},
condition='t._ROW_ID >= 0',
),
],
)
self.assertEqual(result['num_matched'], 3)
out = self._read_sorted(target)
self.assertEqual(out['age'], [99, 99, 99])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_multi_clause_fall_through(self):
target = self._create_table()
self._write(
target,
pa.Table.from_pydict(
{
'id': pa.array([1, 2, 3], type=pa.int32()),
'name': ['a', 'b', 'c'],
'age': pa.array([10, 20, 30], type=pa.int32()),
},
schema=self.pa_schema,
),
)
result = merge_into(
target=target,
source=target,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[
WhenMatched(update={'name': lit('old')}, condition='s.age <= 10'),
WhenMatched(update={'name': lit('young')}, condition='s.age <= 20'),
WhenMatched(update={'name': lit('senior')}),
],
)
self.assertEqual(result['num_matched'], 3)
out = self._read_sorted(target)
self.assertEqual(out['name'], ['old', 'young', 'senior'])
self.assertEqual(out['age'], [10, 20, 30])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_blob_source_condition(self):
blob_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('picture', pa.large_binary()),
])
tbl_name = f'default.tbl_{uuid.uuid4().hex[:8]}'
s = Schema.from_pyarrow_schema(blob_schema, options=self.de_options)
self.catalog.create_table(tbl_name, s, False)
self._write(
tbl_name,
pa.Table.from_pydict(
{
'id': pa.array([1, 2], type=pa.int32()),
'name': ['a', 'b'],
'picture': [None, None],
},
schema=blob_schema,
),
)
result = merge_into(
target=tbl_name,
source=tbl_name,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[
WhenMatched(
update={'name': lit('updated')},
condition='s.picture IS NULL',
),
],
)
self.assertEqual(result['num_matched'], 2)
out = self._read_sorted(tbl_name)
self.assertEqual(out['name'], ['updated', 'updated'])
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_self_merge_blob_target_condition_rejected(self):
blob_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('picture', pa.large_binary()),
])
tbl_name = f'default.tbl_{uuid.uuid4().hex[:8]}'
s = Schema.from_pyarrow_schema(blob_schema, options=self.de_options)
self.catalog.create_table(tbl_name, s, False)
self._write(
tbl_name,
pa.Table.from_pydict(
{
'id': pa.array([1], type=pa.int32()),
'name': ['a'],
'picture': [None],
},
schema=blob_schema,
),
)
with self.assertRaises(ValueError) as ctx:
merge_into(
target=tbl_name,
source=tbl_name,
catalog_options=self.catalog_options,
on=['_ROW_ID'],
when_matched=[
WhenMatched(
update={'name': lit('x')},
condition='t.picture IS NOT NULL',
),
],
)
self.assertIn('blob', str(ctx.exception).lower())
class TargetProjectionTest(unittest.TestCase):
def _clause(self, spec, condition=None):
from pypaimon.ray import data_evolution_merge_into as m
return m._NormalizedClause(spec=spec, condition=condition)
def test_unconditional_set_excludes_target_update_col(self):
from pypaimon.ray import data_evolution_merge_into as m
cols = m._resolve_target_projection(
[self._clause({'feature': 's.feature'})],
['id'], ['feature'], ['id', 'feature', 'image'],
)
self.assertEqual(['id'], cols)
def test_condition_adds_referenced_target_cols(self):
from pypaimon.ray import data_evolution_merge_into as m
cols = m._resolve_target_projection(
[self._clause({'feature': 's.feature'}, condition='s.age > t.age')],
['id'], ['feature'], ['id', 'feature', 'age', 'image'],
)
self.assertIn('age', cols)
self.assertIn('id', cols)
class MergeConditionUnitTest(unittest.TestCase):
def test_rewrite_condition(self):
from pypaimon.ray.merge_condition import rewrite_condition
self.assertEqual(
rewrite_condition('s.age > t.age + 10'),
'"s.age" > "t.age" + 10',
)
def test_rewrite_condition_preserves_string_literals(self):
from pypaimon.ray.merge_condition import rewrite_condition
self.assertEqual(
rewrite_condition("s.status = 't.active' AND s.age > t.age"),
'"s.status" = \'t.active\' AND "s.age" > "t.age"',
)
def test_remap_source_on_keys(self):
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = rewrite_condition('s.id > 1 AND s.age > t.age')
remapped = remap_source_on_keys(rewritten, {'id': 'id'})
self.assertEqual(remapped, '"t.id" > 1 AND "s.age" > "t.age"')
def test_remap_source_on_keys_renamed(self):
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = rewrite_condition('s.uid > 1')
remapped = remap_source_on_keys(rewritten, {'uid': 'id'})
self.assertEqual(remapped, '"t.id" > 1')
def test_remap_preserves_string_literals(self):
from pypaimon.ray.merge_condition import (
remap_source_on_keys, rewrite_condition,
)
rewritten = rewrite_condition("s.note = '\"s.id\"' AND s.id = 1")
remapped = remap_source_on_keys(rewritten, {'id': 'id'})
self.assertEqual(
remapped,
'"s.note" = \'\"s.id\"\' AND "t.id" = 1',
)
def test_extract_target_columns(self):
from pypaimon.ray.merge_condition import extract_target_columns
self.assertEqual(
extract_target_columns('s.name = t.name AND s.age > t.age'),
{'name', 'age'},
)
def test_extract_target_columns_ignores_string_literals(self):
from pypaimon.ray.merge_condition import extract_target_columns
self.assertEqual(
extract_target_columns("s.name = 't.fake' AND s.age > t.age"),
{'age'},
)
def test_extract_columns(self):
from pypaimon.ray.merge_condition import extract_columns
self.assertEqual(
extract_columns('s.id = t.id AND s.age > t.age'),
{'s.id', 't.id', 's.age', 't.age'},
)
@unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
def test_filter_batch(self):
from pypaimon.ray.merge_condition import filter_batch
batch = pa.table({
's.id': pa.array([1, 2, 3], type=pa.int32()),
's.age': pa.array([10, 25, 30], type=pa.int32()),
't.age': pa.array([20, 20, 20], type=pa.int32()),
})
result = filter_batch(batch, 's.age > t.age')
self.assertEqual(result.column('s.id').to_pylist(), [2, 3])
if __name__ == '__main__':
unittest.main()