blob: 1d9b587f15b93403d87befeb5b7703719587260e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{any::Any, sync::Arc};
use arrow::{
array::{AsArray, RecordBatch, StringArray, UInt8Array},
datatypes::UInt64Type,
};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::physical_expr::LexRequirement;
use datafusion::{
datasource::{
file_format::{
csv::CsvFormatFactory, file_compression_type::FileCompressionType,
FileFormat, FileFormatFactory,
},
physical_plan::{FileScanConfig, FileSinkConfig},
MemTable,
},
error::Result,
execution::context::SessionState,
physical_plan::ExecutionPlan,
prelude::SessionContext,
};
use datafusion_common::{GetExt, Statistics};
use datafusion_physical_expr::PhysicalExpr;
use object_store::{ObjectMeta, ObjectStore};
use tempfile::tempdir;
/// Example of a custom file format that reads and writes TSV files.
///
/// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat.
/// The former, once registered with the SessionState, will then be used
/// to facilitate SQL operations on TSV files, such as `COPY TO` shown here.
#[derive(Debug)]
/// Custom file format that reads and writes TSV files
///
/// This file format is a wrapper around the CSV file format
/// for demonstration purposes.
struct TSVFileFormat {
csv_file_format: Arc<dyn FileFormat>,
}
impl TSVFileFormat {
pub fn new(csv_file_format: Arc<dyn FileFormat>) -> Self {
Self { csv_file_format }
}
}
#[async_trait::async_trait]
impl FileFormat for TSVFileFormat {
fn as_any(&self) -> &dyn Any {
self
}
fn get_ext(&self) -> String {
"tsv".to_string()
}
fn get_ext_with_compression(
&self,
c: &FileCompressionType,
) -> datafusion::error::Result<String> {
if c == &FileCompressionType::UNCOMPRESSED {
Ok("tsv".to_string())
} else {
todo!("Compression not supported")
}
}
async fn infer_schema(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
objects: &[ObjectMeta],
) -> Result<SchemaRef> {
self.csv_file_format
.infer_schema(state, store, objects)
.await
}
async fn infer_stats(
&self,
state: &SessionState,
store: &Arc<dyn ObjectStore>,
table_schema: SchemaRef,
object: &ObjectMeta,
) -> Result<Statistics> {
self.csv_file_format
.infer_stats(state, store, table_schema, object)
.await
}
async fn create_physical_plan(
&self,
state: &SessionState,
conf: FileScanConfig,
filters: Option<&Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_physical_plan(state, conf, filters)
.await
}
async fn create_writer_physical_plan(
&self,
input: Arc<dyn ExecutionPlan>,
state: &SessionState,
conf: FileSinkConfig,
order_requirements: Option<LexRequirement>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.csv_file_format
.create_writer_physical_plan(input, state, conf, order_requirements)
.await
}
}
#[derive(Default, Debug)]
/// Factory for creating TSV file formats
///
/// This factory is a wrapper around the CSV file format factory
/// for demonstration purposes.
pub struct TSVFileFactory {
csv_file_factory: CsvFormatFactory,
}
impl TSVFileFactory {
pub fn new() -> Self {
Self {
csv_file_factory: CsvFormatFactory::new(),
}
}
}
impl FileFormatFactory for TSVFileFactory {
fn create(
&self,
state: &SessionState,
format_options: &std::collections::HashMap<String, String>,
) -> Result<std::sync::Arc<dyn FileFormat>> {
let mut new_options = format_options.clone();
new_options.insert("format.delimiter".to_string(), "\t".to_string());
let csv_file_format = self.csv_file_factory.create(state, &new_options)?;
let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format));
Ok(tsv_file_format)
}
fn default(&self) -> std::sync::Arc<dyn FileFormat> {
todo!()
}
fn as_any(&self) -> &dyn Any {
self
}
}
impl GetExt for TSVFileFactory {
fn get_ext(&self) -> String {
"tsv".to_string()
}
}
#[tokio::main]
async fn main() -> Result<()> {
// Create a new context with the default configuration
let mut state = SessionStateBuilder::new().with_default_features().build();
// Register the custom file format
let file_format = Arc::new(TSVFileFactory::new());
state.register_file_format(file_format, true).unwrap();
// Create a new context with the custom file format
let ctx = SessionContext::new_with_state(state);
let mem_table = create_mem_table();
ctx.register_table("mem_table", mem_table).unwrap();
let temp_dir = tempdir().unwrap();
let table_save_path = temp_dir.path().join("mem_table.tsv");
let d = ctx
.sql(&format!(
"COPY mem_table TO '{}' STORED AS TSV;",
table_save_path.display(),
))
.await?;
let results = d.collect().await?;
println!(
"Number of inserted rows: {:?}",
(results[0]
.column_by_name("count")
.unwrap()
.as_primitive::<UInt64Type>()
.value(0))
);
Ok(())
}
// create a simple mem table
fn create_mem_table() -> Arc<MemTable> {
let fields = vec![
Field::new("id", DataType::UInt8, false),
Field::new("data", DataType::Utf8, false),
];
let schema = Arc::new(Schema::new(fields));
let partitions = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt8Array::from(vec![1, 2])),
Arc::new(StringArray::from(vec!["foo", "bar"])),
],
)
.unwrap();
Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap())
}