blob: 9c0adc98d810f71adc5d1c1ee59ade50cca65a12 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::cell::RefCell;
use std::env;
use std::rc::Rc;
use std::sync::Arc;
extern crate arrow;
extern crate datafusion;
use arrow::array::*;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::error::Result;
use datafusion::execution::context::ExecutionContext;
use datafusion::execution::relation::Relation;
use datafusion::logicalplan::LogicalPlan;
const DEFAULT_BATCH_SIZE: usize = 1024 * 1024;
#[test]
fn nyc() -> Result<()> {
// schema for nyxtaxi csv files
let schema = Schema::new(vec![
Field::new("VendorID", DataType::Utf8, true),
Field::new("tpep_pickup_datetime", DataType::Utf8, true),
Field::new("tpep_dropoff_datetime", DataType::Utf8, true),
Field::new("passenger_count", DataType::Utf8, true),
Field::new("trip_distance", DataType::Float64, true),
Field::new("RatecodeID", DataType::Utf8, true),
Field::new("store_and_fwd_flag", DataType::Utf8, true),
Field::new("PULocationID", DataType::Utf8, true),
Field::new("DOLocationID", DataType::Utf8, true),
Field::new("payment_type", DataType::Utf8, true),
Field::new("fare_amount", DataType::Float64, true),
Field::new("extra", DataType::Float64, true),
Field::new("mta_tax", DataType::Float64, true),
Field::new("tip_amount", DataType::Float64, true),
Field::new("tolls_amount", DataType::Float64, true),
Field::new("improvement_surcharge", DataType::Float64, true),
Field::new("total_amount", DataType::Float64, true),
]);
let mut ctx = ExecutionContext::new();
ctx.register_csv("tripdata", "file.csv", &schema, true);
let logical_plan = ctx.create_logical_plan(
"SELECT passenger_count, MIN(fare_amount), MAX(fare_amount) \
FROM tripdata GROUP BY passenger_count",
)?;
let optimized_plan = ctx.optimize(&logical_plan)?;
match optimized_plan.as_ref() {
LogicalPlan::Aggregate { input, .. } => match input.as_ref() {
LogicalPlan::TableScan {
ref projected_schema,
..
} => {
assert_eq!(2, projected_schema.fields().len());
assert_eq!(projected_schema.field(0).name(), "passenger_count");
assert_eq!(projected_schema.field(1).name(), "fare_amount");
}
_ => assert!(false),
},
_ => assert!(false),
}
Ok(())
}
#[test]
fn parquet_query() {
let mut ctx = ExecutionContext::new();
register_alltypes_parquet(&mut ctx);
let sql = "SELECT id, string_col FROM alltypes_plain";
let actual = execute(&mut ctx, sql);
let expected = "4\t\"0\"\n5\t\"1\"\n6\t\"0\"\n7\t\"1\"\n2\t\"0\"\n3\t\"1\"\n0\t\"0\"\n1\t\"1\"\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn parquet_single_nan_schema() {
let mut ctx = ExecutionContext::new();
let testdata = env::var("PARQUET_TEST_DATA").expect("PARQUET_TEST_DATA not defined");
ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata))
.unwrap();
let sql = "SELECT mycol FROM single_nan";
let relation = ctx.sql(&sql, 1024 * 1024).unwrap();
let mut results = relation.borrow_mut();
while let Some(batch) = results.next().unwrap() {
assert_eq!(1, batch.num_rows());
assert_eq!(1, batch.num_columns());
}
}
#[test]
fn csv_count_star() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT COUNT(*), COUNT(1), COUNT(c1) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql);
let expected = "100\t100\t100\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_with_predicate() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4";
let actual = execute(&mut ctx, sql);
let expected = "\"e\"\t0.39144436569161134\n\"d\"\t0.38870280983958583\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_group_by_int_min_max() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2";
let actual = execute(&mut ctx, sql);
let expected = "4\t0.02182578039211991\t0.9237877978193884\n5\t0.01479305307777301\t0.9723580396501548\n2\t0.16301110515739792\t0.991517828651004\n3\t0.047343434291126085\t0.9293883502480845\n1\t0.05636955101974106\t0.9965400387585364\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_avg() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT avg(c12) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql);
let expected = "0.5089725099127211\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_group_by_avg() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
let actual = execute(&mut ctx, sql);
let expected = "\"a\"\t0.48754517466109415\n\"e\"\t0.48600669271341534\n\"d\"\t0.48855379387549824\n\"c\"\t0.6600456536439784\n\"b\"\t0.41040709263815384\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_avg_multi_batch() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT avg(c12) FROM aggregate_test_100";
let plan = ctx.create_logical_plan(&sql).unwrap();
let plan = ctx.optimize(&plan).unwrap();
let results = ctx.execute(&plan, 4).unwrap();
let mut relation = results.borrow_mut();
let batch = relation.next().unwrap().unwrap();
let column = batch.column(0);
let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
let actual = array.value(0);
let expected = 0.5089725;
// Due to float number's accuracy, different batch size will lead to different
// answers.
assert!((expected - actual).abs() < 0.01);
}
//#[test]
//fn csv_query_group_by_avg_multi_batch() {
// let mut ctx = ExecutionContext::new();
// register_aggregate_csv(&mut ctx);
// //TODO add ORDER BY once supported, to make this test determistic
// let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1";
// let plan = ctx.create_logical_plan(&sql).unwrap();
// let results = ctx.execute(&plan, 4).unwrap();
// let mut relation = results.borrow_mut();
// let mut actual_vec = Vec::new();
// while let Some(batch) = relation.next().unwrap() {
// let column = batch.column(1);
// let array = column.as_any().downcast_ref::<Float64Array>().unwrap();
//
// for row_index in 0..batch.num_rows() {
// actual_vec.push(array.value(row_index));
// }
// }
//
// let expect_vec = vec![0.41040709, 0.48754517, 0.48855379, 0.66004565];
//
// actual_vec.sort_by(|a, b| a.partial_cmp(b).unwrap());
//
// println!("actual : {:?}", actual_vec);
// println!("expected: {:?}", expect_vec);
//
// actual_vec
// .iter()
// .zip(expect_vec.iter())
// .for_each(|(actual, expect)| {
// // Due to float number's accuracy, different batch size will lead to
// different answers. assert!((expect - actual).abs() < 0.01);
// });
//}
#[test]
fn csv_query_count() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT count(c12) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql);
let expected = "100\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_group_by_int_count() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT count(c12) FROM aggregate_test_100 GROUP BY c1";
let actual = execute(&mut ctx, sql);
let expected = "\"a\"\t21\n\"e\"\t21\n\"d\"\t18\n\"c\"\t21\n\"b\"\t19\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_group_by_string_min_max() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
//TODO add ORDER BY once supported, to make this test determistic
let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1";
let actual = execute(&mut ctx, sql);
let expected =
"\"a\"\t0.02182578039211991\t0.9800193410444061\n\"e\"\t0.01479305307777301\t0.9965400387585364\n\"d\"\t0.061029375346466685\t0.9748360509016578\n\"c\"\t0.0494924465469434\t0.991517828651004\n\"b\"\t0.04893135681998029\t0.9185813970744787\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_cast() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4";
let actual = execute(&mut ctx, sql);
let expected = "0.39144436569161134\n0.38870280983958583\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_cast_literal() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2";
let actual = execute(&mut ctx, sql);
let expected = "0.9294097332465232\t1.0\n0.3114712539863804\t1.0\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_limit() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 2";
let actual = execute(&mut ctx, sql);
let expected = "0\n0\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_limit_bigger_than_nbr_of_rows() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200";
let actual = execute(&mut ctx, sql);
let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_limit_with_same_nbr_of_rows() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100";
let actual = execute(&mut ctx, sql);
let expected = "2\n5\n1\n1\n5\n4\n3\n3\n1\n4\n1\n4\n3\n2\n1\n1\n2\n1\n3\n2\n4\n1\n5\n4\n2\n1\n4\n5\n2\n3\n4\n2\n1\n5\n3\n1\n2\n3\n3\n3\n2\n4\n1\n3\n2\n5\n2\n1\n4\n1\n4\n2\n5\n4\n2\n3\n4\n4\n4\n5\n4\n2\n1\n2\n4\n2\n3\n5\n1\n1\n4\n2\n1\n2\n1\n1\n5\n4\n5\n2\n3\n2\n4\n1\n3\n4\n3\n2\n5\n3\n3\n2\n5\n5\n4\n1\n3\n3\n4\n4\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_limit_zero() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT 0 FROM aggregate_test_100 LIMIT 0";
let actual = execute(&mut ctx, sql);
let expected = "".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_create_external_table() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv_by_sql(&mut ctx);
let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1";
let actual = execute(&mut ctx, sql);
let expected = "\"c\"\t2\t1\t18109\t2033001162\t-6513304855495910254\t25\t43062\t1491205016\t10\t0.110830784\t0.9294097332465232\t\"6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW\"\n".to_string();
assert_eq!(expected, actual);
}
#[test]
fn csv_query_external_table_count() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv_by_sql(&mut ctx);
let sql = "SELECT COUNT(c12) FROM aggregate_test_100";
let actual = execute(&mut ctx, sql);
let expected = "100\n".to_string();
assert_eq!(expected, actual);
}
//TODO Uncomment the following test when ORDER BY is implemented to be able to test ORDER
// BY + LIMIT
/*
#[test]
fn csv_query_limit_with_order_by() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx);
let sql = "SELECT c7 FROM aggregate_test_100 ORDER BY c7 ASC LIMIT 2";
let actual = execute(&mut ctx, sql);
let expected = "0\n2\n".to_string();
assert_eq!(expected, actual);
}
*/
fn aggr_test_schema() -> Arc<Schema> {
Arc::new(Schema::new(vec![
Field::new("c1", DataType::Utf8, false),
Field::new("c2", DataType::UInt32, false),
Field::new("c3", DataType::Int8, false),
Field::new("c4", DataType::Int16, false),
Field::new("c5", DataType::Int32, false),
Field::new("c6", DataType::Int64, false),
Field::new("c7", DataType::UInt8, false),
Field::new("c8", DataType::UInt16, false),
Field::new("c9", DataType::UInt32, false),
Field::new("c10", DataType::UInt64, false),
Field::new("c11", DataType::Float32, false),
Field::new("c12", DataType::Float64, false),
Field::new("c13", DataType::Utf8, false),
]))
}
fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) {
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
// TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once
// unsigned is supported.
ctx.sql(
&format!(
"
CREATE EXTERNAL TABLE aggregate_test_100 (
c1 VARCHAR NOT NULL,
c2 INT NOT NULL,
c3 SMALLINT NOT NULL,
c4 SMALLINT NOT NULL,
c5 INT NOT NULL,
c6 BIGINT NOT NULL,
c7 SMALLINT NOT NULL,
c8 INT NOT NULL,
c9 BIGINT NOT NULL,
c10 VARCHAR NOT NULL,
c11 FLOAT NOT NULL,
c12 DOUBLE NOT NULL,
c13 VARCHAR NOT NULL
)
STORED AS CSV
WITH HEADER ROW
LOCATION '{}/csv/aggregate_test_100.csv'
",
testdata
),
1024,
)
.unwrap();
}
fn register_aggregate_csv(ctx: &mut ExecutionContext) {
let testdata = env::var("ARROW_TEST_DATA").expect("ARROW_TEST_DATA not defined");
let schema = aggr_test_schema();
register_csv(
ctx,
"aggregate_test_100",
&format!("{}/csv/aggregate_test_100.csv", testdata),
&schema,
);
}
fn register_csv(
ctx: &mut ExecutionContext,
name: &str,
filename: &str,
schema: &Arc<Schema>,
) {
ctx.register_csv(name, filename, &schema, true);
}
fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
let testdata = env::var("PARQUET_TEST_DATA").expect("PARQUET_TEST_DATA not defined");
ctx.register_parquet(
"alltypes_plain",
&format!("{}/alltypes_plain.parquet", testdata),
)
.unwrap();
}
/// Execute query and return result set as tab delimited string
fn execute(ctx: &mut ExecutionContext, sql: &str) -> String {
let plan = ctx.create_logical_plan(&sql).unwrap();
let plan = ctx.optimize(&plan).unwrap();
let results = ctx.execute(&plan, DEFAULT_BATCH_SIZE).unwrap();
result_str(&results)
}
fn result_str(results: &Rc<RefCell<dyn Relation>>) -> String {
let mut relation = results.borrow_mut();
let mut str = String::new();
while let Some(batch) = relation.next().unwrap() {
for row_index in 0..batch.num_rows() {
for column_index in 0..batch.num_columns() {
if column_index > 0 {
str.push_str("\t");
}
let column = batch.column(column_index);
match column.data_type() {
DataType::Int8 => {
let array = column.as_any().downcast_ref::<Int8Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Int16 => {
let array = column.as_any().downcast_ref::<Int16Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Int32 => {
let array = column.as_any().downcast_ref::<Int32Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Int64 => {
let array = column.as_any().downcast_ref::<Int64Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::UInt8 => {
let array = column.as_any().downcast_ref::<UInt8Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::UInt16 => {
let array =
column.as_any().downcast_ref::<UInt16Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::UInt32 => {
let array =
column.as_any().downcast_ref::<UInt32Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::UInt64 => {
let array =
column.as_any().downcast_ref::<UInt64Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Float32 => {
let array =
column.as_any().downcast_ref::<Float32Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Float64 => {
let array =
column.as_any().downcast_ref::<Float64Array>().unwrap();
str.push_str(&format!("{:?}", array.value(row_index)));
}
DataType::Utf8 => {
let array =
column.as_any().downcast_ref::<BinaryArray>().unwrap();
let s =
String::from_utf8(array.value(row_index).to_vec()).unwrap();
str.push_str(&format!("{:?}", s));
}
_ => str.push_str("???"),
}
}
str.push_str("\n");
}
}
str
}