blob: f9a5451bac38095918002340ab662ec42a2d0180 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Three-layer correctness tests for predicate-driven bucket pruning.
Mirrors Java's ``BucketSelectConverter`` contract: PK Equal/In queries on
HASH_FIXED tables must touch only the bucket(s) the writer would have
placed those keys in. Two correctness obligations:
1. Sound: every bucket retained by the selector contains AT MOST a
superset of matching rows. Buckets that DO contain matching rows
are NEVER dropped — false-negative-free.
2. Hash-consistent with writers: ``RowKeyExtractor`` (writer) and
``BucketSelectConverter`` (reader) must agree on every literal.
This is what makes ``pk = 'X'`` read the bucket holding 'X'.
Layered:
* Unit — direct calls to ``create_bucket_selector`` with crafted
predicates, asserting selector behaviour.
* Integration — real PK tables with multiple buckets; queries; assert
(a) result correctness, (b) bucket pruning happened.
* Property — randomly-seeded PK tables, random Equal/In predicates,
result == oracle. No hypothesis dependency (keeps
Python 3.6 compat).
"""
import os
import random
import shutil
import tempfile
import unittest
from typing import Any, Dict, List
import pyarrow as pa
from pypaimon import CatalogFactory, Schema
from pypaimon.common.predicate_builder import PredicateBuilder
from pypaimon.read.scanner.bucket_select_converter import (
MAX_VALUES, create_bucket_selector)
from pypaimon.schema.data_types import AtomicType, DataField
from pypaimon.write.row_key_extractor import (FixedBucketRowKeyExtractor,
_bucket_from_hash,
_hash_bytes_by_words)
from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer
from pypaimon.table.row.internal_row import RowKind
def _bigint_field(idx: int, name: str) -> DataField:
return DataField(idx, name, AtomicType('BIGINT', nullable=False))
def _field(idx: int, name: str, type_name: str) -> DataField:
return DataField(idx, name, AtomicType(type_name, nullable=False))
def _hash_bucket(values: List[Any], fields: List[DataField], total: int) -> int:
"""Re-implement the writer's hash so unit tests can compute the
expected bucket without spinning up a real table."""
row = GenericRow(values, fields, RowKind.INSERT)
serialized = GenericRowSerializer.to_bytes(row)
h = _hash_bytes_by_words(serialized[4:])
return _bucket_from_hash(h, total)
# ---------------------------------------------------------------------------
# Layer 1 — Unit: drive ``create_bucket_selector`` with crafted predicates.
# ---------------------------------------------------------------------------
class BucketSelectConverterUnitTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.id_field = _bigint_field(0, 'id')
cls.val_field = _bigint_field(1, 'val')
cls.k1 = _bigint_field(0, 'k1')
cls.k2 = _bigint_field(1, 'k2')
cls.pb_id_val = PredicateBuilder([cls.id_field, cls.val_field])
cls.pb_k1_k2 = PredicateBuilder([cls.k1, cls.k2])
# -- Equal / In on single bucket key ---------------------------------
def test_equal_on_single_bucket_key_yields_single_bucket(self):
sel = create_bucket_selector(
self.pb_id_val.equal('id', 42), [self.id_field])
self.assertIsNotNone(sel, "PK Equal must produce a selector")
expected = _hash_bucket([42], [self.id_field], total=8)
for b in range(8):
self.assertEqual(
sel(b, 8), b == expected,
"only bucket {} must be kept (got {})".format(expected, b))
def test_in_on_single_bucket_key_unions_buckets(self):
sel = create_bucket_selector(
self.pb_id_val.is_in('id', [1, 2, 3, 100]), [self.id_field])
expected = {_hash_bucket([v], [self.id_field], 8)
for v in (1, 2, 3, 100)}
for b in range(8):
self.assertEqual(sel(b, 8), b in expected)
def test_or_of_equals_on_same_field_unions_buckets(self):
# ``id = 1 OR id = 2`` must equal ``id IN (1, 2)``.
pred = PredicateBuilder.or_predicates([
self.pb_id_val.equal('id', 1),
self.pb_id_val.equal('id', 2),
])
sel = create_bucket_selector(pred, [self.id_field])
expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)}
for b in range(8):
self.assertEqual(sel(b, 8), b in expected)
# -- Composite bucket keys ------------------------------------------
def test_composite_bucket_key_intersects_via_cartesian(self):
pred = PredicateBuilder.and_predicates([
self.pb_k1_k2.is_in('k1', [1, 2]),
self.pb_k1_k2.equal('k2', 99),
])
sel = create_bucket_selector(pred, [self.k1, self.k2])
expected = {
_hash_bucket([k1, 99], [self.k1, self.k2], 4)
for k1 in (1, 2)
}
for b in range(4):
self.assertEqual(sel(b, 4), b in expected)
def test_composite_bucket_key_missing_one_field_returns_none(self):
pred = self.pb_k1_k2.equal('k1', 1) # k2 unconstrained
sel = create_bucket_selector(pred, [self.k1, self.k2])
self.assertIsNone(sel,
"all bucket keys must be constrained or fall back")
# -- Predicates that can't be reduced -------------------------------
def test_non_bucket_key_predicate_returns_none(self):
sel = create_bucket_selector(
self.pb_id_val.equal('val', 5), [self.id_field])
self.assertIsNone(sel, "predicate not on bucket key -> no selector")
def test_range_predicate_on_bucket_key_returns_none(self):
sel = create_bucket_selector(
self.pb_id_val.greater_than('id', 100), [self.id_field])
self.assertIsNone(sel, "ranges can't be turned into a finite bucket set")
def test_or_with_non_bucket_key_returns_none(self):
# ``id = 1 OR val = 5`` — ``val`` isn't a bucket key, so the OR
# is not a pure bucket-key constraint.
pred = PredicateBuilder.or_predicates([
self.pb_id_val.equal('id', 1),
self.pb_id_val.equal('val', 5),
])
sel = create_bucket_selector(pred, [self.id_field])
self.assertIsNone(sel)
def test_repeated_equal_on_same_key_with_empty_intersection_returns_none(self):
# ``id = 1 AND id = 2``: literal sets {1} and {2} intersect to
# empty; Java's ``retainAll`` would also bail here, since the
# predicate is unsatisfiable.
pred = PredicateBuilder.and_predicates([
self.pb_id_val.equal('id', 1),
self.pb_id_val.equal('id', 2),
])
sel = create_bucket_selector(pred, [self.id_field])
self.assertIsNone(sel)
def test_repeated_in_on_same_key_intersects_literals(self):
# ``id IN (1,2,3) AND id IN (2,3,4)`` should now keep the
# intersection {2, 3} and prune to those buckets only. Used to
# bail with no selector before the Java parity fix.
pred = PredicateBuilder.and_predicates([
self.pb_id_val.is_in('id', [1, 2, 3]),
self.pb_id_val.is_in('id', [2, 3, 4]),
])
sel = create_bucket_selector(pred, [self.id_field])
self.assertIsNotNone(sel)
expected = {_hash_bucket([v], [self.id_field], 8) for v in (2, 3)}
for b in range(8):
self.assertEqual(sel(b, 8), b in expected)
def test_and_with_unrelated_clause_is_unaffected(self):
# ``id = 7 AND val > 100`` — the ``val > 100`` part doesn't
# constrain buckets, but mustn't disqualify the ``id = 7`` part.
pred = PredicateBuilder.and_predicates([
self.pb_id_val.equal('id', 7),
self.pb_id_val.greater_than('val', 100),
])
sel = create_bucket_selector(pred, [self.id_field])
self.assertIsNotNone(sel)
expected = _hash_bucket([7], [self.id_field], 4)
for b in range(4):
self.assertEqual(sel(b, 4), b == expected)
# -- Cap & degenerate edge cases ------------------------------------
def test_cartesian_above_max_values_returns_none(self):
# Two columns of size > sqrt(MAX_VALUES) → product > MAX_VALUES.
size = 33 # 33 * 33 = 1089 > 1000
pred = PredicateBuilder.and_predicates([
self.pb_k1_k2.is_in('k1', list(range(size))),
self.pb_k1_k2.is_in('k2', list(range(size))),
])
self.assertGreater(size * size, MAX_VALUES)
sel = create_bucket_selector(pred, [self.k1, self.k2])
self.assertIsNone(sel)
def test_null_only_literal_drops_everything(self):
# ``id IN (NULL)`` after null-stripping has zero literals; the
# cartesian product is empty → selector matches no buckets. Same
# behaviour as Java.
pred = self.pb_id_val.is_in('id', [None])
sel = create_bucket_selector(pred, [self.id_field])
self.assertIsNotNone(sel)
for b in range(4):
self.assertFalse(sel(b, 4),
"all-null literal collapses bucket set to empty")
def test_no_predicate_returns_none(self):
self.assertIsNone(create_bucket_selector(None, [self.id_field]))
def test_no_bucket_keys_returns_none(self):
self.assertIsNone(
create_bucket_selector(self.pb_id_val.equal('id', 1), []))
# -- Selector cache + rescale -------------------------------------
def test_selector_caches_per_total_buckets(self):
"""Selector must answer correctly when the same query applies to
different ``total_buckets`` values (the rescale scenario)."""
sel = create_bucket_selector(
self.pb_id_val.equal('id', 42), [self.id_field])
for total in (4, 8, 16, 32):
expected = _hash_bucket([42], [self.id_field], total)
self.assertTrue(sel(expected, total))
other = (expected + 1) % total
self.assertFalse(sel(other, total))
def test_non_positive_total_buckets_fails_open(self):
"""Manifest entries can carry ``total_buckets <= 0`` for legacy /
special bucket modes. Pruning MUST fail open — returning False
would silently drop rows the writer placed in those entries.
This is correctness, not performance: the soundness contract
forbids false-negatives."""
sel = create_bucket_selector(
self.pb_id_val.equal('id', 1), [self.id_field])
for total in (0, -1, -2):
self.assertTrue(sel(0, total),
"total_buckets={} must be kept (fail open)".format(total))
self.assertTrue(sel(-1, total))
self.assertTrue(sel(99, total))
# -- Bucket-key column types beyond BIGINT --------------------------
def test_string_bucket_key_yields_correct_bucket(self):
"""STRING uses a different ``GenericRowSerializer`` path (utf-8
encode + variable-part offset) — verify writer/reader agree on
its byte layout independent of the BIGINT happy path."""
sf = _field(0, 'sk', 'STRING')
vf = _bigint_field(1, 'val')
pb = PredicateBuilder([sf, vf])
sel = create_bucket_selector(pb.equal('sk', 'hello'), [sf])
self.assertIsNotNone(sel)
expected = _hash_bucket(['hello'], [sf], total=8)
for b in range(8):
self.assertEqual(sel(b, 8), b == expected)
def test_int_bucket_key_yields_correct_bucket(self):
"""INT (32-bit) and BIGINT (64-bit) hit different struct.pack
paths in the serializer — guard the smaller width."""
intf = _field(0, 'i', 'INT')
vf = _bigint_field(1, 'val')
pb = PredicateBuilder([intf, vf])
sel = create_bucket_selector(pb.equal('i', 7), [intf])
self.assertIsNotNone(sel)
expected = _hash_bucket([7], [intf], total=4)
for b in range(4):
self.assertEqual(sel(b, 4), b == expected)
# -- Hash-divergence-prone types refuse to build a selector --------
def test_decimal_bucket_key_disables_pruning(self):
"""DECIMAL columns risk silent hash divergence between writer
(Decimal) and reader-supplied ``float`` literals. Soundness
contract demands fail-open: refuse to build a selector at all."""
df = _field(0, 'd', 'DECIMAL(10, 2)')
vf = _bigint_field(1, 'val')
pb = PredicateBuilder([df, vf])
from decimal import Decimal
sel = create_bucket_selector(pb.equal('d', Decimal('1.50')), [df])
self.assertIsNone(
sel, "DECIMAL bucket-key column must disable pruning")
def test_array_bucket_key_disables_pruning(self):
"""Composite types (ARRAY/MAP/ROW/MULTISET/VARIANT/BLOB) have no
cross-validated byte alignment with Java's ``BinaryRow`` — until
that exists, refuse to prune on them."""
# Hand-roll a DataField whose AtomicType reports an ARRAY type
# name; the converter inspects ``field.type.type`` only.
af = DataField(0, 'arr', AtomicType('ARRAY<BIGINT>'))
vf = _bigint_field(1, 'val')
pb = PredicateBuilder([af, vf])
sel = create_bucket_selector(pb.equal('arr', [1]), [af])
self.assertIsNone(
sel, "ARRAY bucket-key column must disable pruning")
def test_timestamp_bucket_key_disables_pruning(self):
"""TIMESTAMP columns serialise via ``value.timestamp()`` whose
result depends on the process timezone for naive datetimes —
writer and reader running in different TZs would disagree."""
tf = _field(0, 't', 'TIMESTAMP(3)')
vf = _bigint_field(1, 'val')
pb = PredicateBuilder([tf, vf])
from datetime import datetime
sel = create_bucket_selector(
pb.equal('t', datetime(2026, 1, 1)), [tf])
self.assertIsNone(
sel, "TIMESTAMP bucket-key column must disable pruning")
def test_type_mismatched_literal_fails_open_not_crash(self):
"""If the user constructs a predicate whose literal type doesn't
match the bucket-key column's atomic type — e.g. a STRING literal
on a BIGINT column — ``GenericRowSerializer`` raises during the
deferred hash inside ``_Selector``. The selector MUST swallow the
exception and fail open (return True for every bucket) rather
than propagate it. Crashing the entire scan with an opaque
``struct.error`` is a worse user experience than silently
skipping bucket pruning, and the soundness contract still
forbids false-negatives."""
sel = create_bucket_selector(
self.pb_id_val.equal('id', 'not-an-int'), [self.id_field])
# Construction itself succeeds (no eager hashing).
self.assertIsNotNone(sel)
# Calling the selector must NOT raise; instead it returns True
# for every (bucket, total_buckets), preserving soundness.
for total in (4, 8):
for b in range(total):
self.assertTrue(sel(b, total),
"type-mismatched literal must fail open, "
"not crash (bucket={}, total={})".format(b, total))
class PartitionAwareBucketSelectorUnitTest(unittest.TestCase):
"""Unit tests for the per-partition predicate specialisation path.
Covers ``replace_partition_predicate`` (the AND/OR fold walker) and
the partition-aware ``_Selector.__call__(partition, bucket,
total_buckets)`` 3-arg form that ``FileScanner._filter_manifest_entry``
will use after wiring."""
@classmethod
def setUpClass(cls):
cls.id_field = _bigint_field(0, 'id')
cls.part_field = DataField(2, 'part', AtomicType('STRING'))
cls.pb = PredicateBuilder([cls.id_field, cls.part_field])
# ----- replace_partition_predicate --------------------------------
def test_replace_partition_leaf_to_true_drops_constraint(self):
from pypaimon.read.scanner.bucket_select_converter import \
replace_partition_predicate
# ``part = 'a' AND id = 1`` against partition {part: 'a'} →
# part leaf becomes True → AND fold removes it → only ``id = 1``
pred = PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
])
result = replace_partition_predicate(
pred, {'part'}, {'part': 'a'})
self.assertTrue(isinstance(result, type(pred)),
"AND should fold to a remaining single leaf")
self.assertEqual(result.method, 'equal')
self.assertEqual(result.field, 'id')
def test_replace_partition_leaf_to_false_collapses_and(self):
from pypaimon.read.scanner.bucket_select_converter import \
replace_partition_predicate
# ``part = 'a' AND id = 1`` against partition {part: 'b'} →
# part leaf becomes False → AND collapses to AlwaysFalse (False).
pred = PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
])
result = replace_partition_predicate(
pred, {'part'}, {'part': 'b'})
self.assertIs(result, False)
def test_replace_partition_leaf_in_or_keeps_other_branch(self):
from pypaimon.read.scanner.bucket_select_converter import \
replace_partition_predicate
# ``(part='a' AND id=1) OR (part='b' AND id=2)`` against
# partition {part: 'a'} → first OR child becomes ``id=1``, second
# collapses to AlwaysFalse and is dropped. Result is just ``id=1``.
pred = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
self.pb.equal('part', 'b'),
self.pb.equal('id', 2),
]),
])
result = replace_partition_predicate(
pred, {'part'}, {'part': 'a'})
# OR with a single surviving child unwraps to that child.
self.assertEqual(result.method, 'equal')
self.assertEqual(result.field, 'id')
self.assertEqual(result.literals, [1])
def test_replace_partition_leaf_in_or_other_partition(self):
from pypaimon.read.scanner.bucket_select_converter import \
replace_partition_predicate
# Same predicate, partition {part: 'b'} → second branch survives
# as ``id=2``.
pred = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
self.pb.equal('part', 'b'),
self.pb.equal('id', 2),
]),
])
result = replace_partition_predicate(
pred, {'part'}, {'part': 'b'})
self.assertEqual(result.method, 'equal')
self.assertEqual(result.field, 'id')
self.assertEqual(result.literals, [2])
def test_replace_partition_leaf_unrelated_predicate_unchanged(self):
from pypaimon.read.scanner.bucket_select_converter import \
replace_partition_predicate
# No partition leaf → predicate returned as-is.
pred = self.pb.equal('id', 42)
result = replace_partition_predicate(
pred, {'part'}, {'part': 'a'})
self.assertIs(result, pred)
# ----- _Selector partition-aware path -----------------------------
def test_selector_3arg_specialises_per_partition(self):
# ``(part='a' AND id=1) OR (part='b' AND id=2)`` should hit
# bucket(1) only when partition='a' and bucket(2) only when
# partition='b'. Master without this fix returns "all buckets".
pred = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
self.pb.equal('part', 'b'),
self.pb.equal('id', 2),
]),
])
sel = create_bucket_selector(
pred, [self.id_field], partition_fields=[self.part_field])
self.assertIsNotNone(sel)
bucket_for_1 = _hash_bucket([1], [self.id_field], total=8)
bucket_for_2 = _hash_bucket([2], [self.id_field], total=8)
part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
for b in range(8):
self.assertEqual(sel(part_a, b, 8), b == bucket_for_1,
"partition a only keeps bucket %d" % bucket_for_1)
self.assertEqual(sel(part_b, b, 8), b == bucket_for_2,
"partition b only keeps bucket %d" % bucket_for_2)
def test_selector_falls_through_when_partition_unknown(self):
"""Early manifest filter passes ``partition=None`` (or uses the
2-arg form) — no specialisation runs, bucket set falls back to a
sound over-approximation: all buckets accept."""
pred = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
self.pb.equal('part', 'b'),
self.pb.equal('id', 2),
]),
])
sel = create_bucket_selector(
pred, [self.id_field], partition_fields=[self.part_field])
self.assertIsNotNone(sel)
# 2-arg form (legacy callsite) — partition unknown, all buckets keep.
for b in range(8):
self.assertTrue(sel(b, 8),
"partition-unknown call must accept all buckets")
# 3-arg form with partition=None has the same semantics.
for b in range(8):
self.assertTrue(sel(None, b, 8))
def test_selector_partition_not_matching_returns_empty_bucket_set(self):
# ``part = 'a' AND id = 1`` on partition {part: 'c'} simplifies to
# AlwaysFalse — the selector returns False for every bucket since
# no row in this partition can possibly match. Sound: dropping a
# partition that *can't* contain matches doesn't lose data.
pred = PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
])
sel = create_bucket_selector(
pred, [self.id_field], partition_fields=[self.part_field])
self.assertIsNotNone(sel)
part_c = GenericRow(['c'], [self.part_field], RowKind.INSERT)
for b in range(8):
self.assertFalse(sel(part_c, b, 8),
"partition c can't satisfy part='a', "
"drop every bucket (b=%d)" % b)
def test_selector_partition_only_constraint_drops_partition(self):
# ``part='a' AND id IN (1,2)`` — same partition value 'a'
# specialises ``part='a'`` to True, leaving ``id IN (1,2)``.
pred = PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.is_in('id', [1, 2]),
])
sel = create_bucket_selector(
pred, [self.id_field], partition_fields=[self.part_field])
self.assertIsNotNone(sel)
part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)}
for b in range(8):
self.assertEqual(sel(part_a, b, 8), b in expected)
def test_selector_caches_per_partition(self):
pred = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
self.pb.equal('part', 'a'),
self.pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
self.pb.equal('part', 'b'),
self.pb.equal('id', 2),
]),
])
sel = create_bucket_selector(
pred, [self.id_field], partition_fields=[self.part_field])
part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT)
part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT)
# Drive the cache.
for _ in range(5):
sel(part_a, 0, 8)
sel(part_b, 0, 8)
# Cache keyed by (partition tuple, total_buckets); two distinct
# partitions × one total → exactly two entries.
self.assertEqual(len(sel._cache), 2)
# ---------------------------------------------------------------------------
# Layer 2 — Integration: real tables, public API, assert correctness AND
# that pruning actually fired (otherwise we're not testing the optimisation,
# only that we didn't break full-scan).
# ---------------------------------------------------------------------------
class BucketPruningIntegrationTest(unittest.TestCase):
NUM_BUCKETS = 8
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', False)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _create_pk_table(self, name: str, num_buckets: int = NUM_BUCKETS,
bucket_key: str = None) -> Any:
opts = {'bucket': str(num_buckets), 'file.format': 'parquet'}
if bucket_key is not None:
opts['bucket-key'] = bucket_key
pa_schema = pa.schema([
pa.field('id', pa.int64(), nullable=False),
('val', pa.int64()),
])
schema = Schema.from_pyarrow_schema(
pa_schema, primary_keys=['id'], options=opts)
full = 'default.{}'.format(name)
self.catalog.create_table(full, schema, False)
return self.catalog.get_table(full)
def _write(self, table, rows: List[Dict]):
pa_schema = pa.schema([
pa.field('id', pa.int64(), nullable=False),
('val', pa.int64()),
])
wb = table.new_batch_write_builder()
w = wb.new_write()
c = wb.new_commit()
try:
w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
c.commit(w.prepare_commit())
finally:
w.close()
c.close()
def _read_with(self, table, predicate=None):
rb = table.new_read_builder()
if predicate is not None:
rb = rb.with_filter(predicate)
splits = rb.new_scan().plan().splits()
if not splits:
return [], splits
return rb.new_read().to_arrow(splits).to_pylist(), splits
@staticmethod
def _split_buckets(splits) -> set:
"""Collect the distinct bucket numbers actually returned in a plan."""
return {s.bucket for s in splits}
@staticmethod
def _expected_buckets(table, ids, value_field='val') -> set:
"""Use the writer's RowKeyExtractor to compute the bucket(s) the
rows for ``ids`` were written into. Cross-check against the
reader's selector — divergence indicates read/write hash drift."""
ext = FixedBucketRowKeyExtractor(table.table_schema)
pa_schema = pa.schema([
pa.field('id', pa.int64(), nullable=False),
(value_field, pa.int64()),
])
out = set()
for i in ids:
arr = pa.RecordBatch.from_pylist(
[{'id': i, value_field: 0}], schema=pa_schema)
out.update(ext._extract_buckets_batch(arr))
return out
# -- Equal on PK -----------------------------------------------------
def test_pk_equal_only_reads_target_bucket(self):
table = self._create_pk_table('int_eq')
rows = [{'id': i, 'val': i * 11} for i in range(100)]
self._write(table, rows)
target_id = 42
pred = table.new_read_builder().new_predicate_builder().equal(
'id', target_id)
got, splits = self._read_with(table, pred)
# Correctness: row for id=42 returned (and only that).
self.assertEqual(got, [{'id': 42, 'val': 42 * 11}])
# Pruning effectiveness AND hash correctness: the touched bucket
# must equal the bucket the writer placed id=42 into. Asserting
# only ``len == 1`` would mask a hash drift that picks the wrong
# single bucket.
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, [target_id]),
"PK equal must touch exactly the writer's bucket")
def test_pk_in_reads_only_target_buckets(self):
table = self._create_pk_table('int_in')
rows = [{'id': i, 'val': i * 7} for i in range(200)]
self._write(table, rows)
ids = [3, 17, 99, 150]
pred = table.new_read_builder().new_predicate_builder().is_in(
'id', ids)
got, splits = self._read_with(table, pred)
got_sorted = sorted(got, key=lambda r: r['id'])
self.assertEqual(got_sorted,
[{'id': i, 'val': i * 7} for i in sorted(ids)])
actual = self._split_buckets(splits)
expected_buckets = self._expected_buckets(table, ids)
# Equality (not subset): under the single-commit setup every
# target bucket actually has a file, so the planner must produce
# exactly the writer's bucket set. ``issubset`` would mask a
# selector that's overly aggressive on a subset of the IN list.
self.assertEqual(actual, expected_buckets,
"got {}, expected {}".format(actual, expected_buckets))
# -- Predicates that should NOT prune -------------------------------
def test_value_only_predicate_falls_back_to_full_scan(self):
"""``val < X`` doesn't constrain the PK → selector must be None
and no bucket pruning may fire. Both checked: result correctness
AND the explicit "selector is None" property."""
table = self._create_pk_table('val_only')
rows = [{'id': i, 'val': i} for i in range(100)]
self._write(table, rows)
pred = table.new_read_builder().new_predicate_builder().less_than(
'val', 30)
got, splits = self._read_with(table, pred)
self.assertEqual(sorted([r['id'] for r in got]), list(range(30)))
# Inspect the scanner's bucket selector to prove pruning DIDN'T
# fire — without this assertion the test would also pass under a
# buggy selector that prunes wrongly but happens to keep the
# rows we picked.
rb = table.new_read_builder().with_filter(pred)
scan = rb.new_scan()
self.assertIsNone(scan.file_scanner._bucket_selector,
"value-only predicate must NOT produce a selector")
def test_range_on_pk_falls_back_to_full_scan(self):
"""``id > X`` is a range, not Equal/In, so cannot derive a bucket
set. Selector returns None — result must still be exact."""
table = self._create_pk_table('pk_range')
rows = [{'id': i, 'val': i} for i in range(50)]
self._write(table, rows)
pred = table.new_read_builder().new_predicate_builder().greater_or_equal(
'id', 40)
got, _ = self._read_with(table, pred)
self.assertEqual(sorted([r['id'] for r in got]), list(range(40, 50)))
# -- Mixed predicate: Equal on PK AND range on val ------------------
def test_pk_equal_with_unrelated_value_predicate_still_prunes(self):
table = self._create_pk_table('int_eq_with_val')
rows = [{'id': i, 'val': i} for i in range(50)]
self._write(table, rows)
pb = table.new_read_builder().new_predicate_builder()
pred = pb.and_predicates([
pb.equal('id', 25),
pb.greater_than('val', 20),
])
got, splits = self._read_with(table, pred)
self.assertEqual(got, [{'id': 25, 'val': 25}])
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, [25]),
"Equal on PK still narrows to the writer's bucket "
"even when AND'd with a non-bucket-key predicate")
def test_early_filter_skips_full_entry_decode_for_pruned_buckets(self):
"""Entries the bucket selector rejects must never reach
``GenericRowDeserializer.from_bytes`` for their partition / key
stats. Without the early filter the count would scale with the
manifest entry count; with it, only the surviving entries pay
the deserialisation cost."""
from unittest import mock
from pypaimon.table.row import generic_row
table = self._create_pk_table('early_filter')
# 8 separate single-row commits → 8 manifest entries each touching
# a different bucket. ``pk = X`` should reach exactly one of them.
for i in range(self.NUM_BUCKETS):
self._write(table, [{'id': i, 'val': i * 11}])
pred = table.new_read_builder().new_predicate_builder().equal('id', 0)
rb = table.new_read_builder().with_filter(pred)
real_from_bytes = generic_row.GenericRowDeserializer.from_bytes
calls = {'n': 0}
def counting(*args, **kwargs):
calls['n'] += 1
return real_from_bytes(*args, **kwargs)
with mock.patch.object(generic_row.GenericRowDeserializer,
'from_bytes',
side_effect=counting):
splits = rb.new_scan().plan().splits()
got = rb.new_read().to_arrow(splits).to_pylist() if splits else []
self.assertEqual(got, [{'id': 0, 'val': 0}])
# Each surviving entry decodes partition + min_key + max_key
# (3 ``from_bytes`` calls). Allow a small slack in case the planner
# touches extras, but assert it is well below 8 entries × 3 = 24.
self.assertLess(
calls['n'], 3 * self.NUM_BUCKETS,
"early filter should skip from_bytes for pruned entries; "
"got {} calls (would be {}+ without the filter)".format(
calls['n'], 3 * self.NUM_BUCKETS))
def test_init_bucket_selector_fails_open_when_bucket_keys_raises(self):
"""``TableSchema.bucket_keys`` raises if ``bucket-key`` references
an unknown column. The pre-Java-alignment selector path used to
catch ``Exception`` from instantiating ``FixedBucketRowKeyExtractor``
and silently skip pruning; that property must survive the move
of bucket-key resolution onto ``TableSchema``. Crashing the scan
on a misconfiguration would be worse than skipping the
optimisation."""
table = self._create_pk_table('init_fails_open')
self._write(table, [{'id': 1, 'val': 1}])
# Mutate the in-memory schema options to a broken value to
# simulate a corrupted/migrated catalog without rewriting it.
table.table_schema.options['bucket-key'] = 'nope_no_such_column'
rb = table.new_read_builder().with_filter(
table.new_read_builder().new_predicate_builder().equal('id', 1))
scanner = rb.new_scan().file_scanner
# Must NOT raise: the broken option falls back to "no pruning",
# and the scan still finds the row.
self.assertIsNone(scanner._init_bucket_selector())
got, _ = self._read_with(table, scanner.predicate)
self.assertEqual(got, [{'id': 1, 'val': 1}])
# -- Explicit bucket-key option ------------------------------------
def test_bucket_key_option_overrides_pk_for_pruning(self):
"""When the ``bucket-key`` option is set explicitly, the bucket
derivation must use it — not the trimmed primary keys. This is
the path that catches read/write hash divergence if a refactor
forgets the option."""
# PK = id, bucket-key = id explicitly (single key but exercises
# the explicit-config branch in ``_init_bucket_selector``).
table = self._create_pk_table('explicit_bk', bucket_key='id')
rows = [{'id': i, 'val': i * 3} for i in range(40)]
self._write(table, rows)
pred = table.new_read_builder().new_predicate_builder().equal('id', 17)
got, splits = self._read_with(table, pred)
self.assertEqual(got, [{'id': 17, 'val': 51}])
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, [17]))
def test_per_partition_pruning_with_mixed_or(self):
"""``(part='a' AND id=1) OR (part='b' AND id=2)``: each partition
sees only the bucket for its own ``id`` literal. Without
per-partition predicate specialisation this query falls through
to "all buckets in both partitions"."""
opts = {'bucket': '4', 'file.format': 'parquet'}
pa_schema = pa.schema([
pa.field('part', pa.string(), nullable=False),
pa.field('id', pa.int64(), nullable=False),
('val', pa.int64()),
])
schema = Schema.from_pyarrow_schema(
pa_schema, primary_keys=['part', 'id'],
partition_keys=['part'], options=opts)
identifier = 'default.per_part_mixed_or'
self.catalog.create_table(identifier, schema, False)
table = self.catalog.get_table(identifier)
# Two partitions × three id values each → up to 6 (part, bucket)
# combinations after the writer hashes.
rows = []
for p in ('a', 'b'):
for i in (1, 2, 3):
rows.append({'part': p, 'id': i, 'val': i * 7})
wb = table.new_batch_write_builder()
w = wb.new_write()
c = wb.new_commit()
try:
w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
c.commit(w.prepare_commit())
finally:
w.close()
c.close()
pb = table.new_read_builder().new_predicate_builder()
from pypaimon.common.predicate_builder import PredicateBuilder
mixed = PredicateBuilder.or_predicates([
PredicateBuilder.and_predicates([
pb.equal('part', 'a'),
pb.equal('id', 1),
]),
PredicateBuilder.and_predicates([
pb.equal('part', 'b'),
pb.equal('id', 2),
]),
])
got, splits = self._read_with(table, mixed)
# Correctness: only the two matching rows.
got_sorted = sorted(got, key=lambda r: (r['part'], r['id']))
self.assertEqual(
got_sorted,
[{'part': 'a', 'id': 1, 'val': 7},
{'part': 'b', 'id': 2, 'val': 14}])
# Pruning effectiveness: across both partitions we should see at
# most two distinct (partition, bucket) splits — one per branch.
# Without per-partition pruning we'd see every (partition, bucket)
# combo that exists on disk for the predicate's id literals.
self.assertLessEqual(len(splits), 2,
"per-partition pruning should keep ≤ 2 splits, "
"got %d" % len(splits))
# ---------------------------------------------------------------------------
# Layer 3 — Property: random PK tables, random Equal/In predicates,
# correctness vs oracle.
# ---------------------------------------------------------------------------
class BucketPruningPropertyTest(unittest.TestCase):
SEED = 0xB0CC
TRIALS = 30
@classmethod
def setUpClass(cls):
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({'warehouse': cls.warehouse})
cls.catalog.create_database('default', False)
cls.rnd = random.Random(cls.SEED)
@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.tempdir, ignore_errors=True)
def _make_table(self, idx: int, num_buckets: int):
pa_schema = pa.schema([
pa.field('k', pa.int64(), nullable=False),
('v', pa.int64()),
])
schema = Schema.from_pyarrow_schema(
pa_schema,
primary_keys=['k'],
options={'bucket': str(num_buckets), 'file.format': 'parquet'},
)
name = 'default.bp_{}'.format(idx)
self.catalog.create_table(name, schema, False)
return self.catalog.get_table(name)
def _write(self, table, rows):
pa_schema = pa.schema([
pa.field('k', pa.int64(), nullable=False),
('v', pa.int64()),
])
wb = table.new_batch_write_builder()
w = wb.new_write()
c = wb.new_commit()
try:
w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema))
c.commit(w.prepare_commit())
finally:
w.close()
c.close()
@staticmethod
def _expected_buckets(table, keys) -> set:
"""Independent oracle: writer's bucket placement for the given keys."""
ext = FixedBucketRowKeyExtractor(table.table_schema)
pa_schema = pa.schema([
pa.field('k', pa.int64(), nullable=False),
('v', pa.int64()),
])
out = set()
for k in keys:
arr = pa.RecordBatch.from_pylist(
[{'k': k, 'v': 0}], schema=pa_schema)
out.update(ext._extract_buckets_batch(arr))
return out
def test_property_pk_equal_correctness(self):
for trial in range(self.TRIALS):
num_buckets = self.rnd.choice([2, 4, 8, 16])
table = self._make_table(trial, num_buckets)
keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100))
rows = [{'k': k, 'v': k * 13} for k in keys]
self._write(table, rows)
target = self.rnd.choice(keys)
pb = table.new_read_builder().new_predicate_builder()
pred = pb.equal('k', target)
rb = table.new_read_builder().with_filter(pred)
splits = rb.new_scan().plan().splits()
if splits:
got = rb.new_read().to_arrow(splits).to_pylist()
else:
got = []
self.assertEqual(got, [{'k': target, 'v': target * 13}],
"trial {} buckets={} target={}: result mismatch"
.format(trial, num_buckets, target))
# Pruning fired AND picked the writer's bucket. Without this
# cross-check a fail-open selector (i.e. no pruning) would
# still pass the result-equality assertion above.
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, [target]),
"trial {}: bucket set != writer's placement"
.format(trial))
def test_property_pk_in_correctness(self):
for trial in range(self.TRIALS):
num_buckets = self.rnd.choice([2, 4, 8, 16])
offset = self.TRIALS + trial # avoid name clash with prev test
table = self._make_table(offset, num_buckets)
keys = self.rnd.sample(range(1000), self.rnd.randint(20, 100))
rows = [{'k': k, 'v': k * 13} for k in keys]
self._write(table, rows)
target_n = self.rnd.randint(1, min(10, len(keys)))
targets = self.rnd.sample(keys, target_n)
pb = table.new_read_builder().new_predicate_builder()
pred = pb.is_in('k', targets)
rb = table.new_read_builder().with_filter(pred)
splits = rb.new_scan().plan().splits()
if splits:
got = rb.new_read().to_arrow(splits).to_pylist()
else:
got = []
got_sorted = sorted(got, key=lambda r: r['k'])
want = sorted(
[{'k': k, 'v': k * 13} for k in targets],
key=lambda r: r['k'])
self.assertEqual(got_sorted, want,
"trial {}: IN result mismatch".format(trial))
self.assertEqual(self._split_buckets(splits),
self._expected_buckets(table, targets),
"trial {}: IN bucket set != writer's placement"
.format(trial))
@staticmethod
def _split_buckets(splits) -> set:
return {s.bucket for s in splits}
if __name__ == '__main__':
unittest.main()