blob: 0d557799e9c2d0e4dfa8ada051494d5732c15f50 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""End-to-end tests for the ``aggregation`` merge engine.
Each test creates a PK table with ``merge-engine=aggregation`` plus
per-field aggregator configuration, writes two or more commits against
the same PK, and reads back. The aggregation engine must reduce each
non-PK column independently using the configured aggregator (sum / max
/ last_value / ...). Disjoint PKs must remain unmerged. Default
behaviour when no aggregator is configured is ``last_non_null_value``.
The second half of the file exercises the merge-engine-support guard:
tables that configure aggregation with options pypaimon does not yet
implement (retract opt-ins, sequence-group, out-of-scope aggregator
identifiers) must raise ``NotImplementedError`` at TableRead
construction rather than silently fall back to a wrong answer.
"""
import os
import shutil
import tempfile
import unittest
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
class AggregationMergeEngineE2ETest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', True)
cls.pa_schema = pa.schema([
pa.field('id', pa.int64(), nullable=False),
('total', pa.int64()),
('max_score', pa.int64()),
('label', pa.string()),
])
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_pk_table(self, table_name, field_aggs=None,
default_agg=None, extra_options=None):
# bucket=1 forces all rows for a given PK to land in the same
# bucket, which routes reads through SortMergeReader where the
# aggregation merge function lives. Without it, fresh
# single-snapshot tables take the raw_convertible fast path and
# bypass the merge function entirely.
options = {
'bucket': '1',
'merge-engine': 'aggregation',
}
if field_aggs:
for field_name, agg_func in field_aggs.items():
options['fields.{}.aggregate-function'.format(field_name)] = agg_func
if default_agg:
options['fields.default-aggregate-function'] = default_agg
if extra_options:
options.update(extra_options)
schema = Schema.from_pyarrow_schema(
self.pa_schema,
primary_keys=['id'],
options=options,
)
full = 'default.{}'.format(table_name)
self.catalog.create_table(full, schema, False)
return self.catalog.get_table(full)
def _write(self, table, rows):
wb = table.new_batch_write_builder()
w = wb.new_write()
c = wb.new_commit()
try:
w.write_arrow(pa.Table.from_pylist(rows, schema=self.pa_schema))
c.commit(w.prepare_commit())
finally:
w.close()
c.close()
def _read(self, table):
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
if not splits:
return []
return sorted(
rb.new_read().to_arrow(splits).to_pylist(),
key=lambda r: r['id'],
)
# -- aggregation happy path -----------------------------------------
def test_sum_aggregator_across_commits(self):
table = self._create_pk_table(
'agg_sum',
field_aggs={'total': 'sum'},
)
self._write(table, [{'id': 1, 'total': 10, 'max_score': 5, 'label': 'a'}])
self._write(table, [{'id': 1, 'total': 20, 'max_score': 3, 'label': 'b'}])
self._write(table, [{'id': 1, 'total': 30, 'max_score': 8, 'label': 'c'}])
rows = self._read(table)
self.assertEqual(len(rows), 1)
row = rows[0]
self.assertEqual(row['id'], 1)
# total: 10 + 20 + 30 = 60
self.assertEqual(row['total'], 60)
# max_score and label have no aggregator configured → default
# last_non_null_value: latest non-null wins.
self.assertEqual(row['max_score'], 8)
self.assertEqual(row['label'], 'c')
def test_multiple_aggregators_compose(self):
table = self._create_pk_table(
'agg_multi',
field_aggs={
'total': 'sum',
'max_score': 'max',
'label': 'last_value',
},
)
self._write(table, [{'id': 1, 'total': 10, 'max_score': 5, 'label': 'a'}])
self._write(table, [{'id': 1, 'total': 7, 'max_score': 12, 'label': 'b'}])
self._write(table, [{'id': 1, 'total': 3, 'max_score': 1, 'label': 'c'}])
row = self._read(table)[0]
self.assertEqual(row['total'], 20) # sum: 10+7+3
self.assertEqual(row['max_score'], 12) # max: max(5,12,1)
self.assertEqual(row['label'], 'c') # last_value
def test_null_inputs_follow_aggregator_semantics(self):
table = self._create_pk_table(
'agg_nulls',
field_aggs={
'total': 'sum',
'max_score': 'last_value',
},
)
self._write(table, [{'id': 1, 'total': 5, 'max_score': 7, 'label': 'x'}])
# null total is absorbed by sum; null max_score replaces under
# last_value (last_value keeps the last input verbatim,
# including None).
self._write(table, [{'id': 1, 'total': None, 'max_score': None, 'label': None}])
self._write(table, [{'id': 1, 'total': 4, 'max_score': 9, 'label': 'y'}])
row = self._read(table)[0]
self.assertEqual(row['total'], 9) # 5 + 4 (None absorbed)
self.assertEqual(row['max_score'], 9) # last_value's last input
# label: default last_non_null_value, intermediate None ignored,
# the final 'y' wins.
self.assertEqual(row['label'], 'y')
def test_disjoint_keys_remain_unmerged(self):
table = self._create_pk_table(
'agg_disjoint',
field_aggs={'total': 'sum'},
)
self._write(table, [
{'id': 1, 'total': 10, 'max_score': 1, 'label': 'a'},
{'id': 2, 'total': 20, 'max_score': 2, 'label': 'b'},
{'id': 3, 'total': 30, 'max_score': 3, 'label': 'c'},
])
# Second commit only touches id=2.
self._write(table, [{'id': 2, 'total': 5, 'max_score': 7, 'label': 'B'}])
rows = self._read(table)
self.assertEqual(rows, [
{'id': 1, 'total': 10, 'max_score': 1, 'label': 'a'},
{'id': 2, 'total': 25, 'max_score': 7, 'label': 'B'},
{'id': 3, 'total': 30, 'max_score': 3, 'label': 'c'},
])
def test_default_aggregator_applies_to_unconfigured_fields(self):
table = self._create_pk_table(
'agg_default',
default_agg='max',
)
self._write(table, [{'id': 1, 'total': 3, 'max_score': 5, 'label': 'm'}])
self._write(table, [{'id': 1, 'total': 7, 'max_score': 2, 'label': 'a'}])
self._write(table, [{'id': 1, 'total': 1, 'max_score': 9, 'label': 'z'}])
row = self._read(table)[0]
# All non-PK fields fall through to fields.default-aggregate-function=max.
self.assertEqual(row['total'], 7)
self.assertEqual(row['max_score'], 9)
self.assertEqual(row['label'], 'z') # 'z' > 'm' > 'a' lexicographically
def test_default_behavior_is_last_non_null_value(self):
# No field-level or default aggregator configured → every non-PK
# field uses the system default last_non_null_value.
table = self._create_pk_table('agg_implicit_default')
self._write(table, [{'id': 1, 'total': 5, 'max_score': 9, 'label': 'a'}])
self._write(table, [{'id': 1, 'total': None, 'max_score': 3, 'label': None}])
self._write(table, [{'id': 1, 'total': 7, 'max_score': None, 'label': 'b'}])
row = self._read(table)[0]
self.assertEqual(row['total'], 7) # latest non-null
self.assertEqual(row['max_score'], 3) # latest non-null
self.assertEqual(row['label'], 'b') # latest non-null
# -- unsupported-option guards --------------------------------------
#
# Tables that opt into behaviour AggregateMergeFunction doesn't
# implement must surface a NotImplementedError at TableRead
# construction, not silently produce wrong results.
def _create_and_expect_unsupported(self, table_name, extra_options,
expected_substring,
error_type=NotImplementedError):
table = self._create_pk_table(
table_name, extra_options=extra_options
)
# Writing is fine — the guard fires when a reader is built.
self._write(table, [{'id': 1, 'total': 1, 'max_score': 1, 'label': 'a'}])
rb = table.new_read_builder()
with self.assertRaises(error_type) as cm:
rb.new_read()
msg = str(cm.exception)
if error_type is NotImplementedError:
self.assertIn('aggregation', msg)
self.assertIn(expected_substring, msg)
def test_remove_record_on_delete_rejected(self):
self._create_and_expect_unsupported(
'agg_reject_remove_on_delete',
{'aggregation.remove-record-on-delete': 'true'},
'aggregation.remove-record-on-delete',
)
def test_field_ignore_retract_rejected(self):
self._create_and_expect_unsupported(
'agg_reject_ignore_retract',
{'fields.total.ignore-retract': 'true'},
'fields.total.ignore-retract',
)
def test_sequence_field_supported(self):
# Top-level sequence.field is honored by the aggregation engine:
# aggregators fold in sequence-field order, not file order. Here
# ``last_value`` must pick the value from the highest-``total`` row
# even though it was written first.
table = self._create_pk_table(
'agg_sequence_field',
field_aggs={'max_score': 'last_value', 'label': 'last_value'},
extra_options={'sequence.field': 'total'},
)
self._write(table, [{'id': 1, 'total': 100, 'max_score': 9, 'label': 'hi'}])
self._write(table, [{'id': 1, 'total': 50, 'max_score': 1, 'label': 'lo'}])
self.assertEqual(
self._read(table),
[{'id': 1, 'total': 100, 'max_score': 9, 'label': 'hi'}],
)
def test_aggregate_function_on_sequence_field_rejected(self):
# An explicit aggregator on the sequence column is invalid: Java
# rejects fields.<seq>.aggregate-function in
# SchemaValidation.validateSequenceField. Rather than silently
# override 'sum' with last_value, the guard must reject it.
self._create_and_expect_unsupported(
'agg_reject_agg_on_seq',
{'sequence.field': 'total',
'fields.total.aggregate-function': 'sum'},
'fields.total.aggregate-function',
error_type=ValueError,
)
def test_field_sequence_group_rejected(self):
self._create_and_expect_unsupported(
'agg_reject_sequence_group',
{'fields.max_score.sequence-group': 'label'},
'fields.max_score.sequence-group',
)
def test_out_of_scope_field_aggregator_rejected(self):
# collect is one of the aggregator identifiers this engine
# doesn't support yet. The guard must reject the config rather
# than let the per-field factory build a (silently wrong)
# fallback.
self._create_and_expect_unsupported(
'agg_reject_collect',
{'fields.label.aggregate-function': 'collect'},
'fields.label.aggregate-function',
)
def test_out_of_scope_default_aggregator_rejected(self):
self._create_and_expect_unsupported(
'agg_reject_default_collect',
{'fields.default-aggregate-function': 'product'},
'fields.default-aggregate-function',
)
def test_supported_field_aggregator_passes_guard(self):
# Sanity check: setting one of the supported aggregators does
# NOT trip the guard introduced for out-of-scope identifiers.
table = self._create_pk_table(
'agg_supported_passes',
field_aggs={'total': 'sum'},
)
self._write(table, [{'id': 1, 'total': 1, 'max_score': 1, 'label': 'a'}])
# If the guard wrongly flagged 'sum', new_read() would raise.
# Touch it explicitly so the test fails loudly otherwise.
table.new_read_builder().new_read()
# -- partition column that is also part of the primary key ----------
def test_partition_pk_overlap_not_aggregated_by_default(self):
# When a partition column is also part of the primary key and a
# table-wide ``fields.default-aggregate-function`` is configured,
# the partition-PK column must be treated as PK (identity) and
# not run through the default aggregator. Regression for the
# split_read bug where the trimmed PK list (which drops
# partition columns) was passed to ``build_field_aggregators``.
pa_schema = pa.schema([
pa.field('p', pa.int64(), nullable=False),
pa.field('id', pa.int64(), nullable=False),
pa.field('v', pa.int64()),
])
schema = Schema.from_pyarrow_schema(
pa_schema,
primary_keys=['p', 'id'],
partition_keys=['p'],
options={
'bucket': '1',
'merge-engine': 'aggregation',
'fields.default-aggregate-function': 'sum',
},
)
self.catalog.create_table(
'default.agg_partition_pk_overlap', schema, False)
table = self.catalog.get_table('default.agg_partition_pk_overlap')
def write(rows):
wb = table.new_batch_write_builder()
w = wb.new_write()
c = wb.new_commit()
try:
w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
c.commit(w.prepare_commit())
finally:
w.close()
c.close()
write([{'p': 1, 'id': 1, 'v': 10}])
write([{'p': 1, 'id': 1, 'v': 20}])
rb = table.new_read_builder()
splits = rb.new_scan().plan().splits()
rows = rb.new_read().to_arrow(splits).to_pylist()
self.assertEqual(rows, [{'p': 1, 'id': 1, 'v': 30}])
if __name__ == '__main__':
unittest.main()