blob: 5497406c5cd22dde6f51a332c87efc0f7f771e81 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import re
from typing import Mapping, Set
import pyarrow as pa
_COL_REF_PATTERN = re.compile(r'\b([st])\.(\w+)\b')
def _require_datafusion():
try:
import datafusion
return datafusion
except ImportError:
raise ImportError(
"merge_into condition expressions require the 'datafusion' "
"package. Install it with: pip install pypaimon[sql]"
)
_STRING_LITERAL = re.compile(r"'(?:[^']|'')*'")
def _strip_string_literals(condition: str) -> str:
return _STRING_LITERAL.sub('', condition)
def rewrite_condition(condition: str) -> str:
parts, last = [], 0
for m in _STRING_LITERAL.finditer(condition):
parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:m.start()]))
parts.append(m.group())
last = m.end()
parts.append(_COL_REF_PATTERN.sub(r'"\1.\2"', condition[last:]))
return ''.join(parts)
def remap_source_on_keys(
rewritten: str, on_map: Mapping[str, str],
) -> str:
for s_col, t_col in on_map.items():
old, new = f'"s.{s_col}"', f'"t.{t_col}"'
parts, last = [], 0
for m in _STRING_LITERAL.finditer(rewritten):
parts.append(rewritten[last:m.start()].replace(old, new))
parts.append(m.group())
last = m.end()
parts.append(rewritten[last:].replace(old, new))
rewritten = ''.join(parts)
return rewritten
def filter_batch(
batch: pa.Table, condition: str, _pre_rewritten: bool = False,
) -> pa.Table:
if batch.num_rows == 0:
return batch
datafusion = _require_datafusion()
rewritten = condition if _pre_rewritten else rewrite_condition(condition)
ctx = datafusion.SessionContext()
ctx.register_record_batches("_batch", [batch.to_batches()])
result = ctx.sql(
f'SELECT * FROM _batch WHERE {rewritten}'
)
return result.to_arrow_table()
def apply_condition(
batch: pa.Table, rewritten: str, empty_schema: pa.Schema,
) -> pa.Table:
batch = filter_batch(batch, rewritten, _pre_rewritten=True)
if batch.num_rows == 0:
return empty_schema.empty_table()
return batch
def extract_columns(condition: str) -> Set[str]:
stripped = _strip_string_literals(condition)
return {f"{m.group(1)}.{m.group(2)}"
for m in _COL_REF_PATTERN.finditer(stripped)}
def extract_target_columns(condition: str) -> Set[str]:
stripped = _strip_string_literals(condition)
return {m.group(2) for m in _COL_REF_PATTERN.finditer(stripped)
if m.group(1) == "t"}