blob: 89a2dbc547990d7366644e22dc1927cd15bb307e [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL WHERE clause parser for Paimon CLI.
Parses simple SQL-like WHERE expressions into Predicate objects.
Supported operators:
=, !=, <>, <, <=, >, >=,
IS NULL, IS NOT NULL,
IN (...), NOT IN (...),
BETWEEN ... AND ...,
LIKE '...'
Supported connectors: AND, OR (AND has higher precedence than OR).
Parenthesized grouping is supported.
Examples:
"age > 18"
"name = 'Alice' AND age >= 20"
"status IN ('active', 'pending')"
"score BETWEEN 60 AND 100"
"name LIKE 'A%'"
"deleted_at IS NULL"
"age > 18 OR (name = 'Bob' AND status = 'active')"
"""
import re
from typing import Any, Dict, List, Optional
from pypaimon.common.predicate import Predicate
from pypaimon.common.predicate_builder import PredicateBuilder
from pypaimon.schema.data_types import AtomicType, DataField
def extract_fields_from_where(where_string: str, available_fields: set) -> set:
"""Extract all field names referenced in a WHERE clause.
Args:
where_string: The WHERE clause string.
available_fields: Set of valid field names from the table schema.
Returns:
A set of field names referenced in the WHERE clause.
"""
if not where_string or not where_string.strip():
return set()
tokens = _tokenize(where_string.strip())
referenced_fields = set()
for token in tokens:
if token in available_fields:
referenced_fields.add(token)
return referenced_fields
def parse_where_clause(where_string: str, fields: List[DataField]) -> Optional[Predicate]:
"""Parse a SQL-like WHERE clause string into a Predicate.
Args:
where_string: The WHERE clause string (without the 'WHERE' keyword).
fields: The table schema fields for type resolution.
Returns:
A Predicate object, or None if the string is empty.
Raises:
ValueError: If the WHERE clause cannot be parsed.
"""
where_string = where_string.strip()
if not where_string:
return None
field_type_map = _build_field_type_map(fields)
predicate_builder = PredicateBuilder(fields)
tokens = _tokenize(where_string)
predicate, remaining = _parse_or_expression(tokens, predicate_builder, field_type_map)
if remaining:
raise ValueError(
f"Unexpected tokens after parsing: {' '.join(remaining)}"
)
return predicate
def _build_field_type_map(fields: List[DataField]) -> Dict[str, Optional[str]]:
"""Build a mapping from field name to its base type string.
Only AtomicType fields are supported for WHERE filtering.
Non-atomic types (ARRAY, MAP, ROW, etc.) are mapped to None.
"""
result = {}
for field in fields:
if isinstance(field.type, AtomicType):
result[field.name] = field.type.type.upper()
else:
result[field.name] = None
return result
def _cast_literal(value_str: str, type_name: str) -> Any:
"""Cast a literal string to the appropriate Python type based on the field type."""
integer_types = {'TINYINT', 'SMALLINT', 'INT', 'INTEGER', 'BIGINT'}
float_types = {'FLOAT', 'DOUBLE'}
base_type = type_name.split('(')[0].strip()
if base_type in integer_types:
return int(value_str)
if base_type in float_types:
return float(value_str)
if base_type.startswith('DECIMAL') or base_type in ('DECIMAL', 'NUMERIC', 'DEC'):
return float(value_str)
if base_type == 'BOOLEAN':
return value_str.lower() in ('true', '1', 'yes')
return value_str
_TOKEN_PATTERN = re.compile(
r"""
'(?:[^'\\]|\\.)*' # single-quoted string
| "(?:[^"\\]|\\.)*" # double-quoted string
| <= # <=
| >= # >=
| <> # <>
| != # !=
| [=<>] # single-char operators
| [(),] # punctuation
| [^\s,()=<>!'"]+ # unquoted word / number
""",
re.VERBOSE,
)
def _tokenize(expression: str) -> List[str]:
"""Tokenize a WHERE clause string."""
return _TOKEN_PATTERN.findall(expression)
def _parse_or_expression(
tokens: List[str],
builder: PredicateBuilder,
type_map: Dict[str, str],
) -> (Predicate, List[str]):
"""Parse an OR expression (lowest precedence)."""
left, tokens = _parse_and_expression(tokens, builder, type_map)
or_operands = [left]
while tokens and tokens[0].upper() == 'OR':
tokens = tokens[1:] # consume 'OR'
right, tokens = _parse_and_expression(tokens, builder, type_map)
or_operands.append(right)
if len(or_operands) == 1:
return or_operands[0], tokens
return PredicateBuilder.or_predicates(or_operands), tokens
def _parse_and_expression(
tokens: List[str],
builder: PredicateBuilder,
type_map: Dict[str, str],
) -> (Predicate, List[str]):
"""Parse an AND expression."""
left, tokens = _parse_primary(tokens, builder, type_map)
and_operands = [left]
while tokens and tokens[0].upper() == 'AND':
# Distinguish 'AND' as connector vs. 'AND' in 'BETWEEN ... AND ...'
# BETWEEN's AND is consumed inside _parse_primary, so here it's always a connector.
tokens = tokens[1:] # consume 'AND'
right, tokens = _parse_primary(tokens, builder, type_map)
and_operands.append(right)
if len(and_operands) == 1:
return and_operands[0], tokens
return PredicateBuilder.and_predicates(and_operands), tokens
def _parse_primary(
tokens: List[str],
builder: PredicateBuilder,
type_map: Dict[str, str],
) -> (Predicate, List[str]):
"""Parse a primary expression: a single condition or a parenthesized group."""
if not tokens:
raise ValueError("Unexpected end of WHERE clause")
# Parenthesized group
if tokens[0] == '(':
tokens = tokens[1:] # consume '('
predicate, tokens = _parse_or_expression(tokens, builder, type_map)
if not tokens or tokens[0] != ')':
raise ValueError("Missing closing parenthesis ')'")
tokens = tokens[1:] # consume ')'
return predicate, tokens
# Must be a condition starting with a field name
field_name = tokens[0]
tokens = tokens[1:]
if not tokens:
raise ValueError(f"Unexpected end after field name '{field_name}'")
if field_name not in type_map:
raise ValueError(
f"Unknown field '{field_name}'. "
f"Available fields: {sorted(type_map.keys())}"
)
field_type = type_map[field_name]
if field_type is None:
raise ValueError(
f"Field '{field_name}' has a non-atomic type (e.g., ARRAY, MAP, ROW) "
f"which is not supported in WHERE clauses. "
f"Only atomic type fields (INT, STRING, DOUBLE, etc.) can be used for filtering."
)
operator_token = tokens[0].upper()
# IS NULL / IS NOT NULL
if operator_token == 'IS':
tokens = tokens[1:] # consume 'IS'
if not tokens:
raise ValueError(f"Unexpected end after 'IS' for field '{field_name}'")
next_token = tokens[0].upper()
if next_token == 'NULL':
tokens = tokens[1:]
return builder.is_null(field_name), tokens
elif next_token == 'NOT':
tokens = tokens[1:] # consume 'NOT'
if not tokens or tokens[0].upper() != 'NULL':
raise ValueError(f"Expected 'NULL' after 'IS NOT' for field '{field_name}'")
tokens = tokens[1:] # consume 'NULL'
return builder.is_not_null(field_name), tokens
else:
raise ValueError(f"Expected 'NULL' or 'NOT NULL' after 'IS' for field '{field_name}'")
# NOT IN / NOT BETWEEN
if operator_token == 'NOT':
tokens = tokens[1:] # consume 'NOT'
if not tokens:
raise ValueError(f"Expected 'IN' or 'BETWEEN' after 'NOT' for field '{field_name}'")
next_keyword = tokens[0].upper()
if next_keyword == 'IN':
tokens = tokens[1:] # consume 'IN'
values, tokens = _parse_in_list(tokens, field_type)
return builder.is_not_in(field_name, values), tokens
elif next_keyword == 'BETWEEN':
tokens = tokens[1:] # consume 'BETWEEN'
lower_str, tokens = _consume_literal(tokens)
lower_value = _cast_literal(lower_str, field_type)
if not tokens or tokens[0].upper() != 'AND':
raise ValueError(f"Expected 'AND' in NOT BETWEEN expression for field '{field_name}'")
tokens = tokens[1:] # consume 'AND'
upper_str, tokens = _consume_literal(tokens)
upper_value = _cast_literal(upper_str, field_type)
return builder.not_between(field_name, lower_value, upper_value), tokens
else:
raise ValueError(f"Expected 'IN' or 'BETWEEN' after 'NOT' for field '{field_name}'")
# IN (...)
if operator_token == 'IN':
tokens = tokens[1:] # consume 'IN'
values, tokens = _parse_in_list(tokens, field_type)
return builder.is_in(field_name, values), tokens
# BETWEEN ... AND ...
if operator_token == 'BETWEEN':
tokens = tokens[1:] # consume 'BETWEEN'
lower_str, tokens = _consume_literal(tokens)
lower_value = _cast_literal(lower_str, field_type)
if not tokens or tokens[0].upper() != 'AND':
raise ValueError(f"Expected 'AND' in BETWEEN expression for field '{field_name}'")
tokens = tokens[1:] # consume 'AND'
upper_str, tokens = _consume_literal(tokens)
upper_value = _cast_literal(upper_str, field_type)
return builder.between(field_name, lower_value, upper_value), tokens
# LIKE 'pattern'
if operator_token == 'LIKE':
tokens = tokens[1:] # consume 'LIKE'
pattern_str, tokens = _consume_literal(tokens)
return builder.like(field_name, pattern_str), tokens
# Comparison operators: =, !=, <>, <, <=, >, >=
comparison_operators = {'=', '!=', '<>', '<', '<=', '>', '>='}
if operator_token in comparison_operators:
tokens = tokens[1:] # consume operator
value_str, tokens = _consume_literal(tokens)
value = _cast_literal(value_str, field_type)
predicate = _build_comparison(builder, field_name, operator_token, value)
return predicate, tokens
raise ValueError(
f"Unsupported operator '{tokens[0]}' for field '{field_name}'. "
f"Supported: =, !=, <>, <, <=, >, >=, IS NULL, IS NOT NULL, IN, NOT IN, BETWEEN, LIKE"
)
def _build_comparison(
builder: PredicateBuilder,
field_name: str,
operator: str,
value: Any,
) -> Predicate:
"""Build a comparison predicate."""
if operator == '=':
return builder.equal(field_name, value)
elif operator in ('!=', '<>'):
return builder.not_equal(field_name, value)
elif operator == '<':
return builder.less_than(field_name, value)
elif operator == '<=':
return builder.less_or_equal(field_name, value)
elif operator == '>':
return builder.greater_than(field_name, value)
elif operator == '>=':
return builder.greater_or_equal(field_name, value)
else:
raise ValueError(f"Unknown comparison operator: {operator}")
def _parse_in_list(tokens: List[str], field_type: str) -> (List[Any], List[str]):
"""Parse an IN list: (val1, val2, ...)."""
if not tokens or tokens[0] != '(':
raise ValueError("Expected '(' after IN")
tokens = tokens[1:] # consume '('
values = []
while tokens:
if tokens[0] == ')':
tokens = tokens[1:] # consume ')'
return values, tokens
if tokens[0] == ',':
tokens = tokens[1:] # consume ','
continue
value_str, tokens = _consume_literal(tokens)
values.append(_cast_literal(value_str, field_type))
raise ValueError("Missing closing ')' in IN list")
def _consume_literal(tokens: List[str]) -> (str, List[str]):
"""Consume a single literal value from the token stream.
Handles quoted strings (strips quotes) and unquoted values.
"""
if not tokens:
raise ValueError("Expected a literal value but reached end of expression")
token = tokens[0]
tokens = tokens[1:]
# Strip surrounding quotes from string literals
if (token.startswith("'") and token.endswith("'")) or \
(token.startswith('"') and token.endswith('"')):
return token[1:-1], tokens
return token, tokens