| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| from typing import Any, Dict, List, Optional, Sequence, Tuple |
| |
| import pyarrow as pa |
| |
| from pypaimon.ray.data_evolution_merge_transform import ( |
| SourceColumnRef, |
| _NormalizedClause, |
| build_update_schema, |
| vectorized_insert_transform, |
| vectorized_matched_transform, |
| ) |
| |
| |
| def _map_kwargs( |
| ray_remote_args: Optional[Dict[str, Any]], |
| ) -> Dict[str, Any]: |
| """Build kwargs for map_batches/map_groups; spread ray_remote_args because |
| those APIs take remote options as **kwargs, not under a 'ray_remote_args' |
| key.""" |
| kwargs: Dict[str, Any] = {"batch_format": "pyarrow"} |
| if ray_remote_args: |
| kwargs.update(ray_remote_args) |
| return kwargs |
| |
| |
| def _build_matched_transform( |
| clauses: List[_NormalizedClause], |
| on_map: Dict[str, str], |
| on_pairs: List[Tuple[str, str]], |
| update_cols: List[str], |
| row_id_name: str, |
| update_schema: pa.Schema, |
| ): |
| prepared_clauses = [] |
| for clause in clauses: |
| rewritten = None |
| if clause.condition is not None: |
| from pypaimon.ray.merge_condition import ( |
| remap_source_on_keys, rewrite_condition, |
| ) |
| rewritten = remap_source_on_keys( |
| rewrite_condition(clause.condition), on_map, |
| ) |
| prepared_clauses.append((clause.spec, rewritten)) |
| |
| _filter_batch = None |
| if any(r is not None for _, r in prepared_clauses): |
| from pypaimon.ray.merge_condition import filter_batch as _filter_batch |
| |
| def _transform(batch: pa.Table) -> pa.Table: |
| remaining = batch |
| parts = [] |
| for spec, rewritten in prepared_clauses: |
| if remaining.num_rows == 0: |
| break |
| if rewritten is not None: |
| matched = _filter_batch( |
| remaining, rewritten, _pre_rewritten=True, |
| ) |
| else: |
| matched = remaining |
| if matched.num_rows == 0: |
| continue |
| parts.append(vectorized_matched_transform( |
| matched, spec, on_pairs, |
| update_cols, row_id_name, |
| update_schema, |
| )) |
| if rewritten is not None and matched.num_rows < remaining.num_rows: |
| not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" |
| remaining = _filter_batch( |
| remaining, not_cond, _pre_rewritten=True, |
| ) |
| else: |
| remaining = remaining.slice(0, 0) |
| if not parts: |
| return update_schema.empty_table() |
| return pa.concat_tables(parts) |
| |
| return _transform |
| |
| |
| def build_self_merge_update_ds( |
| *, |
| target_identifier: str, |
| clauses: List[_NormalizedClause], |
| target_field_names: Sequence[str], |
| target_pa_schema: pa.Schema, |
| update_cols: Sequence[str], |
| catalog_options: Dict[str, str], |
| resolve_target_projection, |
| snapshot_id: Optional[int] = None, |
| ray_remote_args: Optional[Dict[str, Any]] = None, |
| ) -> Tuple: |
| from pypaimon.ray.ray_paimon import read_paimon |
| from pypaimon.table.special_fields import SpecialFields |
| |
| row_id_name = SpecialFields.ROW_ID.name |
| needed_cols = set(resolve_target_projection( |
| clauses, [row_id_name], update_cols, target_field_names, |
| )) |
| for clause in clauses: |
| for value in clause.spec.values(): |
| if isinstance(value, SourceColumnRef): |
| needed_cols.add(value.column) |
| target_set = set(target_field_names) |
| for clause in clauses: |
| if clause.condition is not None: |
| from pypaimon.ray.merge_condition import extract_columns |
| for ref in extract_columns(clause.condition): |
| prefix, col = ref.split(".", 1) |
| if prefix == "s" and col in target_set: |
| needed_cols.add(col) |
| projection = [row_id_name] + [ |
| c for c in target_field_names if c in needed_cols |
| ] |
| |
| target_ds = read_paimon( |
| target_identifier, catalog_options, |
| projection=projection, snapshot_id=snapshot_id, |
| ) |
| update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name) |
| |
| orig_names = target_ds.schema().names |
| target_renamed = target_ds.rename_columns( |
| {c: f"t.{c}" for c in orig_names} |
| ) |
| |
| def _add_source_aliases(batch: pa.Table) -> pa.Table: |
| columns = list(batch.columns) |
| names = list(batch.schema.names) |
| for orig in orig_names: |
| if orig == row_id_name: |
| continue |
| t_col_name = f"t.{orig}" |
| if t_col_name in names: |
| idx = names.index(t_col_name) |
| columns.append(columns[idx]) |
| names.append(f"s.{orig}") |
| return pa.table(columns, names=names) |
| |
| aliased = target_renamed.map_batches( |
| _add_source_aliases, **_map_kwargs(ray_remote_args), |
| ) |
| |
| _transform = _build_matched_transform( |
| clauses, |
| on_map={row_id_name: row_id_name}, |
| on_pairs=[(row_id_name, row_id_name)], |
| update_cols=list(update_cols), |
| row_id_name=row_id_name, |
| update_schema=update_schema, |
| ) |
| return aliased.map_batches(_transform, **_map_kwargs(ray_remote_args)) |
| |
| |
| def build_matched_update_ds( |
| *, |
| target_identifier: str, |
| source_ds, |
| target_on: Sequence[str], |
| source_on: Sequence[str], |
| clauses: List[_NormalizedClause], |
| target_field_names: Sequence[str], |
| target_pa_schema: pa.Schema, |
| update_cols: Sequence[str], |
| catalog_options: Dict[str, str], |
| num_partitions: int, |
| resolve_target_projection, |
| snapshot_id: Optional[int] = None, |
| ray_remote_args: Optional[Dict[str, Any]] = None, |
| ) -> Tuple: |
| from pypaimon.ray.ray_paimon import read_paimon |
| from pypaimon.table.special_fields import SpecialFields |
| |
| row_id_name = SpecialFields.ROW_ID.name |
| needed_cols = resolve_target_projection( |
| clauses, target_on, update_cols, target_field_names, |
| ) |
| projection = [row_id_name] + [c for c in needed_cols if c != row_id_name] |
| |
| target_ds = read_paimon( |
| target_identifier, catalog_options, |
| projection=projection, snapshot_id=snapshot_id, |
| ) |
| update_schema = build_update_schema(target_pa_schema, update_cols, row_id_name) |
| |
| target_renamed = target_ds.rename_columns( |
| {c: f"t.{c}" for c in target_ds.schema().names} |
| ) |
| source_cols = list(source_ds.schema().names) |
| source_renamed = source_ds.rename_columns( |
| {c: f"s.{c}" for c in source_cols} |
| ) |
| |
| joined = target_renamed.join( |
| source_renamed, |
| join_type="inner", |
| num_partitions=num_partitions, |
| on=tuple(f"t.{c}" for c in target_on), |
| right_on=tuple(f"s.{c}" for c in source_on), |
| ) |
| |
| _transform = _build_matched_transform( |
| clauses, |
| on_map=dict(zip(source_on, target_on)), |
| on_pairs=list(zip(source_on, target_on)), |
| update_cols=list(update_cols), |
| row_id_name=row_id_name, |
| update_schema=update_schema, |
| ) |
| return joined.map_batches(_transform, **_map_kwargs(ray_remote_args)) |
| |
| |
| def distributed_update_apply( |
| update_ds, |
| table, |
| write_update_cols: Sequence[str], |
| *, |
| num_partitions: int, |
| ray_remote_args: Optional[Dict[str, Any]] = None, |
| base_snapshot_id: Optional[int] = None, |
| ) -> Tuple[list, int]: |
| import numpy as np |
| import pickle |
| import uuid |
| |
| import pyarrow.compute as pc |
| import ray |
| |
| from pypaimon.snapshot.snapshot import BATCH_COMMIT_IDENTIFIER |
| from pypaimon.table.special_fields import SpecialFields |
| from pypaimon.write.table_update_by_row_id import TableUpdateByRowId |
| |
| row_id_name = SpecialFields.ROW_ID.name |
| cols = list(write_update_cols) |
| |
| for col in cols: |
| if col not in table.field_names: |
| raise ValueError( |
| f"Column '{col}' is not in target table schema." |
| ) |
| |
| planner = TableUpdateByRowId( |
| table, |
| "_merge_into_planner_" + uuid.uuid4().hex[:8], |
| BATCH_COMMIT_IDENTIFIER, |
| ) |
| sorted_first_row_ids = list(planner.first_row_ids) |
| if not sorted_first_row_ids: |
| return [], 0 |
| |
| # Pin commit-time conflict check to the snapshot the join was built on, |
| # so concurrent commits between read and planner are detected. |
| check_from_snapshot = ( |
| base_snapshot_id if base_snapshot_id is not None |
| else planner.snapshot_id |
| ) |
| |
| # Put file metadata into Ray's object store and pass a single ref to |
| # workers. Avoids per-task manifest re-scans (Jingsong review #6) and |
| # avoids serializing the metadata into every task's closure. Override |
| # snapshot_id with the join's base snapshot so commit-time conflict |
| # detection covers the read→planner window. |
| from dataclasses import replace |
| files_info = replace( |
| planner._snapshot_files_info(), |
| snapshot_id=check_from_snapshot, |
| ) |
| precomputed_info_ref = ray.put(files_info) |
| |
| frid_col = "_FIRST_ROW_ID" |
| captured_sorted = sorted_first_row_ids |
| captured_sorted_arr = np.asarray(captured_sorted, dtype=np.int64) |
| valid_ranges = planner.valid_row_id_ranges |
| range_starts = np.asarray([r.from_ for r in valid_ranges], dtype=np.int64) |
| range_ends = np.asarray([r.to for r in valid_ranges], dtype=np.int64) |
| |
| def _assign_frid(batch: pa.Table) -> pa.Table: |
| if batch.num_rows == 0: |
| return batch.append_column( |
| frid_col, pa.array([], type=pa.int64()) |
| ) |
| rid_col = batch.column(row_id_name) |
| if rid_col.null_count: |
| raise ValueError( |
| "_ROW_ID is null; planner snapshot is stale " |
| "or matched rows come from a different table." |
| ) |
| rids = rid_col.to_numpy(zero_copy_only=False) |
| # Check each row_id belongs to a valid range (vectorized). |
| in_range = np.zeros(len(rids), dtype=bool) |
| for s, e in zip(range_starts, range_ends): |
| in_range |= (rids >= s) & (rids <= e) |
| if not in_range.all(): |
| bad = rids[~in_range][0] |
| raise ValueError( |
| f"_ROW_ID {bad} does not belong to any valid range " |
| f"{[f'[{r.from_}, {r.to}]' for r in valid_ranges]}; " |
| f"planner snapshot is stale or matched rows come " |
| f"from a different table." |
| ) |
| idx = np.searchsorted( |
| captured_sorted_arr, rids, side="right" |
| ) - 1 |
| frids = captured_sorted_arr[idx] |
| return batch.append_column( |
| frid_col, pa.array(frids, type=pa.int64()) |
| ) |
| |
| map_kwargs = _map_kwargs(ray_remote_args) |
| with_frid = update_ds.map_batches(_assign_frid, **map_kwargs) |
| |
| captured_table = table |
| captured_cols = cols |
| |
| def _apply_group(group: pa.Table) -> pa.Table: |
| if group.num_rows == 0: |
| return pa.Table.from_pydict({ |
| "msgs_blob": pa.array([], type=pa.binary()), |
| "n_updated": pa.array([], type=pa.int64()), |
| }) |
| |
| if ( |
| pc.count_distinct(group.column(row_id_name)).as_py() |
| != group.num_rows |
| ): |
| raise ValueError( |
| "MERGE matched multiple source rows to the same " |
| "target _ROW_ID. Deduplicate the source before " |
| "merging." |
| ) |
| |
| for_update = group.drop_columns([frid_col]) |
| worker = TableUpdateByRowId( |
| captured_table, |
| "_merge_into_shard_" + uuid.uuid4().hex[:8], |
| BATCH_COMMIT_IDENTIFIER, |
| _precomputed_files_info=ray.get(precomputed_info_ref), |
| ) |
| msgs = worker.update_columns(for_update, list(captured_cols)) |
| return pa.Table.from_pydict({ |
| "msgs_blob": [pickle.dumps(msgs)], |
| "n_updated": pa.array( |
| [for_update.num_rows], type=pa.int64() |
| ), |
| }) |
| |
| # One group per target data file; bounded by file count and num_partitions. |
| group_partitions = max( |
| 1, min(len(captured_sorted), num_partitions) |
| ) |
| msgs_ds = with_frid.groupby( |
| frid_col, num_partitions=group_partitions |
| ).map_groups(_apply_group, **map_kwargs) |
| |
| all_msgs: list = [] |
| num_updated = 0 |
| for batch in msgs_ds.iter_batches(batch_format="pyarrow"): |
| for blob in batch.column("msgs_blob").to_pylist(): |
| all_msgs.extend(pickle.loads(blob)) |
| for n in batch.column("n_updated").to_pylist(): |
| num_updated += n |
| return all_msgs, num_updated |
| |
| |
| def build_not_matched_insert_ds( |
| *, |
| target_identifier: str, |
| source_ds, |
| target_on: Sequence[str], |
| source_on: Sequence[str], |
| clauses: List[_NormalizedClause], |
| target_field_names: Sequence[str], |
| target_pa_schema: pa.Schema, |
| catalog_options: Dict[str, str], |
| num_partitions: int, |
| target_empty: bool = False, |
| snapshot_id: Optional[int] = None, |
| ray_remote_args: Optional[Dict[str, Any]] = None, |
| ): |
| from pypaimon.ray.ray_paimon import read_paimon |
| from pypaimon.ray.shuffle import _coerce_large_string_types |
| |
| captured_field_names = list(target_field_names) |
| out_schema = target_pa_schema |
| |
| source_cols = list(source_ds.schema().names) |
| source_renamed = source_ds.rename_columns( |
| {c: f"s.{c}" for c in source_cols} |
| ) |
| |
| if target_empty: |
| unmatched = source_renamed |
| else: |
| target_ds = read_paimon( |
| target_identifier, catalog_options, |
| projection=list(target_on), snapshot_id=snapshot_id, |
| ) |
| target_renamed = target_ds.rename_columns( |
| {c: f"t.{c}" for c in target_on} |
| ) |
| unmatched = source_renamed.join( |
| target_renamed, |
| join_type="left_anti", |
| num_partitions=num_partitions, |
| on=tuple(f"s.{c}" for c in source_on), |
| right_on=tuple(f"t.{c}" for c in target_on), |
| ) |
| |
| prepared_clauses = [] |
| for clause in clauses: |
| rewritten = None |
| if clause.condition is not None: |
| from pypaimon.ray.merge_condition import rewrite_condition |
| rewritten = rewrite_condition(clause.condition) |
| prepared_clauses.append((clause.spec, rewritten)) |
| |
| _filter_batch_nm = None |
| if any(r is not None for _, r in prepared_clauses): |
| from pypaimon.ray.merge_condition import filter_batch as _filter_batch_nm |
| |
| def _transform(batch: pa.Table) -> pa.Table: |
| remaining = batch |
| parts = [] |
| for spec, rewritten in prepared_clauses: |
| if remaining.num_rows == 0: |
| break |
| if rewritten is not None: |
| matched = _filter_batch_nm( |
| remaining, rewritten, _pre_rewritten=True, |
| ) |
| if matched.num_rows > 0: |
| parts.append(vectorized_insert_transform( |
| matched, spec, captured_field_names, out_schema |
| )) |
| if matched.num_rows < remaining.num_rows: |
| not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" |
| remaining = _filter_batch_nm( |
| remaining, not_cond, _pre_rewritten=True, |
| ) |
| else: |
| remaining = remaining.slice(0, 0) |
| else: |
| parts.append(vectorized_insert_transform( |
| remaining, spec, captured_field_names, out_schema |
| )) |
| remaining = remaining.slice(0, 0) |
| if not parts: |
| return _coerce_large_string_types(out_schema.empty_table()) |
| return _coerce_large_string_types(pa.concat_tables(parts)) |
| |
| return unmatched.map_batches( |
| _transform, **_map_kwargs(ray_remote_args) |
| ) |
| |
| |
| def distributed_write_collect_msgs( |
| insert_ds, |
| table, |
| *, |
| ray_remote_args: Optional[Dict[str, Any]], |
| concurrency: Optional[int], |
| ) -> list: |
| from pypaimon.write.ray_datasink import PaimonDatasink |
| |
| class _CollectingDatasink(PaimonDatasink): |
| def __init__(self, t): |
| super().__init__(t, overwrite=False) |
| self.collected: list = [] |
| |
| def on_write_complete(self, write_result): |
| self.collected = [ |
| m |
| for batch in self._extract_write_returns(write_result) |
| for m in batch |
| if not m.is_empty() |
| ] |
| |
| sink = _CollectingDatasink(table) |
| write_kwargs: Dict[str, Any] = {} |
| if ray_remote_args is not None: |
| write_kwargs["ray_remote_args"] = ray_remote_args |
| if concurrency is not None: |
| write_kwargs["concurrency"] = concurrency |
| insert_ds.write_datasink(sink, **write_kwargs) |
| return sink.collected |