blob: baa0a2e6602e5bf4c559d6081225baf527074b14 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! See `main.rs` for how to run it.
use std::any::Any;
use std::sync::Arc;
use arrow::array::{RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef};
use datafusion::assert_batches_eq;
use datafusion::common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion::common::{assert_contains, exec_datafusion_err, Result};
use datafusion::datasource::listing::{
ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl,
};
use datafusion::execution::context::SessionContext;
use datafusion::execution::object_store::ObjectStoreUrl;
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion::parquet::arrow::ArrowWriter;
use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_expr::{expressions, ScalarFunctionExpr};
use datafusion::prelude::SessionConfig;
use datafusion::scalar::ScalarValue;
use datafusion_physical_expr_adapter::{
DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory,
};
use object_store::memory::InMemory;
use object_store::path::Path;
use object_store::{ObjectStore, PutPayload};
// Example showing how to implement custom filter rewriting for JSON shredding.
//
// JSON shredding is a technique for optimizing queries on semi-structured data
// by materializing commonly accessed fields into separate columns for better
// columnar storage performance.
//
// In this example, we have a table with both:
// - Original JSON data: data: '{"age": 30}'
// - Shredded flat columns: _data.name: "Alice" (extracted from JSON)
//
// Our custom TableProvider uses a PhysicalExprAdapter to rewrite
// expressions like `json_get_str('name', data)` to use the pre-computed
// flat column `_data.name` when available. This allows the query engine to:
// 1. Push down predicates for better filtering
// 2. Avoid expensive JSON parsing at query time
// 3. Leverage columnar storage benefits for the materialized fields
pub async fn json_shredding() -> Result<()> {
println!("=== Creating example data with flat columns and underscore prefixes ===");
// Create sample data with flat columns using underscore prefixes
let (table_schema, batch) = create_sample_data();
let store = InMemory::new();
let buf = {
let mut buf = vec![];
let props = WriterProperties::builder()
.set_max_row_group_size(2)
.build();
let mut writer = ArrowWriter::try_new(&mut buf, batch.schema(), Some(props))
.expect("creating writer");
writer.write(&batch).expect("Writing batch");
writer.close().unwrap();
buf
};
let path = Path::from("example.parquet");
let payload = PutPayload::from_bytes(buf.into());
store.put(&path, payload).await?;
// Set up query execution
let mut cfg = SessionConfig::new();
cfg.options_mut().execution.parquet.pushdown_filters = true;
let ctx = SessionContext::new_with_config(cfg);
ctx.runtime_env().register_object_store(
ObjectStoreUrl::parse("memory://")?.as_ref(),
Arc::new(store),
);
// Create a custom table provider that rewrites struct field access
let listing_table_config =
ListingTableConfig::new(ListingTableUrl::parse("memory:///example.parquet")?)
.infer_options(&ctx.state())
.await?
.with_schema(table_schema)
.with_expr_adapter_factory(Arc::new(ShreddedJsonRewriterFactory));
let table = ListingTable::try_new(listing_table_config).unwrap();
let table_provider = Arc::new(table);
// Register our table
ctx.register_table("structs", table_provider)?;
ctx.register_udf(ScalarUDF::new_from_impl(JsonGetStr::default()));
println!("\n=== Showing all data ===");
let batches = ctx.sql("SELECT * FROM structs").await?.collect().await?;
arrow::util::pretty::print_batches(&batches)?;
println!("\n=== Running query with flat column access and filter ===");
let query = "SELECT json_get_str('age', data) as age FROM structs WHERE json_get_str('name', data) = 'Bob'";
println!("Query: {query}");
let batches = ctx.sql(query).await?.collect().await?;
#[rustfmt::skip]
let expected = [
"+-----+",
"| age |",
"+-----+",
"| 25 |",
"+-----+",
];
arrow::util::pretty::print_batches(&batches)?;
assert_batches_eq!(expected, &batches);
println!("\n=== Running explain analyze to confirm row group pruning ===");
let batches = ctx
.sql(&format!("EXPLAIN ANALYZE {query}"))
.await?
.collect()
.await?;
let plan = format!("{}", arrow::util::pretty::pretty_format_batches(&batches)?);
println!("{plan}");
assert_contains!(&plan, "row_groups_pruned_statistics=2 total → 1 matched");
assert_contains!(&plan, "pushdown_rows_pruned=1");
Ok(())
}
/// Create the example data with flat columns using underscore prefixes.
///
/// This demonstrates the logical data structure:
/// - Table schema: What users see (just the 'data' JSON column)
/// - File schema: What's physically stored (both 'data' and materialized '_data.name')
///
/// The naming convention uses underscore prefixes to indicate shredded columns:
/// - `data` -> original JSON column
/// - `_data.name` -> materialized field from JSON data.name
fn create_sample_data() -> (SchemaRef, RecordBatch) {
// The table schema only has the main data column - this is what users query against
let table_schema = Schema::new(vec![Field::new("data", DataType::Utf8, false)]);
// The file schema has both the main column and the shredded flat column with underscore prefix
// This represents the actual physical storage with pre-computed columns
let file_schema = Schema::new(vec![
Field::new("data", DataType::Utf8, false), // Original JSON data
Field::new("_data.name", DataType::Utf8, false), // Materialized name field
]);
let batch = create_sample_record_batch(&file_schema);
(Arc::new(table_schema), batch)
}
/// Create the actual RecordBatch with sample data
fn create_sample_record_batch(file_schema: &Schema) -> RecordBatch {
// Build a RecordBatch with flat columns
let data_array = StringArray::from(vec![
r#"{"age": 30}"#,
r#"{"age": 25}"#,
r#"{"age": 35}"#,
r#"{"age": 22}"#,
]);
let names_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Dave"]);
RecordBatch::try_new(
Arc::new(file_schema.clone()),
vec![Arc::new(data_array), Arc::new(names_array)],
)
.unwrap()
}
/// Scalar UDF that uses serde_json to access json fields
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct JsonGetStr {
signature: Signature,
}
impl Default for JsonGetStr {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for JsonGetStr {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"json_get_str"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Utf8)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
assert!(
args.args.len() == 2,
"json_get_str requires exactly 2 arguments"
);
let key = match &args.args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(key))) => key,
_ => {
return Err(exec_datafusion_err!(
"json_get_str first argument must be a string"
))
}
};
// We expect a string array that contains JSON strings
let json_array = match &args.args[1] {
ColumnarValue::Array(array) => array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
exec_datafusion_err!(
"json_get_str second argument must be a string array"
)
})?,
_ => {
return Err(exec_datafusion_err!(
"json_get_str second argument must be a string array"
))
}
};
let values = json_array
.iter()
.map(|value| {
value.and_then(|v| {
let json_value: serde_json::Value =
serde_json::from_str(v).unwrap_or_default();
json_value.get(key).map(|v| v.to_string())
})
})
.collect::<StringArray>();
Ok(ColumnarValue::Array(Arc::new(values)))
}
}
/// Factory for creating ShreddedJsonRewriter instances
#[derive(Debug)]
struct ShreddedJsonRewriterFactory;
impl PhysicalExprAdapterFactory for ShreddedJsonRewriterFactory {
fn create(
&self,
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
) -> Arc<dyn PhysicalExprAdapter> {
let default_factory = DefaultPhysicalExprAdapterFactory;
let default_adapter = default_factory
.create(logical_file_schema.clone(), physical_file_schema.clone());
Arc::new(ShreddedJsonRewriter {
logical_file_schema,
physical_file_schema,
default_adapter,
partition_values: Vec::new(),
})
}
}
/// Rewriter that converts json_get_str calls to direct flat column references
/// and wraps DefaultPhysicalExprAdapter for standard schema adaptation
#[derive(Debug)]
struct ShreddedJsonRewriter {
logical_file_schema: SchemaRef,
physical_file_schema: SchemaRef,
default_adapter: Arc<dyn PhysicalExprAdapter>,
partition_values: Vec<(FieldRef, ScalarValue)>,
}
impl PhysicalExprAdapter for ShreddedJsonRewriter {
fn rewrite(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
// First try our custom JSON shredding rewrite
let rewritten = expr
.transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema))
.data()?;
// Then apply the default adapter as a fallback to handle standard schema differences
// like type casting, missing columns, and partition column handling
let default_adapter = if !self.partition_values.is_empty() {
self.default_adapter
.with_partition_values(self.partition_values.clone())
} else {
self.default_adapter.clone()
};
default_adapter.rewrite(rewritten)
}
fn with_partition_values(
&self,
partition_values: Vec<(FieldRef, ScalarValue)>,
) -> Arc<dyn PhysicalExprAdapter> {
Arc::new(ShreddedJsonRewriter {
logical_file_schema: self.logical_file_schema.clone(),
physical_file_schema: self.physical_file_schema.clone(),
default_adapter: self.default_adapter.clone(),
partition_values,
})
}
}
impl ShreddedJsonRewriter {
fn rewrite_impl(
&self,
expr: Arc<dyn PhysicalExpr>,
physical_file_schema: &Schema,
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
if let Some(func) = expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
if func.name() == "json_get_str" && func.args().len() == 2 {
// Get the key from the first argument
if let Some(literal) = func.args()[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if let ScalarValue::Utf8(Some(field_name)) = literal.value() {
// Get the column from the second argument
if let Some(column) = func.args()[1]
.as_any()
.downcast_ref::<expressions::Column>()
{
let column_name = column.name();
// Check if there's a flat column with underscore prefix
let flat_column_name = format!("_{column_name}.{field_name}");
if let Ok(flat_field_index) =
physical_file_schema.index_of(&flat_column_name)
{
let flat_field =
physical_file_schema.field(flat_field_index);
if flat_field.data_type() == &DataType::Utf8 {
// Replace the whole expression with a direct column reference
let new_expr = Arc::new(expressions::Column::new(
&flat_column_name,
flat_field_index,
))
as Arc<dyn PhysicalExpr>;
return Ok(Transformed {
data: new_expr,
tnr: TreeNodeRecursion::Stop,
transformed: true,
});
}
}
}
}
}
}
}
Ok(Transformed::no(expr))
}
}