blob: aedf5e346d04ef91eb8ba7d04190a240ff7f1d79 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::{iter::zip, sync::Arc, vec};
use arrow_array::builder::BinaryBuilder;
use arrow_schema::DataType;
use datafusion_common::cast::as_float64_array;
use datafusion_common::error::Result;
use datafusion_common::scalar::ScalarValue;
use datafusion_expr::{
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
};
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
use sedona_schema::{
datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY},
matchers::ArgMatcher,
};
use crate::{executor::WkbExecutor, st_setsrid::SRIDifiedKernel};
/// ST_Point() scalar UDF implementation
///
/// Native implementation to create geometries from coordinates.
/// See [`st_geogpoint_udf`] for the corresponding geography constructor.
pub fn st_point_udf() -> SedonaScalarUDF {
let kernel = Arc::new(STGeoFromPoint {
out_type: WKB_GEOMETRY,
});
let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone()));
SedonaScalarUDF::new(
"st_point",
vec![sridified_kernel, kernel],
Volatility::Immutable,
Some(doc("ST_Point", "Geometry")),
)
}
/// ST_GeogPoint() scalar UDF implementation
///
/// Native implementation to create geometries from coordinates.
/// See [`st_geogpoint_udf`] for the corresponding geography constructor.
pub fn st_geogpoint_udf() -> SedonaScalarUDF {
let kernel = Arc::new(STGeoFromPoint {
out_type: WKB_GEOGRAPHY,
});
let sridified_kernel = Arc::new(SRIDifiedKernel::new(kernel.clone()));
SedonaScalarUDF::new(
"st_geogpoint",
vec![sridified_kernel, kernel],
Volatility::Immutable,
Some(doc("st_geogpoint", "Geography")),
)
}
fn doc(name: &str, out_type_name: &str) -> Documentation {
Documentation::builder(
DOC_SECTION_OTHER,
format!(
"Construct a Point {} from X and Y",
out_type_name.to_lowercase()
),
format!("{name} (x: Double, y: Double)"),
)
.with_argument("x", "double: X value")
.with_argument("y", "double: Y value")
.with_argument("srid", "srid: EPSG code to set (e.g., 4326)")
.with_sql_example(format!("{name}(-64.36, 45.09)"))
.build()
}
#[derive(Debug)]
struct STGeoFromPoint {
out_type: SedonaType,
}
impl SedonaScalarKernel for STGeoFromPoint {
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
let matcher = ArgMatcher::new(
vec![ArgMatcher::is_numeric(), ArgMatcher::is_numeric()],
self.out_type.clone(),
);
matcher.match_args(args)
}
fn invoke_batch(
&self,
arg_types: &[SedonaType],
args: &[ColumnarValue],
) -> Result<ColumnarValue> {
let executor = WkbExecutor::new(arg_types, args);
let x = &args[0].cast_to(&DataType::Float64, None)?;
let y = &args[1].cast_to(&DataType::Float64, None)?;
let mut item: [u8; 21] = [0x00; 21];
item[0] = 0x01;
item[1] = 0x01;
// Handle the Scalar case to ensure that Scalar + Scalar -> Scalar
if let (
ColumnarValue::Scalar(ScalarValue::Float64(x_float)),
ColumnarValue::Scalar(ScalarValue::Float64(y_float)),
) = (x, y)
{
if let (Some(x), Some(y)) = (x_float, y_float) {
populate_wkb_item(&mut item, x, y);
return Ok(ScalarValue::Binary(Some(item.to_vec())).into());
} else {
return Ok(ScalarValue::Binary(None).into());
}
}
// Ensure both sides are arrays before iterating
let x_array = x.to_array(executor.num_iterations())?;
let y_array = y.to_array(executor.num_iterations())?;
let x_f64 = as_float64_array(&x_array)?;
let y_f64 = as_float64_array(&y_array)?;
let mut builder = BinaryBuilder::with_capacity(
executor.num_iterations(),
item.len() * executor.num_iterations(),
);
for (x_elem, y_elem) in zip(x_f64, y_f64) {
match (x_elem, y_elem) {
(Some(x), Some(y)) => {
populate_wkb_item(&mut item, &x, &y);
builder.append_value(item);
}
_ => {
builder.append_null();
}
}
}
let new_array = builder.finish();
Ok(ColumnarValue::Array(Arc::new(new_array)))
}
}
fn populate_wkb_item(item: &mut [u8], x: &f64, y: &f64) {
item[5..13].copy_from_slice(&x.to_le_bytes());
item[13..21].copy_from_slice(&y.to_le_bytes());
}
#[cfg(test)]
mod tests {
use arrow_array::create_array;
use arrow_array::ArrayRef;
use arrow_schema::DataType;
use datafusion_expr::Literal;
use datafusion_expr::ScalarUDF;
use rstest::rstest;
use sedona_schema::crs::lnglat;
use sedona_schema::datatypes::Edges;
use sedona_testing::compare::assert_array_equal;
use sedona_testing::{create::create_array, testers::ScalarUdfTester};
use super::*;
#[test]
fn udf_metadata() {
let geom_from_point: ScalarUDF = st_point_udf().into();
assert_eq!(geom_from_point.name(), "st_point");
assert!(geom_from_point.documentation().is_some());
let geog_from_point: ScalarUDF = st_geogpoint_udf().into();
assert_eq!(geog_from_point.name(), "st_geogpoint");
assert!(geog_from_point.documentation().is_some());
}
#[rstest]
#[case(DataType::Float64, DataType::Float64)]
#[case(DataType::Float32, DataType::Float64)]
#[case(DataType::Float64, DataType::Float32)]
#[case(DataType::Float32, DataType::Float32)]
fn udf_invoke(#[case] lhs_type: DataType, #[case] rhs_type: DataType) {
let udf = st_point_udf();
let lhs_scalar_null = ScalarValue::Float64(None).cast_to(&lhs_type).unwrap();
let lhs_scalar = ScalarValue::Float64(Some(1.0)).cast_to(&lhs_type).unwrap();
let rhs_scalar_null = ScalarValue::Float64(None).cast_to(&rhs_type).unwrap();
let rhs_scalar = ScalarValue::Float64(Some(2.0)).cast_to(&rhs_type).unwrap();
let lhs_array: ArrayRef = create_array!(Float64, [Some(1.0), Some(2.0), None, None]);
let rhs_array: ArrayRef = create_array!(Float64, [Some(5.0), None, Some(7.0), None]);
let tester = ScalarUdfTester::new(
udf.into(),
vec![SedonaType::Arrow(lhs_type), SedonaType::Arrow(rhs_type)],
);
// Check scalars
let result = tester
.invoke_scalar_scalar(lhs_scalar.clone(), rhs_scalar.clone())
.unwrap();
tester.assert_scalar_result_equals(result, "POINT (1 2)");
// Check scalar null combinations
let result = tester
.invoke_scalar_scalar(lhs_scalar.clone(), rhs_scalar_null.clone())
.unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::Null);
let result = tester
.invoke_scalar_scalar(lhs_scalar_null.clone(), rhs_scalar.clone())
.unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::Null);
let result = tester
.invoke_scalar_scalar(lhs_scalar_null.clone(), rhs_scalar_null.clone())
.unwrap();
tester.assert_scalar_result_equals(result, ScalarValue::Null);
// Check array
assert_array_equal(
&tester
.invoke_array_array(lhs_array.clone(), rhs_array.clone())
.unwrap(),
&create_array(&[Some("POINT (1 5)"), None, None, None], &WKB_GEOMETRY),
);
// Check array/scalar combinations
assert_array_equal(
&tester
.invoke_array_scalar(lhs_array.clone(), rhs_scalar.clone())
.unwrap(),
&create_array(
&[Some("POINT (1 2)"), Some("POINT (2 2)"), None, None],
&WKB_GEOMETRY,
),
);
assert_array_equal(
&tester
.invoke_scalar_array(lhs_scalar.clone(), rhs_array.clone())
.unwrap(),
&create_array(
&[Some("POINT (1 5)"), None, Some("POINT (1 7)"), None],
&WKB_GEOMETRY,
),
);
}
#[rstest]
#[case(DataType::UInt32, 4326)]
#[case(DataType::Int32, 4326)]
#[case(DataType::Utf8, "4326")]
#[case(DataType::Utf8, "EPSG:4326")]
fn udf_invoke_with_srid(#[case] srid_type: DataType, #[case] srid_value: impl Literal + Copy) {
let udf = st_point_udf();
let tester = ScalarUdfTester::new(
udf.into(),
vec![
SedonaType::Arrow(DataType::Float64),
SedonaType::Arrow(DataType::Float64),
SedonaType::Arrow(srid_type),
],
);
let return_type = tester
.return_type_with_scalar_scalar_scalar(Some(1.0), Some(2.0), Some(srid_value))
.unwrap();
assert_eq!(return_type, SedonaType::Wkb(Edges::Planar, lnglat()));
let result = tester
.invoke_scalar_scalar_scalar(1.0, 2.0, srid_value)
.unwrap();
tester.assert_scalar_result_equals_with_return_type(result, "POINT (1 2)", return_type);
}
#[test]
fn udf_invoke_with_invalid_srid() {
let udf = st_point_udf();
let tester = ScalarUdfTester::new(
udf.into(),
vec![
SedonaType::Arrow(DataType::Float64),
SedonaType::Arrow(DataType::Float64),
SedonaType::Arrow(DataType::Utf8),
],
);
let return_type = tester.return_type_with_scalar_scalar_scalar(
Some(1.0),
Some(2.0),
Some("gazornenplat"),
);
assert!(return_type.is_err());
let result = tester.invoke_scalar_scalar_scalar(1.0, 2.0, "gazornenplat");
assert!(result.is_err());
}
#[test]
fn geog() {
let udf = st_geogpoint_udf();
let tester = ScalarUdfTester::new(
udf.into(),
vec![
SedonaType::Arrow(DataType::Float64),
SedonaType::Arrow(DataType::Float64),
],
);
tester.assert_return_type(WKB_GEOGRAPHY);
let result = tester.invoke_scalar_scalar(1.0, 2.0).unwrap();
tester.assert_scalar_result_equals(result, "POINT (1 2)");
}
}