blob: 57f08381875efdb4e0fde96ad0b51d5e9e424929 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from superset.sql.dialects.pinot import Pinot
def test_pinot_dialect_registered() -> None:
"""
Test that Pinot dialect is properly registered.
"""
from superset.sql.parse import SQLGLOT_DIALECTS
assert "pinot" in SQLGLOT_DIALECTS
assert SQLGLOT_DIALECTS["pinot"] == Pinot
def test_double_quotes_as_identifiers() -> None:
"""
Test that double quotes are treated as identifiers, not string literals.
"""
sql = 'SELECT "column_name" FROM "table_name"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"column_name"
FROM "table_name"
""".strip()
)
def test_single_quotes_for_strings() -> None:
"""
Test that single quotes are used for string literals.
"""
sql = "SELECT * FROM users WHERE name = 'John'"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM users
WHERE
name = 'John'
""".strip()
)
def test_backticks_as_identifiers() -> None:
"""
Test that backticks work as identifiers (MySQL-style).
Backticks are normalized to double quotes in output.
"""
sql = "SELECT `column_name` FROM `table_name`"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"column_name"
FROM "table_name"
""".strip()
)
def test_mixed_identifier_quotes() -> None:
"""
Test mixing double quotes and backticks for identifiers.
All identifiers are normalized to double quotes in output.
"""
sql = (
'SELECT "col1", `col2` FROM "table1" JOIN `table2` ON "table1".id = `table2`.id'
)
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"col1",
"col2"
FROM "table1"
JOIN "table2"
ON "table1".id = "table2".id
""".strip()
)
def test_string_with_escaped_quotes() -> None:
"""
Test string literals with escaped single quotes.
"""
sql = "SELECT * FROM users WHERE name = 'O''Brien'"
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM users
WHERE
name = 'O''Brien'
""".strip()
)
def test_string_with_backslash_escape() -> None:
"""
Test string literals with backslash escapes.
"""
sql = r"SELECT * FROM users WHERE path = 'C:\\Users\\John'"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert "WHERE" in generated
assert "path" in generated
@pytest.mark.parametrize(
"sql, expected",
[
(
'SELECT COUNT(*) FROM "events" WHERE "type" = \'click\'',
"""
SELECT
COUNT(*)
FROM "events"
WHERE
"type" = 'click'
""".strip(),
),
(
'SELECT "user_id", SUM("amount") FROM "transactions" GROUP BY "user_id"',
"""
SELECT
"user_id",
SUM("amount")
FROM "transactions"
GROUP BY
"user_id"
""".strip(),
),
(
"SELECT * FROM \"orders\" WHERE \"status\" IN ('pending', 'shipped')",
"""
SELECT
*
FROM "orders"
WHERE
"status" IN ('pending', 'shipped')
""".strip(),
),
],
)
def test_various_queries(sql: str, expected: str) -> None:
"""
Test various SQL queries with Pinot dialect.
"""
ast = sqlglot.parse_one(sql, Pinot)
assert Pinot().generate(expression=ast, pretty=True) == expected
def test_aggregate_functions() -> None:
"""
Test aggregate functions with quoted identifiers.
"""
sql = """
SELECT
"category",
COUNT(*),
AVG("price"),
MAX("quantity")
FROM "products"
GROUP BY "category"
"""
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"category",
COUNT(*),
AVG("price"),
MAX("quantity")
FROM "products"
GROUP BY
"category"
""".strip()
)
def test_join_with_quoted_identifiers() -> None:
"""
Test JOIN operations with double-quoted identifiers.
"""
sql = """
SELECT "u"."name", "o"."total"
FROM "users" AS "u"
JOIN "orders" AS "o" ON "u"."id" = "o"."user_id"
"""
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"u"."name",
"o"."total"
FROM "users" AS "u"
JOIN "orders" AS "o"
ON "u"."id" = "o"."user_id"
""".strip()
)
def test_subquery_with_quoted_identifiers() -> None:
"""
Test subqueries with double-quoted identifiers.
"""
sql = 'SELECT * FROM (SELECT "id", "name" FROM "users") AS "subquery"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
*
FROM (
SELECT
"id",
"name"
FROM "users"
) AS "subquery"
""".strip()
)
def test_case_expression() -> None:
"""
Test CASE expressions with quoted identifiers.
"""
sql = """
SELECT "name",
CASE WHEN "age" < 18 THEN 'minor'
WHEN "age" >= 18 THEN 'adult'
END AS "category"
FROM "persons"
"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert '"name"' in generated
assert '"age"' in generated
assert '"category"' in generated
assert "'minor'" in generated
assert "'adult'" in generated
def test_cte_with_quoted_identifiers() -> None:
"""
Test Common Table Expressions (CTE) with quoted identifiers.
"""
sql = """
WITH "high_value_orders" AS (
SELECT * FROM "orders" WHERE "total" > 1000
)
SELECT "customer_id", COUNT(*) FROM "high_value_orders" GROUP BY "customer_id"
"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert 'WITH "high_value_orders" AS' in generated
assert '"orders"' in generated
assert '"total"' in generated
assert '"customer_id"' in generated
def test_order_by_with_quoted_identifiers() -> None:
"""
Test ORDER BY clause with quoted identifiers.
SQLGlot explicitly includes ASC in the output.
"""
sql = 'SELECT "name", "salary" FROM "employees" ORDER BY "salary" DESC, "name" ASC'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT
"name",
"salary"
FROM "employees"
ORDER BY
"salary" DESC,
"name" ASC
""".strip()
)
def test_limit_and_offset() -> None:
"""
Test LIMIT and OFFSET clauses.
"""
sql = 'SELECT * FROM "products" LIMIT 10 OFFSET 20'
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast, pretty=True)
assert '"products"' in generated
assert "LIMIT 10" in generated
def test_distinct() -> None:
"""
Test DISTINCT keyword with quoted identifiers.
"""
sql = 'SELECT DISTINCT "category" FROM "products"'
ast = sqlglot.parse_one(sql, Pinot)
assert (
Pinot().generate(expression=ast, pretty=True)
== """
SELECT DISTINCT
"category"
FROM "products"
""".strip()
)
def test_cast_to_string() -> None:
"""
Test that CAST to STRING is preserved (not converted to CHAR).
"""
sql = "SELECT CAST(cohort_size AS STRING) FROM table"
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert "STRING" in generated
assert "CHAR" not in generated
def test_concat_with_cast_string() -> None:
"""
Test CONCAT with CAST to STRING - verifies the original issue is fixed.
"""
sql = """
SELECT concat(a, cast(b AS string), ' - ')
FROM "default".c"""
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
# Verify STRING type is preserved (not converted to CHAR)
assert "STRING" in generated or "string" in generated.lower()
assert "CHAR" not in generated
@pytest.mark.parametrize(
"cast_type, expected_type",
[
("INT", "INT"),
("TINYINT", "INT"),
("SMALLINT", "INT"),
("BIGINT", "LONG"),
("LONG", "LONG"),
("FLOAT", "FLOAT"),
("DOUBLE", "DOUBLE"),
("BOOLEAN", "BOOLEAN"),
("TIMESTAMP", "TIMESTAMP"),
("STRING", "STRING"),
("VARCHAR", "STRING"),
("CHAR", "STRING"),
("TEXT", "STRING"),
("BYTES", "BYTES"),
("BINARY", "BYTES"),
("VARBINARY", "BYTES"),
("JSON", "JSON"),
],
)
def test_type_mappings(cast_type: str, expected_type: str) -> None:
"""
Test that Pinot type mappings work correctly for all basic types.
"""
sql = f"SELECT CAST(col AS {cast_type}) FROM table" # noqa: S608
ast = sqlglot.parse_one(sql, Pinot)
generated = Pinot().generate(expression=ast)
assert expected_type in generated
def test_unsigned_type() -> None:
"""
Test that unsigned integer types are handled correctly.
Tests the UNSIGNED_TYPE_MAPPING path in datatype_sql method.
"""
from sqlglot import exp
# Create a UBIGINT DataType which is in UNSIGNED_TYPE_MAPPING
dt = exp.DataType(this=exp.DataType.Type.UBIGINT)
result = Pinot.Generator().datatype_sql(dt)
assert "UNSIGNED" in result
assert "BIGINT" in result
def test_date_trunc_preserved() -> None:
"""
Test that DATE_TRUNC is preserved and not converted to MySQL's DATE() function.
"""
sql = "SELECT DATE_TRUNC('day', dt_column) FROM table"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
assert "DATE_TRUNC" in result
assert "date_trunc('day'" in result.lower()
# Should not be converted to MySQL's DATE() function
assert result != "SELECT DATE(dt_column) FROM table"
def test_cast_timestamp_preserved() -> None:
"""
Test that CAST AS TIMESTAMP is preserved and not converted to TIMESTAMP() function.
"""
sql = "SELECT CAST(dt_column AS TIMESTAMP) FROM table"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
assert "CAST" in result
assert "AS TIMESTAMP" in result
# Should not be converted to MySQL's TIMESTAMP() function
assert "TIMESTAMP(dt_column)" not in result
def test_date_trunc_with_cast_timestamp() -> None:
"""
Test the original complex query with DATE_TRUNC and CAST AS TIMESTAMP.
Verifies that both are preserved in parse/generate round-trip.
"""
sql = """
SELECT
CAST(
DATE_TRUNC(
'day',
CAST(
DATETIMECONVERT(
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
) AS TIMESTAMP
)
) AS TIMESTAMP
),
SUM(a) + SUM(b)
FROM
"default".c
WHERE
dt_epoch_ms >= 1735690800000
AND dt_epoch_ms < 1759328588000
AND locality != 'US'
GROUP BY
CAST(
DATE_TRUNC(
'day',
CAST(
DATETIMECONVERT(
dt_epoch_ms, '1:MILLISECONDS:EPOCH',
'1:MILLISECONDS:EPOCH', '1:MILLISECONDS'
) AS TIMESTAMP
)
) AS TIMESTAMP
)
LIMIT
10000
"""
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
# Verify DATE_TRUNC and CAST are preserved
assert "DATE_TRUNC" in result
assert "CAST" in result
# Verify these are NOT converted to MySQL functions
assert "TIMESTAMP(DATETIMECONVERT" not in result
assert result.count("DATE_TRUNC") == 2 # Should appear twice (SELECT and GROUP BY)
def test_pinot_date_add_parsing() -> None:
"""
Test that Pinot's DATE_ADD function with Presto-like syntax can be parsed.
"""
from superset.sql.parse import SQLScript
sql = """
SELECT dt_epoch_ms FROM my_table WHERE dt_epoch_ms >= date_add('day', -180, now())
"""
script = SQLScript(sql, "pinot")
assert len(script.statements) == 1
assert not script.has_mutation()
def test_pinot_date_add_simple() -> None:
"""
Test parsing of simple DATE_ADD expressions.
"""
test_cases = [
"date_add('day', -180, now())",
"DATE_ADD('month', 5, current_timestamp())",
"date_add('year', 1, my_date_column)",
]
for sql in test_cases:
parsed = sqlglot.parse_one(sql, Pinot)
assert parsed is not None
# Verify that it generates valid SQL
generated = parsed.sql(dialect=Pinot)
assert "DATE_ADD" in generated.upper()
def test_pinot_date_add_unit_quoted() -> None:
"""
Test that DATE_ADD preserves quotes around the unit argument.
Pinot requires the unit to be a quoted string, not an identifier.
"""
sql = "dt_epoch_ms >= date_add('day', -180, now())"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
# The unit should be quoted: 'DAY' not DAY
assert "DATE_ADD('DAY', -180, NOW())" in result
assert "DATE_ADD(DAY," not in result
def test_pinot_date_sub_parsing() -> None:
"""
Test that Pinot's DATE_SUB function with Presto-like syntax can be parsed.
"""
from superset.sql.parse import SQLScript
sql = "SELECT * FROM my_table WHERE dt >= date_sub('day', 7, now())"
script = SQLScript(sql, "pinot")
assert len(script.statements) == 1
assert not script.has_mutation()
def test_pinot_date_sub_simple() -> None:
"""
Test parsing of simple DATE_SUB expressions.
"""
test_cases = [
"date_sub('day', 7, now())",
"DATE_SUB('month', 3, current_timestamp())",
"date_sub('hour', 24, my_date_column)",
]
for sql in test_cases:
parsed = sqlglot.parse_one(sql, Pinot)
assert parsed is not None
# Verify that it generates valid SQL
generated = parsed.sql(dialect=Pinot)
assert "DATE_SUB" in generated.upper()
def test_pinot_date_sub_unit_quoted() -> None:
"""
Test that DATE_SUB preserves quotes around the unit argument.
Pinot requires the unit to be a quoted string, not an identifier.
"""
sql = "dt_epoch_ms >= date_sub('day', -180, now())"
result = sqlglot.parse_one(sql, Pinot).sql(Pinot)
# The unit should be quoted: 'DAY' not DAY
assert "DATE_SUB('DAY', -180, NOW())" in result
assert "DATE_SUB(DAY," not in result
def test_substr_cross_dialect_generation() -> None:
"""
Test that SUBSTR is preserved when generating Pinot SQL.
Note that the MySQL dialect (in which Pinot is based) uses SUBSTRING instead of
SUBSTR.
"""
# Parse with Pinot dialect
pinot_sql = "SELECT SUBSTR('hello', 0, 3) FROM users"
parsed = sqlglot.parse_one(pinot_sql, Pinot)
# Generate back to Pinot → should preserve SUBSTR
pinot_output = parsed.sql(dialect=Pinot)
assert "SUBSTR(" in pinot_output
assert "SUBSTRING(" not in pinot_output
# Generate to MySQL → should convert to SUBSTRING
mysql_output = parsed.sql(dialect="mysql")
assert "SUBSTRING(" in mysql_output
assert pinot_output != mysql_output # They should be different
@pytest.mark.parametrize(
"function_name,sample_args",
[
# Math functions
("ABS", "-5"),
("CEIL", "3.14"),
("FLOOR", "3.14"),
("EXP", "2"),
("LN", "10"),
("SQRT", "16"),
("ROUNDDECIMAL", "3.14159, 2"),
("ADD", "1, 2, 3"),
("SUB", "10, 3"),
("MULT", "5, 4"),
("MOD", "10, 3"),
# String functions
("UPPER", "'hello'"),
("LOWER", "'HELLO'"),
("REVERSE", "'hello'"),
("SUBSTR", "'hello', 0, 3"),
("CONCAT", "'hello', ' ', 'world'"),
("TRIM", "' hello '"),
("LTRIM", "' hello'"),
("RTRIM", "'hello '"),
("LENGTH", "'hello'"),
("STRPOS", "'hello', 'l', 1"),
("STARTSWITH", "'hello', 'he'"),
("REPLACE", "'hello', 'l', 'r'"),
("RPAD", "'hello', 10, 'x'"),
("LPAD", "'hello', 10, 'x'"),
("CODEPOINT", "'A'"),
("CHR", "65"),
("regexpExtract", "'foo123bar', '[0-9]+'"),
("regexpReplace", "'hello', 'l', 'r'"),
("remove", "'hello', 'l'"),
("urlEncoding", "'hello world'"),
("urlDecoding", "'hello%20world'"),
("fromBase64", "'aGVsbG8='"),
("toUtf8", "'hello'"),
("isSubnetOf", "'192.168.1.1', '192.168.0.0/16'"),
# DateTime functions
("DATETRUNC", "'day', timestamp_col"),
("DATETIMECONVERT", "dt_col, '1:HOURS:EPOCH', '1:DAYS:EPOCH', '1:DAYS'"),
("TIMECONVERT", "timestamp_col, 'MILLISECONDS', 'SECONDS'"),
("NOW", ""),
("AGO", "'P1D'"),
("YEAR", "timestamp_col"),
("QUARTER", "timestamp_col"),
("MONTH", "timestamp_col"),
("WEEK", "timestamp_col"),
("DAY", "timestamp_col"),
("HOUR", "timestamp_col"),
("MINUTE", "timestamp_col"),
("SECOND", "timestamp_col"),
("MILLISECOND", "timestamp_col"),
("DAYOFWEEK", "timestamp_col"),
("DAYOFYEAR", "timestamp_col"),
("YEAROFWEEK", "timestamp_col"),
("toEpochSeconds", "timestamp_col"),
("toEpochMinutes", "timestamp_col"),
("toEpochHours", "timestamp_col"),
("toEpochDays", "timestamp_col"),
("fromEpochSeconds", "1234567890"),
("fromEpochMinutes", "20576131"),
("fromEpochHours", "342935"),
("fromEpochDays", "14288"),
("toDateTime", "timestamp_col, 'yyyy-MM-dd'"),
("fromDateTime", "'2024-01-01', 'yyyy-MM-dd'"),
("timezoneHour", "timestamp_col"),
("timezoneMinute", "timestamp_col"),
("DATE_ADD", "'day', 7, NOW()"),
("DATE_SUB", "'day', 7, NOW()"),
("TIMESTAMPADD", "'day', 7, timestamp_col"),
("TIMESTAMPDIFF", "'day', timestamp1, timestamp2"),
("dateTrunc", "'day', timestamp_col"),
("dateDiff", "'day', timestamp1, timestamp2"),
("dateAdd", "'day', 7, timestamp_col"),
("dateBin", "'day', timestamp_col, NOW()"),
("toIso8601", "timestamp_col"),
("fromIso8601", "'2024-01-01T00:00:00Z'"),
# Aggregation functions
("COUNT", "*"),
("SUM", "amount"),
("AVG", "value"),
("MIN", "value"),
("MAX", "value"),
("DISTINCTCOUNT", "user_id"),
("DISTINCTCOUNTBITMAP", "user_id"),
("DISTINCTCOUNTHLL", "user_id"),
("DISTINCTCOUNTRAWHLL", "user_id"),
("DISTINCTCOUNTHLLPLUS", "user_id"),
("DISTINCTCOUNTRAWHLLPLUS", "user_id"),
("DISTINCTCOUNTSMARTHLL", "user_id"),
("DISTINCTCOUNTCPCSKETCH", "user_id"),
("DISTINCTCOUNTRAWCPCSKETCH", "user_id"),
("DISTINCTCOUNTTHETASKETCH", "user_id"),
("DISTINCTCOUNTRAWTHETASKETCH", "user_id"),
("DISTINCTCOUNTTUPLESKETCH", "user_id"),
("DISTINCTCOUNTRAWINTEGERSUMTUPLESKETCH", "user_id"),
("DISTINCTCOUNTULL", "user_id"),
("DISTINCTCOUNTRAWULL", "user_id"),
("SEGMENTPARTITIONEDDISTINCTCOUNT", "user_id"),
("SUMVALUESINTEGERSUMTUPLESKETCH", "value"),
("PERCENTILE", "value, 95"),
("PERCENTILEEST", "value, 95"),
("PERCENTILETDIGEST", "value, 95"),
("PERCENTILESMARTTDIGEST", "value, 95"),
("PERCENTILEKLL", "value, 95"),
("PERCENTILEKLLRAW", "value, 95"),
("HISTOGRAM", "value, 10"),
("MODE", "category"),
("MINMAXRANGE", "value"),
("SUMPRECISION", "value, 10"),
("ARG_MIN", "value, id"),
("ARG_MAX", "value, id"),
("COVAR_POP", "x, y"),
("COVAR_SAMP", "x, y"),
("LASTWITHTIME", "value, timestamp_col, 'LONG'"),
("FIRSTWITHTIME", "value, timestamp_col, 'LONG'"),
("ARRAY_AGG", "value"),
# Multi-value functions
("COUNTMV", "tags"),
("MAXMV", "scores"),
("MINMV", "scores"),
("SUMMV", "scores"),
("AVGMV", "scores"),
("MINMAXRANGEMV", "scores"),
("PERCENTILEMV", "scores, 95"),
("PERCENTILEESTMV", "scores, 95"),
("PERCENTILETDIGESTMV", "scores, 95"),
("PERCENTILEKLLMV", "scores, 95"),
("DISTINCTCOUNTMV", "tags"),
("DISTINCTCOUNTBITMAPMV", "tags"),
("DISTINCTCOUNTHLLMV", "tags"),
("DISTINCTCOUNTRAWHLLMV", "tags"),
("DISTINCTCOUNTHLLPLUSMV", "tags"),
("DISTINCTCOUNTRAWHLLPLUSMV", "tags"),
("ARRAYLENGTH", "array_col"),
("MAP_VALUE", "map_col, 'key'"),
("VALUEIN", "value, 'val1', 'val2'"),
# JSON functions
("JSONEXTRACTSCALAR", "json_col, '$.name', 'STRING'"),
("JSONEXTRACTKEY", "json_col, '$.data'"),
("JSONFORMAT", "json_col"),
("JSONPATH", "json_col, '$.name'"),
("JSONPATHLONG", "json_col, '$.id'"),
("JSONPATHDOUBLE", "json_col, '$.price'"),
("JSONPATHSTRING", "json_col, '$.name'"),
("JSONPATHARRAY", "json_col, '$.items'"),
("JSONPATHARRAYDEFAULTEMPTY", "json_col, '$.items'"),
("TOJSONMAPSTR", "map_col"),
("JSON_MATCH", "json_col, '\"$.name\"=''value'''"),
("JSON_EXTRACT_SCALAR", "json_col, '$.name', 'STRING'"),
# Array functions
("arrayReverseInt", "int_array"),
("arrayReverseString", "string_array"),
("arraySortInt", "int_array"),
("arraySortString", "string_array"),
("arrayIndexOfInt", "int_array, 5"),
("arrayIndexOfString", "string_array, 'value'"),
("arrayContainsInt", "int_array, 5"),
("arrayContainsString", "string_array, 'value'"),
("arraySliceInt", "int_array, 0, 3"),
("arraySliceString", "string_array, 0, 3"),
("arrayDistinctInt", "int_array"),
("arrayDistinctString", "string_array"),
("arrayRemoveInt", "int_array, 5"),
("arrayRemoveString", "string_array, 'value'"),
("arrayUnionInt", "int_array1, int_array2"),
("arrayUnionString", "string_array1, string_array2"),
("arrayConcatInt", "int_array1, int_array2"),
("arrayConcatString", "string_array1, string_array2"),
("arrayElementAtInt", "int_array, 0"),
("arrayElementAtString", "string_array, 0"),
("arraySumInt", "int_array"),
("arrayValueConstructor", "1, 2, 3"),
("arrayToString", "array_col, ','"),
# Geospatial functions
("ST_DISTANCE", "point1, point2"),
("ST_CONTAINS", "polygon, point"),
("ST_AREA", "polygon"),
("ST_GEOMFROMTEXT", "'POINT(1 2)'"),
("ST_GEOMFROMWKB", "wkb_col"),
("ST_GEOGFROMWKB", "wkb_col"),
("ST_GEOGFROMTEXT", "'POINT(1 2)'"),
("ST_POINT", "1.0, 2.0"),
("ST_POLYGON", "'POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'"),
("ST_ASBINARY", "geom_col"),
("ST_ASTEXT", "geom_col"),
("ST_GEOMETRYTYPE", "geom_col"),
("ST_EQUALS", "geom1, geom2"),
("ST_WITHIN", "geom1, geom2"),
("ST_UNION", "geom1, geom2"),
("ST_GEOMFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_GEOGFROMGEOJSON", '\'{"type":"Point","coordinates":[1,2]}\''),
("ST_ASGEOJSON", "geom_col"),
("toSphericalGeography", "geom_col"),
("toGeometry", "geog_col"),
# Binary/Hash functions
("SHA", "'hello'"),
("SHA256", "'hello'"),
("SHA512", "'hello'"),
("SHA224", "'hello'"),
("MD5", "'hello'"),
("MD2", "'hello'"),
("toBase64", "'hello'"),
("fromUtf8", "bytes_col"),
("MurmurHash2", "'hello'"),
("MurmurHash3Bit32", "'hello'"),
# Window functions
("ROW_NUMBER", ""),
("RANK", ""),
("DENSE_RANK", ""),
# Funnel analysis
("FunnelMaxStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelMatchStep", "event_col, 'step1', 'step2', 'step3'"),
("FunnelCompleteCount", "event_col, 'step1', 'step2', 'step3'"),
# Text search
("TEXT_MATCH", "text_col, 'search query'"),
# Vector functions
("VECTOR_SIMILARITY", "vector1, vector2"),
("l2_distance", "vector1, vector2"),
# Lookup
("LOOKUP", "'lookupTable', 'lookupColumn', 'keyColumn', keyValue"),
# URL functions
("urlProtocol", "'https://example.com/path'"),
("urlDomain", "'https://example.com/path'"),
("urlPath", "'https://example.com/path'"),
("urlPort", "'https://example.com:8080/path'"),
("urlEncode", "'hello world'"),
("urlDecode", "'hello%20world'"),
# Conditional
("COALESCE", "val1, val2, 'default'"),
("NULLIF", "val1, val2"),
("GREATEST", "1, 2, 3"),
("LEAST", "1, 2, 3"),
# Other
("REGEXP_LIKE", "'hello', 'h.*'"),
("GROOVY", "'{return arg0 + arg1}', col1, col2"),
],
)
def test_pinot_function_names_preserved(function_name: str, sample_args: str) -> None:
"""
Test that Pinot function names are preserved during parse/generate roundtrip.
This ensures that when we parse Pinot SQL and generate it back, the function
names remain unchanged. This is critical for maintaining compatibility with
Pinot's function library.
"""
# Special handling for window functions
if function_name in ["ROW_NUMBER", "RANK", "DENSE_RANK"]:
sql = f"SELECT {function_name}() OVER (ORDER BY col) FROM table" # noqa: S608
else:
sql = f"SELECT {function_name}({sample_args}) FROM table" # noqa: S608
# Parse with Pinot dialect
parsed = sqlglot.parse_one(sql, Pinot)
# Generate back to Pinot
generated = parsed.sql(dialect=Pinot)
# The function name should be preserved (case-insensitive check)
assert function_name.upper() in generated.upper(), (
f"Function {function_name} not preserved in output: {generated}"
)