blob: 7ff00fbb9b627ed3bcaa477d5568d2485bca355e [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Direct unit tests for :class:`AggregateMergeFunction` and its
helper functions.
Drives the merge function with synthetic :class:`KeyValue` instances
so the contract is pinned down without going through the full read
pipeline. The end-to-end behaviour on real PK tables is exercised
separately in ``test_aggregation_e2e.py``.
"""
import unittest
from pypaimon.common.options.core_options import CoreOptions
from pypaimon.common.options.options import Options
from pypaimon.read.reader.aggregate import create_field_aggregator
from pypaimon.read.reader.aggregation_merge_function import (
AggregateMergeFunction,
build_field_aggregators,
resolve_agg_func_name,
)
from pypaimon.schema.data_types import AtomicType, DataField
from pypaimon.table.row.key_value import KeyValue
from pypaimon.table.row.row_kind import RowKind
def _kv(key, seq, row_kind, value):
"""Build a fresh KeyValue for a (key, sequence, row_kind, value)
tuple. Same shape as the helper in
``test_partial_update_merge_function.py`` so both test files read
consistently.
"""
kv = KeyValue(key_arity=len(key), value_arity=len(value))
kv.replace(tuple(key) + (seq, row_kind.value) + tuple(value))
return kv
def _result_value(kv):
return tuple(kv.value.get_field(i) for i in range(kv.value_arity))
def _result_key(kv):
return tuple(kv.key.get_field(i) for i in range(kv.key_arity))
def _make_agg(identifier, sql_type, field_name="f"):
return create_field_aggregator(
AtomicType(sql_type), field_name, identifier, options=None
)
class AggregateMergeFunctionTest(unittest.TestCase):
def test_single_insert_returns_aggregated_value(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
mf.add(_kv((1,), 100, RowKind.INSERT, (10,)))
result = mf.get_result()
self.assertIsNotNone(result)
self.assertEqual(_result_key(result), (1,))
self.assertEqual(_result_value(result), (10,))
self.assertEqual(result.sequence_number, 100)
self.assertEqual(result.value_row_kind_byte, RowKind.INSERT.value)
def test_multi_row_aggregation_across_fields(self):
# field 0: sum, field 1: max, field 2: last_value
mf = AggregateMergeFunction(
key_arity=1, value_arity=3,
field_aggregators=[
_make_agg("sum", "BIGINT", "v_sum"),
_make_agg("max", "INT", "v_max"),
_make_agg("last_value", "VARCHAR", "v_last"),
],
)
mf.reset()
mf.add(_kv((1,), 100, RowKind.INSERT, (10, 5, "a")))
mf.add(_kv((1,), 101, RowKind.INSERT, (20, 3, "b")))
mf.add(_kv((1,), 102, RowKind.INSERT, (30, 9, "c")))
result = mf.get_result()
self.assertEqual(_result_value(result), (60, 9, "c"))
# latest sequence is propagated through.
self.assertEqual(result.sequence_number, 102)
def test_null_inputs_follow_aggregator_semantics(self):
# sum drops nulls; last_non_null_value drops nulls; last_value keeps them.
mf = AggregateMergeFunction(
key_arity=1, value_arity=3,
field_aggregators=[
_make_agg("sum", "BIGINT", "v_sum"),
_make_agg("last_non_null_value", "VARCHAR", "v_lnn"),
_make_agg("last_value", "VARCHAR", "v_last"),
],
)
mf.reset()
mf.add(_kv((1,), 100, RowKind.INSERT, (10, "x", "x")))
mf.add(_kv((1,), 101, RowKind.INSERT, (None, None, None)))
mf.add(_kv((1,), 102, RowKind.INSERT, (5, None, "z")))
result = mf.get_result()
# sum: 10 + 5 = 15 (nulls absorbed)
# last_non_null: 'x' (intermediate nulls preserved earlier value)
# last_value: 'z' (the very last value, including the prior None)
self.assertEqual(_result_value(result), (15, "x", "z"))
def test_update_after_treated_as_insert(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
mf.add(_kv((1,), 100, RowKind.UPDATE_AFTER, (7,)))
result = mf.get_result()
self.assertEqual(_result_value(result), (7,))
def test_delete_row_raises_not_implemented(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
with self.assertRaises(NotImplementedError) as ctx:
mf.add(_kv((1,), 100, RowKind.DELETE, (5,)))
self.assertIn("retract", str(ctx.exception))
self.assertIn("-D", str(ctx.exception))
def test_update_before_row_raises_not_implemented(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
with self.assertRaises(NotImplementedError) as ctx:
mf.add(_kv((1,), 100, RowKind.UPDATE_BEFORE, (5,)))
self.assertIn("-U", str(ctx.exception))
def test_reset_between_keys_clears_state(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=2,
field_aggregators=[
_make_agg("sum", "BIGINT"),
_make_agg("first_value", "VARCHAR"),
],
)
# Key group 1.
mf.reset()
mf.add(_kv((1,), 100, RowKind.INSERT, (5, "a")))
mf.add(_kv((1,), 101, RowKind.INSERT, (3, "b")))
r1 = mf.get_result()
self.assertEqual(_result_value(r1), (8, "a"))
# Key group 2 — sum must restart from 0 and first_value must
# re-arm so the new group's first row wins.
mf.reset()
mf.add(_kv((2,), 200, RowKind.INSERT, (10, "x")))
mf.add(_kv((2,), 201, RowKind.INSERT, (20, "y")))
r2 = mf.get_result()
self.assertEqual(_result_key(r2), (2,))
self.assertEqual(_result_value(r2), (30, "x"))
def test_get_result_before_any_add_returns_none(self):
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
self.assertIsNone(mf.get_result())
def test_result_is_decoupled_from_input_kv(self):
"""Critical: upstream KeyValueWrapReader reuses one KeyValue and
rebinds its row_tuple between iterations. The merge function
must snapshot its output so the previously-returned result is
not silently mutated when the source advances.
"""
mf = AggregateMergeFunction(
key_arity=1, value_arity=1,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
mf.reset()
# Build one reusable kv and rebind it twice, like upstream does.
kv = KeyValue(key_arity=1, value_arity=1)
kv.replace((1, 100, RowKind.INSERT.value, 5))
mf.add(kv)
kv.replace((1, 101, RowKind.INSERT.value, 7))
mf.add(kv)
result = mf.get_result()
# Now mutate the source kv. The previously-captured result must
# NOT change.
kv.replace((999, 999, RowKind.INSERT.value, 99999))
self.assertEqual(_result_key(result), (1,))
self.assertEqual(_result_value(result), (12,)) # 5 + 7
self.assertEqual(result.sequence_number, 101)
def test_value_arity_mismatch_at_construction_raises(self):
with self.assertRaises(ValueError):
AggregateMergeFunction(
key_arity=1, value_arity=2,
field_aggregators=[_make_agg("sum", "BIGINT")],
)
class ResolveAggFuncNameTest(unittest.TestCase):
def test_primary_key_takes_precedence(self):
name = resolve_agg_func_name(
"id", primary_keys={"id"}, options_map={
"fields.id.aggregate-function": "sum",
}
)
self.assertEqual(name, "primary_key")
def test_field_level_override_wins_over_default(self):
name = resolve_agg_func_name(
"v", primary_keys=set(), options_map={
"fields.v.aggregate-function": "max",
"fields.default-aggregate-function": "sum",
}
)
self.assertEqual(name, "max")
def test_table_default_used_when_no_field_override(self):
name = resolve_agg_func_name(
"v", primary_keys=set(), options_map={
"fields.default-aggregate-function": "sum",
}
)
self.assertEqual(name, "sum")
def test_system_default_when_nothing_configured(self):
name = resolve_agg_func_name(
"v", primary_keys=set(), options_map={}
)
self.assertEqual(name, "last_non_null_value")
class BuildFieldAggregatorsTest(unittest.TestCase):
def _make_options(self, raw):
return CoreOptions(Options(raw))
def test_builds_aggregators_aligned_with_value_fields(self):
fields = [
DataField(0, "id", AtomicType("BIGINT")),
DataField(1, "amount", AtomicType("BIGINT")),
DataField(2, "name", AtomicType("VARCHAR")),
]
options = self._make_options({
"fields.amount.aggregate-function": "sum",
})
aggs = build_field_aggregators(
value_fields=fields,
primary_keys=["id"],
core_options=options,
)
self.assertEqual(len(aggs), 3)
self.assertEqual(aggs[0].name, "primary_key")
self.assertEqual(aggs[1].name, "sum")
# Falls through to the system default since no override and no
# fields.default-aggregate-function is set.
self.assertEqual(aggs[2].name, "last_non_null_value")
def test_unknown_aggregator_identifier_raises(self):
fields = [
DataField(0, "id", AtomicType("BIGINT")),
DataField(1, "v", AtomicType("BIGINT")),
]
options = self._make_options({
"fields.v.aggregate-function": "no_such_aggregator",
})
with self.assertRaises(ValueError):
build_field_aggregators(
value_fields=fields,
primary_keys=["id"],
core_options=options,
)
if __name__ == '__main__':
unittest.main()