blob: dc46e68f532f645cc3f9c157f621f11cba86420f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
import array
import datetime
import decimal
import pyarrow as pa
from pyspark.sql.types import (
_create_row,
Row,
DataType,
TimestampType,
TimestampNTZType,
MapType,
StructField,
StructType,
ArrayType,
BinaryType,
NullType,
DecimalType,
StringType,
UserDefinedType,
)
from pyspark.storagelevel import StorageLevel
import pyspark.sql.connect.proto as pb2
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names
from typing import (
Any,
Callable,
Sequence,
List,
)
class LocalDataToArrowConversion:
"""
Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow.
Currently, only :class:`SparkSession` in Spark Connect can use this class.
"""
@staticmethod
def _need_converter(dataType: DataType) -> bool:
if isinstance(dataType, NullType):
return True
elif isinstance(dataType, StructType):
# Struct maybe rows, should convert to dict.
return True
elif isinstance(dataType, ArrayType):
return LocalDataToArrowConversion._need_converter(dataType.elementType)
elif isinstance(dataType, MapType):
# Different from PySpark, here always needs conversion,
# since an Arrow Map requires a list of tuples.
return True
elif isinstance(dataType, BinaryType):
return True
elif isinstance(dataType, (TimestampType, TimestampNTZType)):
# Always truncate
return True
elif isinstance(dataType, DecimalType):
# Convert Decimal('NaN') to None
return True
elif isinstance(dataType, StringType):
# Coercion to StringType is allowed
return True
elif isinstance(dataType, UserDefinedType):
return True
else:
return False
@staticmethod
def _create_converter(dataType: DataType) -> Callable:
assert dataType is not None and isinstance(dataType, DataType)
if not LocalDataToArrowConversion._need_converter(dataType):
return lambda value: value
if isinstance(dataType, NullType):
return lambda value: None
elif isinstance(dataType, StructType):
field_names = dataType.fieldNames()
dedup_field_names = _dedup_names(dataType.names)
field_convs = [
LocalDataToArrowConversion._create_converter(field.dataType)
for field in dataType.fields
]
def convert_struct(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, (tuple, dict)) or hasattr(
value, "__dict__"
), f"{type(value)} {value}"
_dict = {}
if (
not isinstance(value, Row)
and not isinstance(value, tuple) # inherited namedtuple
and hasattr(value, "__dict__")
):
value = value.__dict__
if isinstance(value, dict):
for i, field in enumerate(field_names):
_dict[dedup_field_names[i]] = field_convs[i](value.get(field))
else:
if len(value) != len(field_names):
raise ValueError(
f"Length mismatch: Expected axis has {len(field_names)} elements, "
f"new values have {len(value)} elements"
)
for i in range(len(field_names)):
_dict[dedup_field_names[i]] = field_convs[i](value[i])
return _dict
return convert_struct
elif isinstance(dataType, ArrayType):
element_conv = LocalDataToArrowConversion._create_converter(dataType.elementType)
def convert_array(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, (list, array.array))
return [element_conv(v) for v in value]
return convert_array
elif isinstance(dataType, MapType):
key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType)
value_conv = LocalDataToArrowConversion._create_converter(dataType.valueType)
def convert_map(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, dict)
_tuples = []
for k, v in value.items():
_tuples.append((key_conv(k), value_conv(v)))
return _tuples
return convert_map
elif isinstance(dataType, BinaryType):
def convert_binary(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, (bytes, bytearray))
return bytes(value)
return convert_binary
elif isinstance(dataType, TimestampType):
def convert_timestamp(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, datetime.datetime)
return value.astimezone(datetime.timezone.utc)
return convert_timestamp
elif isinstance(dataType, TimestampNTZType):
def convert_timestamp_ntz(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, datetime.datetime) and value.tzinfo is None
return value
return convert_timestamp_ntz
elif isinstance(dataType, DecimalType):
def convert_decimal(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, decimal.Decimal)
return None if value.is_nan() else value
return convert_decimal
elif isinstance(dataType, StringType):
def convert_string(value: Any) -> Any:
if value is None:
return None
else:
# only atomic types are supported
assert isinstance(
value,
(
bool,
int,
float,
str,
bytes,
bytearray,
decimal.Decimal,
datetime.date,
datetime.datetime,
datetime.timedelta,
),
)
if isinstance(value, bool):
# To match the PySpark which convert bool to string in
# the JVM side (python.EvaluatePython.makeFromJava)
return str(value).lower()
else:
return str(value)
return convert_string
elif isinstance(dataType, UserDefinedType):
udt: UserDefinedType = dataType
conv = LocalDataToArrowConversion._create_converter(udt.sqlType())
def convert_udt(value: Any) -> Any:
if value is None:
return None
else:
return conv(udt.serialize(value))
return convert_udt
else:
return lambda value: value
@staticmethod
def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":
assert isinstance(data, list) and len(data) > 0
assert schema is not None and isinstance(schema, StructType)
column_names = schema.fieldNames()
column_convs = [
LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields
]
pylist: List[List] = [[] for _ in range(len(column_names))]
for item in data:
if (
not isinstance(item, Row)
and not isinstance(item, tuple) # inherited namedtuple
and hasattr(item, "__dict__")
):
item = item.__dict__
if isinstance(item, dict):
for i, col in enumerate(column_names):
pylist[i].append(column_convs[i](item.get(col)))
else:
if len(item) != len(column_names):
raise ValueError(
f"Length mismatch: Expected axis has {len(column_names)} elements, "
f"new values have {len(item)} elements"
)
for i in range(len(column_names)):
pylist[i].append(column_convs[i](item[i]))
pa_schema = to_arrow_schema(
StructType(
[
StructField(
field.name, _deduplicate_field_names(field.dataType), field.nullable
)
for field in schema.fields
]
)
)
return pa.Table.from_arrays(pylist, schema=pa_schema)
class ArrowTableToRowsConversion:
"""
Conversion from Arrow Table to Rows.
Currently, only :class:`DataFrame` in Spark Connect can use this class.
"""
@staticmethod
def _need_converter(dataType: DataType) -> bool:
if isinstance(dataType, NullType):
return True
elif isinstance(dataType, StructType):
return True
elif isinstance(dataType, ArrayType):
return ArrowTableToRowsConversion._need_converter(dataType.elementType)
elif isinstance(dataType, MapType):
# Different from PySpark, here always needs conversion,
# since the input from Arrow is a list of tuples.
return True
elif isinstance(dataType, BinaryType):
return True
elif isinstance(dataType, (TimestampType, TimestampNTZType)):
# Always remove the time zone info for now
return True
elif isinstance(dataType, UserDefinedType):
return True
else:
return False
@staticmethod
def _create_converter(dataType: DataType) -> Callable:
assert dataType is not None and isinstance(dataType, DataType)
if not ArrowTableToRowsConversion._need_converter(dataType):
return lambda value: value
if isinstance(dataType, NullType):
return lambda value: None
elif isinstance(dataType, StructType):
field_names = dataType.names
dedup_field_names = _dedup_names(field_names)
field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f in dataType.fields
]
def convert_struct(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, dict)
_values = [
field_convs[i](value.get(name, None))
for i, name in enumerate(dedup_field_names)
]
return _create_row(field_names, _values)
return convert_struct
elif isinstance(dataType, ArrayType):
element_conv = ArrowTableToRowsConversion._create_converter(dataType.elementType)
def convert_array(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, list)
return [element_conv(v) for v in value]
return convert_array
elif isinstance(dataType, MapType):
key_conv = ArrowTableToRowsConversion._create_converter(dataType.keyType)
value_conv = ArrowTableToRowsConversion._create_converter(dataType.valueType)
def convert_map(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, list)
assert all(isinstance(t, tuple) and len(t) == 2 for t in value)
return dict((key_conv(t[0]), value_conv(t[1])) for t in value)
return convert_map
elif isinstance(dataType, BinaryType):
def convert_binary(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, bytes)
return bytearray(value)
return convert_binary
elif isinstance(dataType, TimestampType):
def convert_timestample(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, datetime.datetime)
return value.astimezone().replace(tzinfo=None)
return convert_timestample
elif isinstance(dataType, TimestampNTZType):
def convert_timestample_ntz(value: Any) -> Any:
if value is None:
return None
else:
assert isinstance(value, datetime.datetime)
return value
return convert_timestample_ntz
elif isinstance(dataType, UserDefinedType):
udt: UserDefinedType = dataType
conv = ArrowTableToRowsConversion._create_converter(udt.sqlType())
def convert_udt(value: Any) -> Any:
if value is None:
return None
else:
return udt.deserialize(conv(value))
return convert_udt
else:
return lambda value: value
@staticmethod
def convert(table: "pa.Table", schema: StructType) -> List[Row]:
assert isinstance(table, pa.Table)
assert schema is not None and isinstance(schema, StructType)
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields
]
columnar_data = [column.to_pylist() for column in table.columns]
rows: List[Row] = []
for i in range(0, table.num_rows):
values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)]
rows.append(_create_row(fields=schema.fieldNames(), values=values))
return rows
def storage_level_to_proto(storage_level: StorageLevel) -> pb2.StorageLevel:
assert storage_level is not None and isinstance(storage_level, StorageLevel)
return pb2.StorageLevel(
use_disk=storage_level.useDisk,
use_memory=storage_level.useMemory,
use_off_heap=storage_level.useOffHeap,
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)
def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
assert storage_level is not None and isinstance(storage_level, pb2.StorageLevel)
return StorageLevel(
useDisk=storage_level.use_disk,
useMemory=storage_level.use_memory,
useOffHeap=storage_level.use_off_heap,
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)