blob: c57483d1a63ed76ffbef6533076ef83f3ba91243 [file] [log] [blame]
#!/usr/bin/env python3
"""
Microbenchmark comparing DataFusion and DuckDB performance
for SQL string functions on Parquet files.
"""
import tempfile
import time
import os
from dataclasses import dataclass
from pathlib import Path
import pyarrow as pa
import pyarrow.parquet as pq
import datafusion
import duckdb
@dataclass
class BenchmarkResult:
"""Stores benchmark results for a single function."""
function_name: str
datafusion_time_ms: float
duckdb_time_ms: float
rows: int
@property
def speedup(self) -> float:
"""DuckDB time / DataFusion time (>1 means DataFusion is faster)."""
if self.datafusion_time_ms == 0:
return float('inf')
return self.duckdb_time_ms / self.datafusion_time_ms
@dataclass
class StringFunction:
"""Defines a string function with syntax for both engines."""
name: str
datafusion_expr: str # Expression using {col} as placeholder for column name
duckdb_expr: str # Expression using {col} as placeholder for column name
# String functions to benchmark
# {col} will be replaced with the actual column name
STRING_FUNCTIONS = [
StringFunction("trim", "trim({col})", "trim({col})"),
StringFunction("ltrim", "ltrim({col})", "ltrim({col})"),
StringFunction("rtrim", "rtrim({col})", "rtrim({col})"),
StringFunction("lower", "lower({col})", "lower({col})"),
StringFunction("upper", "upper({col})", "upper({col})"),
StringFunction("length", "length({col})", "length({col})"),
StringFunction("char_length", "char_length({col})", "length({col})"),
StringFunction("reverse", "reverse({col})", "reverse({col})"),
StringFunction("repeat_3", "repeat({col}, 3)", "repeat({col}, 3)"),
StringFunction("concat", "concat({col}, {col})", "concat({col}, {col})"),
StringFunction("concat_ws", "concat_ws('-', {col}, {col})", "concat_ws('-', {col}, {col})"),
StringFunction("substring_1_5", "substring({col}, 1, 5)", "substring({col}, 1, 5)"),
StringFunction("left_5", "left({col}, 5)", "left({col}, 5)"),
StringFunction("right_5", "right({col}, 5)", "right({col}, 5)"),
StringFunction("lpad_20", "lpad({col}, 20, '*')", "lpad({col}, 20, '*')"),
StringFunction("rpad_20", "rpad({col}, 20, '*')", "rpad({col}, 20, '*')"),
StringFunction("replace", "replace({col}, 'a', 'X')", "replace({col}, 'a', 'X')"),
StringFunction("translate", "translate({col}, 'aeiou', '12345')", "translate({col}, 'aeiou', '12345')"),
StringFunction("ascii", "ascii({col})", "ascii({col})"),
StringFunction("md5", "md5({col})", "md5({col})"),
StringFunction("sha256", "sha256({col})", "sha256({col})"),
StringFunction("btrim", "btrim({col}, ' ')", "trim({col}, ' ')"),
StringFunction("split_part", "split_part({col}, ' ', 1)", "split_part({col}, ' ', 1)"),
StringFunction("starts_with", "starts_with({col}, 'test')", "starts_with({col}, 'test')"),
StringFunction("ends_with", "ends_with({col}, 'data')", "ends_with({col}, 'data')"),
StringFunction("strpos", "strpos({col}, 'e')", "strpos({col}, 'e')"),
StringFunction("regexp_replace", "regexp_replace({col}, '[aeiou]', '*')", "regexp_replace({col}, '[aeiou]', '*', 'g')"),
]
def generate_test_data(num_rows: int = 1_000_000, use_string_view: bool = False) -> pa.Table:
"""Generate test data with various string patterns."""
import random
import string
random.seed(42) # For reproducibility
# Generate diverse string data
strings = []
for i in range(num_rows):
# Mix of different string patterns
pattern_type = i % 5
if pattern_type == 0:
# Short strings with spaces
s = f" test_{i % 1000} "
elif pattern_type == 1:
# Longer strings
s = ''.join(random.choices(string.ascii_lowercase, k=20))
elif pattern_type == 2:
# Mixed case with numbers
s = f"TestData_{i}_Value"
elif pattern_type == 3:
# Strings with special patterns
s = f"hello world {i % 100} data"
else:
# Random length strings
length = random.randint(5, 50)
s = ''.join(random.choices(string.ascii_letters + string.digits + ' ', k=length))
strings.append(s)
str_type = pa.string_view() if use_string_view else pa.string()
table = pa.table({
'str_col': pa.array(strings, type=str_type)
})
return table
def setup_datafusion(parquet_path: str) -> datafusion.SessionContext:
"""Create and configure DataFusion context with single thread/partition."""
config = datafusion.SessionConfig().with_target_partitions(1)
ctx = datafusion.SessionContext(config)
ctx.register_parquet('test_data', parquet_path)
return ctx
def setup_duckdb(parquet_path: str) -> duckdb.DuckDBPyConnection:
"""Create and configure DuckDB connection with single thread."""
conn = duckdb.connect(':memory:', config={'threads': 1})
conn.execute(f"CREATE VIEW test_data AS SELECT * FROM read_parquet('{parquet_path}')")
return conn
def benchmark_datafusion(ctx: datafusion.SessionContext, expr: str,
warmup: int = 2, iterations: int = 5) -> float:
"""Benchmark a query in DataFusion, return average time in ms."""
query = f"SELECT {expr} FROM test_data"
# Warmup runs
for _ in range(warmup):
ctx.sql(query).collect()
# Timed runs
times = []
for _ in range(iterations):
start = time.perf_counter()
ctx.sql(query).collect()
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
return sum(times) / len(times)
def benchmark_duckdb(conn: duckdb.DuckDBPyConnection, expr: str,
warmup: int = 2, iterations: int = 5) -> float:
"""Benchmark a query in DuckDB, return average time in ms."""
query = f"SELECT {expr} FROM test_data"
# Use fetch_arrow_table() for fair comparison with DataFusion's collect()
# Both return Arrow data without Python object conversion overhead
for _ in range(warmup):
conn.execute(query).fetch_arrow_table()
# Timed runs
times = []
for _ in range(iterations):
start = time.perf_counter()
conn.execute(query).fetch_arrow_table()
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
return sum(times) / len(times)
def run_benchmarks(num_rows: int = 1_000_000,
warmup: int = 2,
iterations: int = 5,
use_string_view: bool = False) -> list[BenchmarkResult]:
"""Run all benchmarks and return results."""
results = []
with tempfile.TemporaryDirectory() as tmpdir:
parquet_path = os.path.join(tmpdir, 'test_data.parquet')
# Generate and save test data
str_type = "StringView" if use_string_view else "String"
print(f"Generating {num_rows:,} rows of test data (type: {str_type})...")
table = generate_test_data(num_rows, use_string_view)
pq.write_table(table, parquet_path)
print(f"Parquet file written to: {parquet_path}")
print(f"File size: {os.path.getsize(parquet_path) / 1024 / 1024:.2f} MB")
# Setup engines
print("\nSetting up DataFusion...")
df_ctx = setup_datafusion(parquet_path)
print("Setting up DuckDB...")
duck_conn = setup_duckdb(parquet_path)
# Run benchmarks
print(f"\nRunning benchmarks ({warmup} warmup, {iterations} iterations each)...\n")
col = 'str_col'
for func in STRING_FUNCTIONS:
df_expr = func.datafusion_expr.format(col=col)
duck_expr = func.duckdb_expr.format(col=col)
print(f" Benchmarking: {func.name}...", end=" ", flush=True)
try:
df_time = benchmark_datafusion(df_ctx, df_expr, warmup, iterations)
except Exception as e:
print(f"DataFusion error: {e}")
df_time = float('nan')
try:
duck_time = benchmark_duckdb(duck_conn, duck_expr, warmup, iterations)
except Exception as e:
print(f"DuckDB error: {e}")
duck_time = float('nan')
result = BenchmarkResult(
function_name=func.name,
datafusion_time_ms=df_time,
duckdb_time_ms=duck_time,
rows=num_rows
)
results.append(result)
# Print progress
if df_time == df_time and duck_time == duck_time: # Check for NaN
faster = "DataFusion" if df_time < duck_time else "DuckDB"
ratio = max(df_time, duck_time) / min(df_time, duck_time)
print(f"done ({faster} {ratio:.2f}x faster)")
else:
print("done (with errors)")
duck_conn.close()
return results
def format_results_markdown(results: list[BenchmarkResult], use_string_view: bool = False) -> str:
"""Format benchmark results as a markdown table."""
str_type = "StringView" if use_string_view else "String"
lines = [
"# String Function Microbenchmarks: DataFusion vs DuckDB",
"",
f"**DataFusion version:** {datafusion.__version__} ",
f"**DuckDB version:** {duckdb.__version__} ",
f"**Rows:** {results[0].rows:,} ",
f"**String type:** {str_type} ",
"**Configuration:** Single thread, single partition",
"",
f"| Function | DataFusion {datafusion.__version__} (ms) | DuckDB {duckdb.__version__} (ms) | Speedup | Faster |",
"|----------|----------------:|------------:|--------:|--------|",
]
for r in results:
if r.datafusion_time_ms != r.datafusion_time_ms or r.duckdb_time_ms != r.duckdb_time_ms:
# Handle NaN
lines.append(f"| {r.function_name} | ERROR | ERROR | N/A | N/A |")
else:
speedup = r.speedup
if speedup > 1:
faster = "DataFusion"
speedup_str = f"{speedup:.2f}x"
else:
faster = "DuckDB"
speedup_str = f"{1/speedup:.2f}x"
lines.append(
f"| {r.function_name} | {r.datafusion_time_ms:.2f} | "
f"{r.duckdb_time_ms:.2f} | {speedup_str} | {faster} |"
)
# Summary statistics
valid_results = [r for r in results
if r.datafusion_time_ms == r.datafusion_time_ms
and r.duckdb_time_ms == r.duckdb_time_ms]
if valid_results:
df_wins = sum(1 for r in valid_results if r.speedup > 1)
duck_wins = len(valid_results) - df_wins
df_total = sum(r.datafusion_time_ms for r in valid_results)
duck_total = sum(r.duckdb_time_ms for r in valid_results)
lines.extend([
"",
"## Summary",
"",
f"- **Functions tested:** {len(valid_results)}",
f"- **DataFusion faster:** {df_wins} functions",
f"- **DuckDB faster:** {duck_wins} functions",
f"- **Total DataFusion time:** {df_total:.2f} ms",
f"- **Total DuckDB time:** {duck_total:.2f} ms",
])
return "\n".join(lines)
def main():
import argparse
parser = argparse.ArgumentParser(
description="Benchmark string functions: DataFusion vs DuckDB"
)
parser.add_argument(
"--rows", type=int, default=1_000_000,
help="Number of rows in test data (default: 1,000,000)"
)
parser.add_argument(
"--warmup", type=int, default=2,
help="Number of warmup iterations (default: 2)"
)
parser.add_argument(
"--iterations", type=int, default=5,
help="Number of timed iterations (default: 5)"
)
parser.add_argument(
"--output", type=str, default=None,
help="Output file for markdown results (default: stdout)"
)
parser.add_argument(
"--string-view", action="store_true",
help="Use StringView type instead of String (default: False)"
)
args = parser.parse_args()
print("=" * 60)
print("String Function Microbenchmarks: DataFusion vs DuckDB")
print("=" * 60)
results = run_benchmarks(
num_rows=args.rows,
warmup=args.warmup,
iterations=args.iterations,
use_string_view=args.string_view
)
markdown = format_results_markdown(results, use_string_view=args.string_view)
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60 + "\n")
print(markdown)
if args.output:
with open(args.output, 'w') as f:
f.write(markdown)
print(f"\nResults saved to: {args.output}")
if __name__ == "__main__":
main()