| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """ |
| Three-layer correctness tests for predicate-driven bucket pruning. |
| |
| Mirrors Java's ``BucketSelectConverter`` contract: PK Equal/In queries on |
| HASH_FIXED tables must touch only the bucket(s) the writer would have |
| placed those keys in. Two correctness obligations: |
| |
| 1. Sound: every bucket retained by the selector contains AT MOST a |
| superset of matching rows. Buckets that DO contain matching rows |
| are NEVER dropped — false-negative-free. |
| 2. Hash-consistent with writers: ``RowKeyExtractor`` (writer) and |
| ``BucketSelectConverter`` (reader) must agree on every literal. |
| This is what makes ``pk = 'X'`` read the bucket holding 'X'. |
| |
| Layered: |
| * Unit — direct calls to ``create_bucket_selector`` with crafted |
| predicates, asserting selector behaviour. |
| * Integration — real PK tables with multiple buckets; queries; assert |
| (a) result correctness, (b) bucket pruning happened. |
| * Property — randomly-seeded PK tables, random Equal/In predicates, |
| result == oracle. No hypothesis dependency (keeps |
| Python 3.6 compat). |
| """ |
| |
| import os |
| import random |
| import shutil |
| import tempfile |
| import unittest |
| from typing import Any, Dict, List |
| |
| import pyarrow as pa |
| |
| from pypaimon import CatalogFactory, Schema |
| from pypaimon.common.predicate_builder import PredicateBuilder |
| from pypaimon.read.scanner.bucket_select_converter import ( |
| MAX_VALUES, create_bucket_selector) |
| from pypaimon.schema.data_types import AtomicType, DataField |
| from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor, |
| _bucket_from_hash, |
| _hash_bytes_by_words) |
| from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer |
| from pypaimon.table.row.internal_row import RowKind |
| |
| |
| def _bigint_field(idx: int, name: str) -> DataField: |
| return DataField(idx, name, AtomicType('BIGINT', nullable=False)) |
| |
| |
| def _field(idx: int, name: str, type_name: str) -> DataField: |
| return DataField(idx, name, AtomicType(type_name, nullable=False)) |
| |
| |
| def _hash_bucket(values: List[Any], fields: List[DataField], total: int) -> int: |
| """Re-implement the writer's hash so unit tests can compute the |
| expected bucket without spinning up a real table.""" |
| row = GenericRow(values, fields, RowKind.INSERT) |
| serialized = GenericRowSerializer.to_bytes(row) |
| h = _hash_bytes_by_words(serialized[4:]) |
| return _bucket_from_hash(h, total) |
| |
| |
| # --------------------------------------------------------------------------- |
| # Layer 1 — Unit: drive ``create_bucket_selector`` with crafted predicates. |
| # --------------------------------------------------------------------------- |
| class BucketSelectConverterUnitTest(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.id_field = _bigint_field(0, 'id') |
| cls.val_field = _bigint_field(1, 'val') |
| cls.k1 = _bigint_field(0, 'k1') |
| cls.k2 = _bigint_field(1, 'k2') |
| cls.pb_id_val = PredicateBuilder([cls.id_field, cls.val_field]) |
| cls.pb_k1_k2 = PredicateBuilder([cls.k1, cls.k2]) |
| |
| # -- Equal / In on single bucket key --------------------------------- |
| def test_equal_on_single_bucket_key_yields_single_bucket(self): |
| sel = create_bucket_selector( |
| self.pb_id_val.equal('id', 42), [self.id_field]) |
| self.assertIsNotNone(sel, "PK Equal must produce a selector") |
| expected = _hash_bucket([42], [self.id_field], total=8) |
| for b in range(8): |
| self.assertEqual( |
| sel(b, 8), b == expected, |
| "only bucket {} must be kept (got {})".format(expected, b)) |
| |
| def test_in_on_single_bucket_key_unions_buckets(self): |
| sel = create_bucket_selector( |
| self.pb_id_val.is_in('id', [1, 2, 3, 100]), [self.id_field]) |
| expected = {_hash_bucket([v], [self.id_field], 8) |
| for v in (1, 2, 3, 100)} |
| for b in range(8): |
| self.assertEqual(sel(b, 8), b in expected) |
| |
| def test_or_of_equals_on_same_field_unions_buckets(self): |
| # ``id = 1 OR id = 2`` must equal ``id IN (1, 2)``. |
| pred = PredicateBuilder.or_predicates([ |
| self.pb_id_val.equal('id', 1), |
| self.pb_id_val.equal('id', 2), |
| ]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)} |
| for b in range(8): |
| self.assertEqual(sel(b, 8), b in expected) |
| |
| # -- Composite bucket keys ------------------------------------------ |
| def test_composite_bucket_key_intersects_via_cartesian(self): |
| pred = PredicateBuilder.and_predicates([ |
| self.pb_k1_k2.is_in('k1', [1, 2]), |
| self.pb_k1_k2.equal('k2', 99), |
| ]) |
| sel = create_bucket_selector(pred, [self.k1, self.k2]) |
| expected = { |
| _hash_bucket([k1, 99], [self.k1, self.k2], 4) |
| for k1 in (1, 2) |
| } |
| for b in range(4): |
| self.assertEqual(sel(b, 4), b in expected) |
| |
| def test_composite_bucket_key_missing_one_field_returns_none(self): |
| pred = self.pb_k1_k2.equal('k1', 1) # k2 unconstrained |
| sel = create_bucket_selector(pred, [self.k1, self.k2]) |
| self.assertIsNone(sel, |
| "all bucket keys must be constrained or fall back") |
| |
| # -- Predicates that can't be reduced ------------------------------- |
| def test_non_bucket_key_predicate_returns_none(self): |
| sel = create_bucket_selector( |
| self.pb_id_val.equal('val', 5), [self.id_field]) |
| self.assertIsNone(sel, "predicate not on bucket key -> no selector") |
| |
| def test_range_predicate_on_bucket_key_returns_none(self): |
| sel = create_bucket_selector( |
| self.pb_id_val.greater_than('id', 100), [self.id_field]) |
| self.assertIsNone(sel, "ranges can't be turned into a finite bucket set") |
| |
| def test_or_with_non_bucket_key_returns_none(self): |
| # ``id = 1 OR val = 5`` — ``val`` isn't a bucket key, so the OR |
| # is not a pure bucket-key constraint. |
| pred = PredicateBuilder.or_predicates([ |
| self.pb_id_val.equal('id', 1), |
| self.pb_id_val.equal('val', 5), |
| ]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| self.assertIsNone(sel) |
| |
| def test_repeated_equal_on_same_key_with_empty_intersection_returns_none(self): |
| # ``id = 1 AND id = 2``: literal sets {1} and {2} intersect to |
| # empty; Java's ``retainAll`` would also bail here, since the |
| # predicate is unsatisfiable. |
| pred = PredicateBuilder.and_predicates([ |
| self.pb_id_val.equal('id', 1), |
| self.pb_id_val.equal('id', 2), |
| ]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| self.assertIsNone(sel) |
| |
| def test_repeated_in_on_same_key_intersects_literals(self): |
| # ``id IN (1,2,3) AND id IN (2,3,4)`` should now keep the |
| # intersection {2, 3} and prune to those buckets only. Used to |
| # bail with no selector before the Java parity fix. |
| pred = PredicateBuilder.and_predicates([ |
| self.pb_id_val.is_in('id', [1, 2, 3]), |
| self.pb_id_val.is_in('id', [2, 3, 4]), |
| ]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| self.assertIsNotNone(sel) |
| expected = {_hash_bucket([v], [self.id_field], 8) for v in (2, 3)} |
| for b in range(8): |
| self.assertEqual(sel(b, 8), b in expected) |
| |
| def test_and_with_unrelated_clause_is_unaffected(self): |
| # ``id = 7 AND val > 100`` — the ``val > 100`` part doesn't |
| # constrain buckets, but mustn't disqualify the ``id = 7`` part. |
| pred = PredicateBuilder.and_predicates([ |
| self.pb_id_val.equal('id', 7), |
| self.pb_id_val.greater_than('val', 100), |
| ]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| self.assertIsNotNone(sel) |
| expected = _hash_bucket([7], [self.id_field], 4) |
| for b in range(4): |
| self.assertEqual(sel(b, 4), b == expected) |
| |
| # -- Cap & degenerate edge cases ------------------------------------ |
| def test_cartesian_above_max_values_returns_none(self): |
| # Two columns of size > sqrt(MAX_VALUES) → product > MAX_VALUES. |
| size = 33 # 33 * 33 = 1089 > 1000 |
| pred = PredicateBuilder.and_predicates([ |
| self.pb_k1_k2.is_in('k1', list(range(size))), |
| self.pb_k1_k2.is_in('k2', list(range(size))), |
| ]) |
| self.assertGreater(size * size, MAX_VALUES) |
| sel = create_bucket_selector(pred, [self.k1, self.k2]) |
| self.assertIsNone(sel) |
| |
| def test_null_only_literal_drops_everything(self): |
| # ``id IN (NULL)`` after null-stripping has zero literals; the |
| # cartesian product is empty → selector matches no buckets. Same |
| # behaviour as Java. |
| pred = self.pb_id_val.is_in('id', [None]) |
| sel = create_bucket_selector(pred, [self.id_field]) |
| self.assertIsNotNone(sel) |
| for b in range(4): |
| self.assertFalse(sel(b, 4), |
| "all-null literal collapses bucket set to empty") |
| |
| def test_no_predicate_returns_none(self): |
| self.assertIsNone(create_bucket_selector(None, [self.id_field])) |
| |
| def test_no_bucket_keys_returns_none(self): |
| self.assertIsNone( |
| create_bucket_selector(self.pb_id_val.equal('id', 1), [])) |
| |
| # -- Selector cache + rescale ------------------------------------- |
| def test_selector_caches_per_total_buckets(self): |
| """Selector must answer correctly when the same query applies to |
| different ``total_buckets`` values (the rescale scenario).""" |
| sel = create_bucket_selector( |
| self.pb_id_val.equal('id', 42), [self.id_field]) |
| for total in (4, 8, 16, 32): |
| expected = _hash_bucket([42], [self.id_field], total) |
| self.assertTrue(sel(expected, total)) |
| other = (expected + 1) % total |
| self.assertFalse(sel(other, total)) |
| |
| def test_non_positive_total_buckets_fails_open(self): |
| """Manifest entries can carry ``total_buckets <= 0`` for legacy / |
| special bucket modes. Pruning MUST fail open — returning False |
| would silently drop rows the writer placed in those entries. |
| This is correctness, not performance: the soundness contract |
| forbids false-negatives.""" |
| sel = create_bucket_selector( |
| self.pb_id_val.equal('id', 1), [self.id_field]) |
| for total in (0, -1, -2): |
| self.assertTrue(sel(0, total), |
| "total_buckets={} must be kept (fail open)".format(total)) |
| self.assertTrue(sel(-1, total)) |
| self.assertTrue(sel(99, total)) |
| |
| # -- Bucket-key column types beyond BIGINT -------------------------- |
| def test_string_bucket_key_yields_correct_bucket(self): |
| """STRING uses a different ``GenericRowSerializer`` path (utf-8 |
| encode + variable-part offset) — verify writer/reader agree on |
| its byte layout independent of the BIGINT happy path.""" |
| sf = _field(0, 'sk', 'STRING') |
| vf = _bigint_field(1, 'val') |
| pb = PredicateBuilder([sf, vf]) |
| sel = create_bucket_selector(pb.equal('sk', 'hello'), [sf]) |
| self.assertIsNotNone(sel) |
| expected = _hash_bucket(['hello'], [sf], total=8) |
| for b in range(8): |
| self.assertEqual(sel(b, 8), b == expected) |
| |
| def test_int_bucket_key_yields_correct_bucket(self): |
| """INT (32-bit) and BIGINT (64-bit) hit different struct.pack |
| paths in the serializer — guard the smaller width.""" |
| intf = _field(0, 'i', 'INT') |
| vf = _bigint_field(1, 'val') |
| pb = PredicateBuilder([intf, vf]) |
| sel = create_bucket_selector(pb.equal('i', 7), [intf]) |
| self.assertIsNotNone(sel) |
| expected = _hash_bucket([7], [intf], total=4) |
| for b in range(4): |
| self.assertEqual(sel(b, 4), b == expected) |
| |
| # -- Hash-divergence-prone types refuse to build a selector -------- |
| def test_decimal_bucket_key_disables_pruning(self): |
| """DECIMAL columns risk silent hash divergence between writer |
| (Decimal) and reader-supplied ``float`` literals. Soundness |
| contract demands fail-open: refuse to build a selector at all.""" |
| df = _field(0, 'd', 'DECIMAL(10, 2)') |
| vf = _bigint_field(1, 'val') |
| pb = PredicateBuilder([df, vf]) |
| from decimal import Decimal |
| sel = create_bucket_selector(pb.equal('d', Decimal('1.50')), [df]) |
| self.assertIsNone( |
| sel, "DECIMAL bucket-key column must disable pruning") |
| |
| def test_array_bucket_key_disables_pruning(self): |
| """Composite types (ARRAY/MAP/ROW/MULTISET/VARIANT/BLOB) have no |
| cross-validated byte alignment with Java's ``BinaryRow`` — until |
| that exists, refuse to prune on them.""" |
| # Hand-roll a DataField whose AtomicType reports an ARRAY type |
| # name; the converter inspects ``field.type.type`` only. |
| af = DataField(0, 'arr', AtomicType('ARRAY<BIGINT>')) |
| vf = _bigint_field(1, 'val') |
| pb = PredicateBuilder([af, vf]) |
| sel = create_bucket_selector(pb.equal('arr', [1]), [af]) |
| self.assertIsNone( |
| sel, "ARRAY bucket-key column must disable pruning") |
| |
| def test_timestamp_bucket_key_disables_pruning(self): |
| """TIMESTAMP columns serialise via ``value.timestamp()`` whose |
| result depends on the process timezone for naive datetimes — |
| writer and reader running in different TZs would disagree.""" |
| tf = _field(0, 't', 'TIMESTAMP(3)') |
| vf = _bigint_field(1, 'val') |
| pb = PredicateBuilder([tf, vf]) |
| from datetime import datetime |
| sel = create_bucket_selector( |
| pb.equal('t', datetime(2026, 1, 1)), [tf]) |
| self.assertIsNone( |
| sel, "TIMESTAMP bucket-key column must disable pruning") |
| |
| def test_type_mismatched_literal_fails_open_not_crash(self): |
| """If the user constructs a predicate whose literal type doesn't |
| match the bucket-key column's atomic type — e.g. a STRING literal |
| on a BIGINT column — ``GenericRowSerializer`` raises during the |
| deferred hash inside ``_Selector``. The selector MUST swallow the |
| exception and fail open (return True for every bucket) rather |
| than propagate it. Crashing the entire scan with an opaque |
| ``struct.error`` is a worse user experience than silently |
| skipping bucket pruning, and the soundness contract still |
| forbids false-negatives.""" |
| sel = create_bucket_selector( |
| self.pb_id_val.equal('id', 'not-an-int'), [self.id_field]) |
| # Construction itself succeeds (no eager hashing). |
| self.assertIsNotNone(sel) |
| # Calling the selector must NOT raise; instead it returns True |
| # for every (bucket, total_buckets), preserving soundness. |
| for total in (4, 8): |
| for b in range(total): |
| self.assertTrue(sel(b, total), |
| "type-mismatched literal must fail open, " |
| "not crash (bucket={}, total={})".format(b, total)) |
| |
| |
| class PartitionAwareBucketSelectorUnitTest(unittest.TestCase): |
| """Unit tests for the per-partition predicate specialisation path. |
| |
| Covers ``replace_partition_predicate`` (the AND/OR fold walker) and |
| the partition-aware ``_Selector.__call__(partition, bucket, |
| total_buckets)`` 3-arg form that ``FileScanner._filter_manifest_entry`` |
| will use after wiring.""" |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.id_field = _bigint_field(0, 'id') |
| cls.part_field = DataField(2, 'part', AtomicType('STRING')) |
| cls.pb = PredicateBuilder([cls.id_field, cls.part_field]) |
| |
| # ----- replace_partition_predicate -------------------------------- |
| |
| def test_replace_partition_leaf_to_true_drops_constraint(self): |
| from pypaimon.read.scanner.bucket_select_converter import \ |
| replace_partition_predicate |
| # ``part = 'a' AND id = 1`` against partition {part: 'a'} → |
| # part leaf becomes True → AND fold removes it → only ``id = 1`` |
| pred = PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]) |
| result = replace_partition_predicate( |
| pred, {'part'}, {'part': 'a'}) |
| self.assertTrue(isinstance(result, type(pred)), |
| "AND should fold to a remaining single leaf") |
| self.assertEqual(result.method, 'equal') |
| self.assertEqual(result.field, 'id') |
| |
| def test_replace_partition_leaf_to_false_collapses_and(self): |
| from pypaimon.read.scanner.bucket_select_converter import \ |
| replace_partition_predicate |
| # ``part = 'a' AND id = 1`` against partition {part: 'b'} → |
| # part leaf becomes False → AND collapses to AlwaysFalse (False). |
| pred = PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]) |
| result = replace_partition_predicate( |
| pred, {'part'}, {'part': 'b'}) |
| self.assertIs(result, False) |
| |
| def test_replace_partition_leaf_in_or_keeps_other_branch(self): |
| from pypaimon.read.scanner.bucket_select_converter import \ |
| replace_partition_predicate |
| # ``(part='a' AND id=1) OR (part='b' AND id=2)`` against |
| # partition {part: 'a'} → first OR child becomes ``id=1``, second |
| # collapses to AlwaysFalse and is dropped. Result is just ``id=1``. |
| pred = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'b'), |
| self.pb.equal('id', 2), |
| ]), |
| ]) |
| result = replace_partition_predicate( |
| pred, {'part'}, {'part': 'a'}) |
| # OR with a single surviving child unwraps to that child. |
| self.assertEqual(result.method, 'equal') |
| self.assertEqual(result.field, 'id') |
| self.assertEqual(result.literals, [1]) |
| |
| def test_replace_partition_leaf_in_or_other_partition(self): |
| from pypaimon.read.scanner.bucket_select_converter import \ |
| replace_partition_predicate |
| # Same predicate, partition {part: 'b'} → second branch survives |
| # as ``id=2``. |
| pred = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'b'), |
| self.pb.equal('id', 2), |
| ]), |
| ]) |
| result = replace_partition_predicate( |
| pred, {'part'}, {'part': 'b'}) |
| self.assertEqual(result.method, 'equal') |
| self.assertEqual(result.field, 'id') |
| self.assertEqual(result.literals, [2]) |
| |
| def test_replace_partition_leaf_unrelated_predicate_unchanged(self): |
| from pypaimon.read.scanner.bucket_select_converter import \ |
| replace_partition_predicate |
| # No partition leaf → predicate returned as-is. |
| pred = self.pb.equal('id', 42) |
| result = replace_partition_predicate( |
| pred, {'part'}, {'part': 'a'}) |
| self.assertIs(result, pred) |
| |
| # ----- _Selector partition-aware path ----------------------------- |
| |
| def test_selector_3arg_specialises_per_partition(self): |
| # ``(part='a' AND id=1) OR (part='b' AND id=2)`` should hit |
| # bucket(1) only when partition='a' and bucket(2) only when |
| # partition='b'. Master without this fix returns "all buckets". |
| pred = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'b'), |
| self.pb.equal('id', 2), |
| ]), |
| ]) |
| sel = create_bucket_selector( |
| pred, [self.id_field], partition_fields=[self.part_field]) |
| self.assertIsNotNone(sel) |
| bucket_for_1 = _hash_bucket([1], [self.id_field], total=8) |
| bucket_for_2 = _hash_bucket([2], [self.id_field], total=8) |
| |
| part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) |
| part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT) |
| |
| for b in range(8): |
| self.assertEqual(sel(part_a, b, 8), b == bucket_for_1, |
| "partition a only keeps bucket %d" % bucket_for_1) |
| self.assertEqual(sel(part_b, b, 8), b == bucket_for_2, |
| "partition b only keeps bucket %d" % bucket_for_2) |
| |
| def test_selector_falls_through_when_partition_unknown(self): |
| """Early manifest filter passes ``partition=None`` (or uses the |
| 2-arg form) — no specialisation runs, bucket set falls back to a |
| sound over-approximation: all buckets accept.""" |
| pred = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'b'), |
| self.pb.equal('id', 2), |
| ]), |
| ]) |
| sel = create_bucket_selector( |
| pred, [self.id_field], partition_fields=[self.part_field]) |
| self.assertIsNotNone(sel) |
| # 2-arg form (legacy callsite) — partition unknown, all buckets keep. |
| for b in range(8): |
| self.assertTrue(sel(b, 8), |
| "partition-unknown call must accept all buckets") |
| # 3-arg form with partition=None has the same semantics. |
| for b in range(8): |
| self.assertTrue(sel(None, b, 8)) |
| |
| def test_selector_partition_not_matching_returns_empty_bucket_set(self): |
| # ``part = 'a' AND id = 1`` on partition {part: 'c'} simplifies to |
| # AlwaysFalse — the selector returns False for every bucket since |
| # no row in this partition can possibly match. Sound: dropping a |
| # partition that *can't* contain matches doesn't lose data. |
| pred = PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]) |
| sel = create_bucket_selector( |
| pred, [self.id_field], partition_fields=[self.part_field]) |
| self.assertIsNotNone(sel) |
| part_c = GenericRow(['c'], [self.part_field], RowKind.INSERT) |
| for b in range(8): |
| self.assertFalse(sel(part_c, b, 8), |
| "partition c can't satisfy part='a', " |
| "drop every bucket (b=%d)" % b) |
| |
| def test_selector_partition_only_constraint_drops_partition(self): |
| # ``part='a' AND id IN (1,2)`` — same partition value 'a' |
| # specialises ``part='a'`` to True, leaving ``id IN (1,2)``. |
| pred = PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.is_in('id', [1, 2]), |
| ]) |
| sel = create_bucket_selector( |
| pred, [self.id_field], partition_fields=[self.part_field]) |
| self.assertIsNotNone(sel) |
| part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) |
| expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)} |
| for b in range(8): |
| self.assertEqual(sel(part_a, b, 8), b in expected) |
| |
| def test_selector_caches_per_partition(self): |
| pred = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'a'), |
| self.pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| self.pb.equal('part', 'b'), |
| self.pb.equal('id', 2), |
| ]), |
| ]) |
| sel = create_bucket_selector( |
| pred, [self.id_field], partition_fields=[self.part_field]) |
| part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) |
| part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT) |
| # Drive the cache. |
| for _ in range(5): |
| sel(part_a, 0, 8) |
| sel(part_b, 0, 8) |
| # Cache keyed by (partition tuple, total_buckets); two distinct |
| # partitions × one total → exactly two entries. |
| self.assertEqual(len(sel._cache), 2) |
| |
| |
| # --------------------------------------------------------------------------- |
| # Layer 2 — Integration: real tables, public API, assert correctness AND |
| # that pruning actually fired (otherwise we're not testing the optimisation, |
| # only that we didn't break full-scan). |
| # --------------------------------------------------------------------------- |
| class BucketPruningIntegrationTest(unittest.TestCase): |
| |
| NUM_BUCKETS = 8 |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.tempdir = tempfile.mkdtemp() |
| cls.warehouse = os.path.join(cls.tempdir, 'warehouse') |
| cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) |
| cls.catalog.create_database('default', False) |
| |
| @classmethod |
| def tearDownClass(cls): |
| shutil.rmtree(cls.tempdir, ignore_errors=True) |
| |
| def _create_pk_table(self, name: str, num_buckets: int = NUM_BUCKETS, |
| bucket_key: str = None) -> Any: |
| opts = {'bucket': str(num_buckets), 'file.format': 'parquet'} |
| if bucket_key is not None: |
| opts['bucket-key'] = bucket_key |
| pa_schema = pa.schema([ |
| pa.field('id', pa.int64(), nullable=False), |
| ('val', pa.int64()), |
| ]) |
| schema = Schema.from_pyarrow_schema( |
| pa_schema, primary_keys=['id'], options=opts) |
| full = 'default.{}'.format(name) |
| self.catalog.create_table(full, schema, False) |
| return self.catalog.get_table(full) |
| |
| def _write(self, table, rows: List[Dict]): |
| pa_schema = pa.schema([ |
| pa.field('id', pa.int64(), nullable=False), |
| ('val', pa.int64()), |
| ]) |
| wb = table.new_batch_write_builder() |
| w = wb.new_write() |
| c = wb.new_commit() |
| try: |
| w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) |
| c.commit(w.prepare_commit()) |
| finally: |
| w.close() |
| c.close() |
| |
| def _read_with(self, table, predicate=None): |
| rb = table.new_read_builder() |
| if predicate is not None: |
| rb = rb.with_filter(predicate) |
| splits = rb.new_scan().plan().splits() |
| if not splits: |
| return [], splits |
| return rb.new_read().to_arrow(splits).to_pylist(), splits |
| |
| @staticmethod |
| def _split_buckets(splits) -> set: |
| """Collect the distinct bucket numbers actually returned in a plan.""" |
| return {s.bucket for s in splits} |
| |
| @staticmethod |
| def _expected_buckets(table, ids, value_field='val') -> set: |
| """Use the writer's RowKeyExtractor to compute the bucket(s) the |
| rows for ``ids`` were written into. Cross-check against the |
| reader's selector — divergence indicates read/write hash drift.""" |
| ext = FixedBucketRowKeyExtractor(table.table_schema) |
| pa_schema = pa.schema([ |
| pa.field('id', pa.int64(), nullable=False), |
| (value_field, pa.int64()), |
| ]) |
| out = set() |
| for i in ids: |
| arr = pa.RecordBatch.from_pylist( |
| [{'id': i, value_field: 0}], schema=pa_schema) |
| out.update(ext._extract_buckets_batch(arr)) |
| return out |
| |
| # -- Equal on PK ----------------------------------------------------- |
| def test_pk_equal_only_reads_target_bucket(self): |
| table = self._create_pk_table('int_eq') |
| rows = [{'id': i, 'val': i * 11} for i in range(100)] |
| self._write(table, rows) |
| |
| target_id = 42 |
| pred = table.new_read_builder().new_predicate_builder().equal( |
| 'id', target_id) |
| got, splits = self._read_with(table, pred) |
| |
| # Correctness: row for id=42 returned (and only that). |
| self.assertEqual(got, [{'id': 42, 'val': 42 * 11}]) |
| |
| # Pruning effectiveness AND hash correctness: the touched bucket |
| # must equal the bucket the writer placed id=42 into. Asserting |
| # only ``len == 1`` would mask a hash drift that picks the wrong |
| # single bucket. |
| self.assertEqual(self._split_buckets(splits), |
| self._expected_buckets(table, [target_id]), |
| "PK equal must touch exactly the writer's bucket") |
| |
| def test_pk_in_reads_only_target_buckets(self): |
| table = self._create_pk_table('int_in') |
| rows = [{'id': i, 'val': i * 7} for i in range(200)] |
| self._write(table, rows) |
| |
| ids = [3, 17, 99, 150] |
| pred = table.new_read_builder().new_predicate_builder().is_in( |
| 'id', ids) |
| got, splits = self._read_with(table, pred) |
| got_sorted = sorted(got, key=lambda r: r['id']) |
| self.assertEqual(got_sorted, |
| [{'id': i, 'val': i * 7} for i in sorted(ids)]) |
| |
| actual = self._split_buckets(splits) |
| expected_buckets = self._expected_buckets(table, ids) |
| # Equality (not subset): under the single-commit setup every |
| # target bucket actually has a file, so the planner must produce |
| # exactly the writer's bucket set. ``issubset`` would mask a |
| # selector that's overly aggressive on a subset of the IN list. |
| self.assertEqual(actual, expected_buckets, |
| "got {}, expected {}".format(actual, expected_buckets)) |
| |
| # -- Predicates that should NOT prune ------------------------------- |
| def test_value_only_predicate_falls_back_to_full_scan(self): |
| """``val < X`` doesn't constrain the PK → selector must be None |
| and no bucket pruning may fire. Both checked: result correctness |
| AND the explicit "selector is None" property.""" |
| table = self._create_pk_table('val_only') |
| rows = [{'id': i, 'val': i} for i in range(100)] |
| self._write(table, rows) |
| |
| pred = table.new_read_builder().new_predicate_builder().less_than( |
| 'val', 30) |
| got, splits = self._read_with(table, pred) |
| self.assertEqual(sorted([r['id'] for r in got]), list(range(30))) |
| |
| # Inspect the scanner's bucket selector to prove pruning DIDN'T |
| # fire — without this assertion the test would also pass under a |
| # buggy selector that prunes wrongly but happens to keep the |
| # rows we picked. |
| rb = table.new_read_builder().with_filter(pred) |
| scan = rb.new_scan() |
| self.assertIsNone(scan.file_scanner._bucket_selector, |
| "value-only predicate must NOT produce a selector") |
| |
| def test_range_on_pk_falls_back_to_full_scan(self): |
| """``id > X`` is a range, not Equal/In, so cannot derive a bucket |
| set. Selector returns None — result must still be exact.""" |
| table = self._create_pk_table('pk_range') |
| rows = [{'id': i, 'val': i} for i in range(50)] |
| self._write(table, rows) |
| |
| pred = table.new_read_builder().new_predicate_builder().greater_or_equal( |
| 'id', 40) |
| got, _ = self._read_with(table, pred) |
| self.assertEqual(sorted([r['id'] for r in got]), list(range(40, 50))) |
| |
| # -- Mixed predicate: Equal on PK AND range on val ------------------ |
| def test_pk_equal_with_unrelated_value_predicate_still_prunes(self): |
| table = self._create_pk_table('int_eq_with_val') |
| rows = [{'id': i, 'val': i} for i in range(50)] |
| self._write(table, rows) |
| |
| pb = table.new_read_builder().new_predicate_builder() |
| pred = pb.and_predicates([ |
| pb.equal('id', 25), |
| pb.greater_than('val', 20), |
| ]) |
| got, splits = self._read_with(table, pred) |
| self.assertEqual(got, [{'id': 25, 'val': 25}]) |
| self.assertEqual(self._split_buckets(splits), |
| self._expected_buckets(table, [25]), |
| "Equal on PK still narrows to the writer's bucket " |
| "even when AND'd with a non-bucket-key predicate") |
| |
| def test_early_filter_skips_full_entry_decode_for_pruned_buckets(self): |
| """Entries the bucket selector rejects must never reach |
| ``GenericRowDeserializer.from_bytes`` for their partition / key |
| stats. Without the early filter the count would scale with the |
| manifest entry count; with it, only the surviving entries pay |
| the deserialisation cost.""" |
| from unittest import mock |
| |
| from pypaimon.table.row import generic_row |
| |
| table = self._create_pk_table('early_filter') |
| # 8 separate single-row commits → 8 manifest entries each touching |
| # a different bucket. ``pk = X`` should reach exactly one of them. |
| for i in range(self.NUM_BUCKETS): |
| self._write(table, [{'id': i, 'val': i * 11}]) |
| |
| pred = table.new_read_builder().new_predicate_builder().equal('id', 0) |
| rb = table.new_read_builder().with_filter(pred) |
| |
| real_from_bytes = generic_row.GenericRowDeserializer.from_bytes |
| calls = {'n': 0} |
| |
| def counting(*args, **kwargs): |
| calls['n'] += 1 |
| return real_from_bytes(*args, **kwargs) |
| |
| with mock.patch.object(generic_row.GenericRowDeserializer, |
| 'from_bytes', |
| side_effect=counting): |
| splits = rb.new_scan().plan().splits() |
| got = rb.new_read().to_arrow(splits).to_pylist() if splits else [] |
| |
| self.assertEqual(got, [{'id': 0, 'val': 0}]) |
| # Each surviving entry decodes partition + min_key + max_key |
| # (3 ``from_bytes`` calls). Allow a small slack in case the planner |
| # touches extras, but assert it is well below 8 entries × 3 = 24. |
| self.assertLess( |
| calls['n'], 3 * self.NUM_BUCKETS, |
| "early filter should skip from_bytes for pruned entries; " |
| "got {} calls (would be {}+ without the filter)".format( |
| calls['n'], 3 * self.NUM_BUCKETS)) |
| |
| def test_init_bucket_selector_fails_open_when_bucket_keys_raises(self): |
| """``TableSchema.bucket_keys`` raises if ``bucket-key`` references |
| an unknown column. The pre-Java-alignment selector path used to |
| catch ``Exception`` from instantiating ``FixedBucketRowKeyExtractor`` |
| and silently skip pruning; that property must survive the move |
| of bucket-key resolution onto ``TableSchema``. Crashing the scan |
| on a misconfiguration would be worse than skipping the |
| optimisation.""" |
| table = self._create_pk_table('init_fails_open') |
| self._write(table, [{'id': 1, 'val': 1}]) |
| # Mutate the in-memory schema options to a broken value to |
| # simulate a corrupted/migrated catalog without rewriting it. |
| table.table_schema.options['bucket-key'] = 'nope_no_such_column' |
| |
| rb = table.new_read_builder().with_filter( |
| table.new_read_builder().new_predicate_builder().equal('id', 1)) |
| scanner = rb.new_scan().file_scanner |
| # Must NOT raise: the broken option falls back to "no pruning", |
| # and the scan still finds the row. |
| self.assertIsNone(scanner._init_bucket_selector()) |
| got, _ = self._read_with(table, scanner.predicate) |
| self.assertEqual(got, [{'id': 1, 'val': 1}]) |
| |
| # -- Explicit bucket-key option ------------------------------------ |
| def test_bucket_key_option_overrides_pk_for_pruning(self): |
| """When the ``bucket-key`` option is set explicitly, the bucket |
| derivation must use it — not the trimmed primary keys. This is |
| the path that catches read/write hash divergence if a refactor |
| forgets the option.""" |
| # PK = id, bucket-key = id explicitly (single key but exercises |
| # the explicit-config branch in ``_init_bucket_selector``). |
| table = self._create_pk_table('explicit_bk', bucket_key='id') |
| rows = [{'id': i, 'val': i * 3} for i in range(40)] |
| self._write(table, rows) |
| |
| pred = table.new_read_builder().new_predicate_builder().equal('id', 17) |
| got, splits = self._read_with(table, pred) |
| self.assertEqual(got, [{'id': 17, 'val': 51}]) |
| self.assertEqual(self._split_buckets(splits), |
| self._expected_buckets(table, [17])) |
| |
| def test_per_partition_pruning_with_mixed_or(self): |
| """``(part='a' AND id=1) OR (part='b' AND id=2)``: each partition |
| sees only the bucket for its own ``id`` literal. Without |
| per-partition predicate specialisation this query falls through |
| to "all buckets in both partitions".""" |
| opts = {'bucket': '4', 'file.format': 'parquet'} |
| pa_schema = pa.schema([ |
| pa.field('part', pa.string(), nullable=False), |
| pa.field('id', pa.int64(), nullable=False), |
| ('val', pa.int64()), |
| ]) |
| schema = Schema.from_pyarrow_schema( |
| pa_schema, primary_keys=['part', 'id'], |
| partition_keys=['part'], options=opts) |
| identifier = 'default.per_part_mixed_or' |
| self.catalog.create_table(identifier, schema, False) |
| table = self.catalog.get_table(identifier) |
| # Two partitions × three id values each → up to 6 (part, bucket) |
| # combinations after the writer hashes. |
| rows = [] |
| for p in ('a', 'b'): |
| for i in (1, 2, 3): |
| rows.append({'part': p, 'id': i, 'val': i * 7}) |
| wb = table.new_batch_write_builder() |
| w = wb.new_write() |
| c = wb.new_commit() |
| try: |
| w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) |
| c.commit(w.prepare_commit()) |
| finally: |
| w.close() |
| c.close() |
| |
| pb = table.new_read_builder().new_predicate_builder() |
| from pypaimon.common.predicate_builder import PredicateBuilder |
| mixed = PredicateBuilder.or_predicates([ |
| PredicateBuilder.and_predicates([ |
| pb.equal('part', 'a'), |
| pb.equal('id', 1), |
| ]), |
| PredicateBuilder.and_predicates([ |
| pb.equal('part', 'b'), |
| pb.equal('id', 2), |
| ]), |
| ]) |
| got, splits = self._read_with(table, mixed) |
| # Correctness: only the two matching rows. |
| got_sorted = sorted(got, key=lambda r: (r['part'], r['id'])) |
| self.assertEqual( |
| got_sorted, |
| [{'part': 'a', 'id': 1, 'val': 7}, |
| {'part': 'b', 'id': 2, 'val': 14}]) |
| # Pruning effectiveness: across both partitions we should see at |
| # most two distinct (partition, bucket) splits — one per branch. |
| # Without per-partition pruning we'd see every (partition, bucket) |
| # combo that exists on disk for the predicate's id literals. |
| self.assertLessEqual(len(splits), 2, |
| "per-partition pruning should keep ≤ 2 splits, " |
| "got %d" % len(splits)) |
| |
| |
| # --------------------------------------------------------------------------- |
| # Layer 3 — Property: random PK tables, random Equal/In predicates, |
| # correctness vs oracle. |
| # --------------------------------------------------------------------------- |
| class BucketPruningPropertyTest(unittest.TestCase): |
| |
| SEED = 0xB0CC |
| TRIALS = 30 |
| |
| @classmethod |
| def setUpClass(cls): |
| cls.tempdir = tempfile.mkdtemp() |
| cls.warehouse = os.path.join(cls.tempdir, 'warehouse') |
| cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse}) |
| cls.catalog.create_database('default', False) |
| cls.rnd = random.Random(cls.SEED) |
| |
| @classmethod |
| def tearDownClass(cls): |
| shutil.rmtree(cls.tempdir, ignore_errors=True) |
| |
| def _make_table(self, idx: int, num_buckets: int): |
| pa_schema = pa.schema([ |
| pa.field('k', pa.int64(), nullable=False), |
| ('v', pa.int64()), |
| ]) |
| schema = Schema.from_pyarrow_schema( |
| pa_schema, |
| primary_keys=['k'], |
| options={'bucket': str(num_buckets), 'file.format': 'parquet'}, |
| ) |
| name = 'default.bp_{}'.format(idx) |
| self.catalog.create_table(name, schema, False) |
| return self.catalog.get_table(name) |
| |
| def _write(self, table, rows): |
| pa_schema = pa.schema([ |
| pa.field('k', pa.int64(), nullable=False), |
| ('v', pa.int64()), |
| ]) |
| wb = table.new_batch_write_builder() |
| w = wb.new_write() |
| c = wb.new_commit() |
| try: |
| w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) |
| c.commit(w.prepare_commit()) |
| finally: |
| w.close() |
| c.close() |
| |
| @staticmethod |
| def _expected_buckets(table, keys) -> set: |
| """Independent oracle: writer's bucket placement for the given keys.""" |
| ext = FixedBucketRowKeyExtractor(table.table_schema) |
| pa_schema = pa.schema([ |
| pa.field('k', pa.int64(), nullable=False), |
| ('v', pa.int64()), |
| ]) |
| out = set() |
| for k in keys: |
| arr = pa.RecordBatch.from_pylist( |
| [{'k': k, 'v': 0}], schema=pa_schema) |
| out.update(ext._extract_buckets_batch(arr)) |
| return out |
| |
| def test_property_pk_equal_correctness(self): |
| for trial in range(self.TRIALS): |
| num_buckets = self.rnd.choice([2, 4, 8, 16]) |
| table = self._make_table(trial, num_buckets) |
| keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) |
| rows = [{'k': k, 'v': k * 13} for k in keys] |
| self._write(table, rows) |
| |
| target = self.rnd.choice(keys) |
| pb = table.new_read_builder().new_predicate_builder() |
| pred = pb.equal('k', target) |
| rb = table.new_read_builder().with_filter(pred) |
| splits = rb.new_scan().plan().splits() |
| if splits: |
| got = rb.new_read().to_arrow(splits).to_pylist() |
| else: |
| got = [] |
| self.assertEqual(got, [{'k': target, 'v': target * 13}], |
| "trial {} buckets={} target={}: result mismatch" |
| .format(trial, num_buckets, target)) |
| # Pruning fired AND picked the writer's bucket. Without this |
| # cross-check a fail-open selector (i.e. no pruning) would |
| # still pass the result-equality assertion above. |
| self.assertEqual(self._split_buckets(splits), |
| self._expected_buckets(table, [target]), |
| "trial {}: bucket set != writer's placement" |
| .format(trial)) |
| |
| def test_property_pk_in_correctness(self): |
| for trial in range(self.TRIALS): |
| num_buckets = self.rnd.choice([2, 4, 8, 16]) |
| offset = self.TRIALS + trial # avoid name clash with prev test |
| table = self._make_table(offset, num_buckets) |
| keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100)) |
| rows = [{'k': k, 'v': k * 13} for k in keys] |
| self._write(table, rows) |
| |
| target_n = self.rnd.randint(1, min(10, len(keys))) |
| targets = self.rnd.sample(keys, target_n) |
| pb = table.new_read_builder().new_predicate_builder() |
| pred = pb.is_in('k', targets) |
| rb = table.new_read_builder().with_filter(pred) |
| splits = rb.new_scan().plan().splits() |
| if splits: |
| got = rb.new_read().to_arrow(splits).to_pylist() |
| else: |
| got = [] |
| got_sorted = sorted(got, key=lambda r: r['k']) |
| want = sorted( |
| [{'k': k, 'v': k * 13} for k in targets], |
| key=lambda r: r['k']) |
| self.assertEqual(got_sorted, want, |
| "trial {}: IN result mismatch".format(trial)) |
| self.assertEqual(self._split_buckets(splits), |
| self._expected_buckets(table, targets), |
| "trial {}: IN bucket set != writer's placement" |
| .format(trial)) |
| |
| @staticmethod |
| def _split_buckets(splits) -> set: |
| return {s.bucket for s in splits} |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |