blob: ab20a56dc7a6cc43dd9399e25da79d612cd06c0f [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""MERGE INTO ... USING ... for Paimon data-evolution tables via Ray Datasets."""
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
import pyarrow as pa
from pypaimon.ray.data_evolution_merge_join import (
build_matched_update_ds,
build_not_matched_insert_ds,
build_self_merge_update_ds,
distributed_update_apply,
distributed_write_collect_msgs,
)
from pypaimon.ray.data_evolution_merge_transform import (
LiteralValue,
OnSpec,
SetSpec,
SourceColumnRef,
TargetColumnRef,
WhenMatched,
WhenNotMatched,
_NormalizedClause,
)
__all__ = ["merge_into", "WhenMatched", "WhenNotMatched"]
@dataclass(frozen=True)
class _PrepareCtx:
"""Bag of values _prepare hands to _build_datasets."""
target_on_cols: List[str]
source_on_cols: List[str]
settable_field_names: List[str]
full_target_field_names: List[str]
update_pa_schema: pa.Schema
full_pa_schema: pa.Schema
catalog_options: Dict[str, str]
is_self_merge: bool = False
def merge_into(
target: str,
source: Any,
catalog_options: Dict[str, str],
*,
on: OnSpec,
when_matched: Sequence[WhenMatched] = (),
when_not_matched: Sequence[WhenNotMatched] = (),
num_partitions: Optional[int] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
concurrency: Optional[int] = None,
) -> Dict[str, int]:
_require_ray_join()
num_partitions = _resolve_num_partitions(num_partitions)
table, source_ds, matched_specs, not_matched_specs, ctx = _prepare(
target, source, catalog_options,
list(when_matched), list(when_not_matched), on,
)
base_snapshot = table.snapshot_manager().get_latest_snapshot()
update_ds, insert_ds, update_cols_union = _build_datasets(
target, source_ds, matched_specs, not_matched_specs,
ctx, base_snapshot, num_partitions, ray_remote_args,
)
return _execute_and_commit(
table, update_ds, insert_ds, update_cols_union,
base_snapshot, num_partitions,
ray_remote_args, concurrency,
)
def _prepare(target, source, catalog_options, when_matched, when_not_matched, on):
if not when_matched and not when_not_matched:
raise ValueError(
"At least one of when_matched or when_not_matched must be non-empty."
)
for label, clauses in [("when_matched", when_matched),
("when_not_matched", when_not_matched)]:
for i, clause in enumerate(clauses[:-1]):
if clause.condition is None:
raise ValueError(
f"Only the last {label} clause may omit its condition. "
f"Clause at index {i} has no condition, making subsequent "
f"clauses unreachable."
)
target_on_cols, source_on_cols = _normalize_on(on)
from pypaimon.catalog.catalog_factory import CatalogFactory
catalog = CatalogFactory.create(catalog_options)
table = catalog.get_table(target)
if not table.options.data_evolution_enabled():
raise ValueError(
f"merge_into requires 'data-evolution.enabled' = 'true' on '{target}'."
)
if not table.options.row_tracking_enabled():
raise ValueError(
f"merge_into requires 'row-tracking.enabled' = 'true' on '{target}'."
)
blob_cols = _blob_col_names(table)
full_target_field_names = list(table.field_names)
# SET specs only cover non-blob columns: update can't rewrite blob files
# (data evolution puts them in dedicated .blob files), and insert leaves
# blob columns null since the source can't carry them through SET="*".
settable_field_names = [
c for c in full_target_field_names if c not in blob_cols
]
on_map = dict(zip(target_on_cols, source_on_cols))
matched_specs = [
_NormalizedClause(
spec=_normalize_set_spec(
c.update, settable_field_names, on_map,
),
condition=c.condition,
)
for c in when_matched
]
if matched_specs and table.partition_keys:
partition_set = set(table.partition_keys)
for clause in matched_specs:
modified_partition_cols = partition_set & set(clause.spec.keys())
if modified_partition_cols:
raise ValueError(
f"merge_into does not support updating partition columns "
f"{sorted(modified_partition_cols)}; cross-partition row "
f"movement is not implemented."
)
has_condition = any(
c.condition is not None
for c in list(when_matched) + list(when_not_matched)
)
if has_condition:
from pypaimon.ray.merge_condition import (
_require_datafusion, extract_target_columns,
)
_require_datafusion()
for c in when_not_matched:
if c.condition is not None:
t_refs = extract_target_columns(c.condition)
if t_refs:
raise ValueError(
f"WhenNotMatched condition must not reference "
f"target columns (t.*), but found: {sorted(t_refs)}"
)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
blob_refs = extract_target_columns(c.condition) & blob_cols
if blob_refs:
raise ValueError(
f"condition must not reference blob columns, "
f"but found: {sorted(blob_refs)}"
)
not_matched_specs = []
for c in when_not_matched:
spec = _normalize_set_spec(
c.insert, settable_field_names, on_map,
allow_target_refs=False,
)
for tk, sk in on_map.items():
if tk in settable_field_names and tk not in spec:
spec[tk] = SourceColumnRef(sk)
not_matched_specs.append(
_NormalizedClause(spec=spec, condition=c.condition)
)
is_self_merge = _is_self_merge(target, source, target_on_cols, source_on_cols)
if is_self_merge and not_matched_specs:
raise ValueError(
"Self-merge (source == target with ON _ROW_ID) does not "
"support WHEN NOT MATCHED clauses."
)
if is_self_merge:
source_ds = None
source_col_names = set(full_target_field_names) | set(source_on_cols)
else:
source_snapshot_id = None
if isinstance(source, str):
source_snapshot = (
catalog.get_table(source)
.snapshot_manager()
.get_latest_snapshot()
)
if source_snapshot is not None:
source_snapshot_id = source_snapshot.id
source_ds = _normalize_source(
source, catalog_options, source_snapshot_id=source_snapshot_id,
)
_validate_source_on_cols(source_ds, source_on_cols)
source_col_names = set(_source_schema_or_raise(source_ds).names)
_validate_source_has_target_cols(
source_col_names, matched_specs + not_matched_specs,
)
if has_condition:
from pypaimon.ray.merge_condition import extract_columns
target_names = set(full_target_field_names)
if is_self_merge:
target_names |= set(target_on_cols)
for c in list(when_matched) + list(when_not_matched):
if c.condition is not None:
for ref in extract_columns(c.condition):
prefix, col = ref.split(".", 1)
if prefix == "s" and col not in source_col_names:
raise ValueError(
f"condition references unknown source "
f"column '{col}'"
)
if prefix == "t" and col not in target_names:
raise ValueError(
f"condition references unknown target "
f"column '{col}'"
)
from pypaimon.schema.data_types import PyarrowFieldParser
full_pa_schema = PyarrowFieldParser.from_paimon_schema(
table.table_schema.fields
)
# update_pa_schema strips blob (only non-blob cols are written by the
# update path); insert_pa_schema is the full table schema so the writer
# gets every column (blob columns end up null).
update_pa_schema = pa.schema(
[full_pa_schema.field(c) for c in settable_field_names]
)
ctx = _PrepareCtx(
target_on_cols=target_on_cols,
source_on_cols=source_on_cols,
settable_field_names=settable_field_names,
full_target_field_names=full_target_field_names,
update_pa_schema=update_pa_schema,
full_pa_schema=full_pa_schema,
catalog_options=catalog_options,
is_self_merge=is_self_merge,
)
return table, source_ds, matched_specs, not_matched_specs, ctx
def _is_self_merge(target, source, target_on_cols, source_on_cols) -> bool:
from pypaimon.table.special_fields import SpecialFields
row_id_name = SpecialFields.ROW_ID.name
return (isinstance(source, str)
and source == target
and target_on_cols == [row_id_name]
and source_on_cols == [row_id_name])
def _build_datasets(
target, source_ds, matched_specs, not_matched_specs,
ctx: "_PrepareCtx", base_snapshot, num_partitions, ray_remote_args,
):
# Pin every target read to base_snapshot so all branches see the same
# snapshot the caller observed; otherwise concurrent commits in between
# would mix data from different snapshots.
base_snapshot_id = base_snapshot.id if base_snapshot is not None else None
update_ds = None
insert_ds = None
update_cols_union: List[str] = []
if ctx.is_self_merge:
if matched_specs and base_snapshot is not None:
update_cols_union = _union_update_cols(matched_specs)
update_ds = build_self_merge_update_ds(
target_identifier=target,
clauses=matched_specs,
target_field_names=ctx.full_target_field_names,
target_pa_schema=ctx.update_pa_schema,
update_cols=update_cols_union,
catalog_options=ctx.catalog_options,
resolve_target_projection=_resolve_target_projection,
snapshot_id=base_snapshot_id,
ray_remote_args=ray_remote_args,
)
return update_ds, insert_ds, update_cols_union
# Mirror Spark: matched/not-matched run as two independent joins
# (inner / left_anti). One unified left_outer join would force
# joined.materialize() to feed both branches, which can OOM on large merges.
if matched_specs and base_snapshot is not None:
update_cols_union = _union_update_cols(matched_specs)
update_ds = build_matched_update_ds(
target_identifier=target,
source_ds=source_ds,
target_on=ctx.target_on_cols,
source_on=ctx.source_on_cols,
clauses=matched_specs,
target_field_names=ctx.settable_field_names,
target_pa_schema=ctx.update_pa_schema,
update_cols=update_cols_union,
catalog_options=ctx.catalog_options,
num_partitions=num_partitions,
resolve_target_projection=_resolve_target_projection,
snapshot_id=base_snapshot_id,
ray_remote_args=ray_remote_args,
)
if not_matched_specs:
# Insert writes the full target schema; SET spec only covers
# settable cols, so blob columns fall through to null.
insert_ds = build_not_matched_insert_ds(
target_identifier=target,
source_ds=source_ds,
target_on=ctx.target_on_cols,
source_on=ctx.source_on_cols,
clauses=not_matched_specs,
target_field_names=ctx.full_target_field_names,
target_pa_schema=ctx.full_pa_schema,
catalog_options=ctx.catalog_options,
num_partitions=num_partitions,
snapshot_id=base_snapshot_id,
target_empty=base_snapshot is None,
ray_remote_args=ray_remote_args,
)
return update_ds, insert_ds, update_cols_union
def _execute_and_commit(
table, update_ds, insert_ds, update_cols_union,
base_snapshot, num_partitions,
ray_remote_args, concurrency,
):
update_msgs: list = []
num_updated = 0
if update_ds is not None:
try:
update_msgs, num_updated = distributed_update_apply(
update_ds, table, update_cols_union,
num_partitions=num_partitions,
ray_remote_args=ray_remote_args,
base_snapshot_id=(
base_snapshot.id
if base_snapshot is not None else None
),
)
except Exception as e:
_reraise_inner(e)
all_msgs: list = list(update_msgs)
num_inserted = 0
if insert_ds is not None:
try:
insert_msgs = distributed_write_collect_msgs(
insert_ds, table,
ray_remote_args=ray_remote_args, concurrency=concurrency,
)
except Exception as e:
_reraise_inner(e)
num_inserted = sum(
f.row_count for m in insert_msgs for f in m.new_files
)
all_msgs.extend(insert_msgs)
if all_msgs:
wb = table.new_batch_write_builder()
tc = wb.new_commit()
tc.commit(all_msgs)
tc.close()
# num_matched = rows that passed the condition and were updated
return {
"num_matched": num_updated,
"num_inserted": num_inserted,
"num_unchanged": 0,
}
def _normalize_on(on: OnSpec) -> Tuple[List[str], List[str]]:
if isinstance(on, Mapping):
target_cols = list(on.keys())
source_cols = list(on.values())
else:
target_cols = list(on)
source_cols = list(on)
if not target_cols:
raise ValueError("'on' must be non-empty.")
return target_cols, source_cols
def _resolve_num_partitions(num_partitions: Optional[int]) -> int:
if num_partitions is not None:
return num_partitions
try:
import ray
cpus = int(ray.cluster_resources().get("CPU", 4))
return max(1, cpus * 2)
except Exception:
return 4
def _require_ray_join() -> None:
import ray
from packaging.version import parse
if parse(ray.__version__) < parse("2.50.0"):
raise RuntimeError(
f"merge_into requires ray>=2.50; "
f"installed ray is {ray.__version__}."
)
def _blob_col_names(table) -> set:
return {
f.name
for f in table.table_schema.fields
if getattr(f.type, "type", None) == "BLOB"
}
def _reraise_inner(err: BaseException) -> None:
"""Unwrap Ray's RayTaskError so callers see the worker-side exception."""
inner = err
cause = getattr(err, "cause", None) or getattr(err, "__cause__", None)
while cause is not None:
inner = cause
cause = getattr(inner, "cause", None) or getattr(inner, "__cause__", None)
if inner is err:
raise err
raise inner from err
def _union_update_cols(clauses: List[_NormalizedClause]) -> List[str]:
seen: List[str] = []
seen_set: set = set()
for clause in clauses:
for col in clause.spec.keys():
if col not in seen_set:
seen.append(col)
seen_set.add(col)
return seen
def _needed_target_cols(
clauses: List[_NormalizedClause],
on: Sequence[str],
update_cols: Sequence[str],
all_target_cols: Sequence[str],
) -> list:
# Target needs only: join keys, t.col refs, and cols that may fall back
# (not set by every clause). Cols all clauses set from source aren't read.
needed = set(on)
set_by_all = set(update_cols)
for clause in clauses:
for value in clause.spec.values():
if isinstance(value, TargetColumnRef):
needed.add(value.column)
set_by_all &= set(clause.spec.keys())
needed |= set(update_cols) - set_by_all
return [c for c in all_target_cols if c in needed]
def _resolve_target_projection(
clauses: List[_NormalizedClause],
target_on: Sequence[str],
update_cols: Sequence[str],
target_field_names: Sequence[str],
) -> list:
needed = set(_needed_target_cols(
clauses, target_on, update_cols, target_field_names,
))
if any(c.condition is not None for c in clauses):
from pypaimon.ray.merge_condition import extract_target_columns
target_set = set(target_field_names)
for clause in clauses:
if clause.condition is not None:
needed |= extract_target_columns(clause.condition) & target_set
return [c for c in target_field_names if c in needed]
def _normalize_set_spec(
spec: SetSpec,
target_field_names: Sequence[str],
on_map: Optional[Mapping[str, str]] = None,
allow_target_refs: bool = True,
) -> Dict[str, Any]:
on_map = on_map or {}
if spec == "*":
return {
col: SourceColumnRef(on_map.get(col, col))
for col in target_field_names
}
if not isinstance(spec, Mapping):
raise TypeError(
f"SET spec must be '*' or a mapping, got {type(spec).__name__}"
)
if not spec:
raise ValueError("SET spec must not be empty")
target_set = set(target_field_names)
for key in spec:
if key not in target_set:
raise ValueError(
f"SET spec references unknown target column '{key}'"
)
result: Dict[str, Any] = {}
for key, val in spec.items():
if callable(val) and not isinstance(val, type):
raise TypeError(
"SET values must be source_col(), target_col(), "
"lit(), or literals, not callables"
)
if isinstance(val, SourceColumnRef):
result[key] = val
elif isinstance(val, TargetColumnRef):
if not allow_target_refs:
raise ValueError(
"INSERT spec must not reference target columns "
f"(t.*), but found: 't.{val.column}'"
)
if val.column not in target_set:
raise ValueError(
f"SET spec references unknown target column "
f"'{val.column}'"
)
result[key] = val
elif isinstance(val, LiteralValue):
result[key] = val
elif isinstance(val, str) and val.startswith("s."):
result[key] = SourceColumnRef(val[2:])
elif isinstance(val, str) and val.startswith("t."):
if not allow_target_refs:
raise ValueError(
"INSERT spec must not reference target columns "
f"(t.*), but found: '{val}'"
)
ref = val[2:]
if ref not in target_set:
raise ValueError(
f"SET spec references unknown target column '{ref}'"
)
result[key] = TargetColumnRef(ref)
else:
result[key] = LiteralValue(val)
return result
def _normalize_source(
source: Any,
catalog_options: Dict[str, str],
source_snapshot_id: Optional[int] = None,
):
import ray.data
if isinstance(source, ray.data.Dataset):
return source
if isinstance(source, str):
from pypaimon.ray.ray_paimon import read_paimon
read_kwargs = {}
if source_snapshot_id is not None:
read_kwargs["snapshot_id"] = source_snapshot_id
return read_paimon(source, catalog_options, **read_kwargs)
if isinstance(source, pa.Table):
return ray.data.from_arrow(source)
try:
import pandas as pd
except ImportError:
pd = None
if pd is not None and isinstance(source, pd.DataFrame):
return ray.data.from_pandas(source)
raise TypeError(
"source must be a ray.data.Dataset, a Paimon table identifier string, "
f"a pyarrow.Table, or a pandas.DataFrame; got {type(source).__name__}."
)
def _source_schema_or_raise(source_ds):
"""Get source schema; refuse to proceed if Ray can't tell us the columns."""
schema = source_ds.schema()
if schema is None:
raise ValueError(
"merge_into could not infer the source schema; pass a "
"ray.data.Dataset that has been materialized (e.g. via "
".materialize()) or constructed from pyarrow/pandas."
)
return schema
def _validate_source_on_cols(source_ds, on: Sequence[str]) -> None:
names = set(_source_schema_or_raise(source_ds).names)
missing = [c for c in on if c not in names]
if missing:
raise ValueError(
f"'on' columns {missing} missing from source schema {list(names)}."
)
def _validate_source_has_target_cols(
source_col_names: set,
specs: List[_NormalizedClause],
) -> None:
needed = set()
for clause in specs:
for val in clause.spec.values():
if isinstance(val, SourceColumnRef):
needed.add(val.column)
missing = sorted(needed - source_col_names)
if missing:
raise ValueError(
f"source is missing columns {missing} referenced by SET spec"
)