| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| from pyspark.sql.connect.utils import check_dependencies |
| |
| check_dependencies(__name__) |
| |
| import array |
| import datetime |
| import decimal |
| |
| import pyarrow as pa |
| |
| from pyspark.sql.types import ( |
| _create_row, |
| Row, |
| DataType, |
| TimestampType, |
| TimestampNTZType, |
| MapType, |
| StructField, |
| StructType, |
| ArrayType, |
| BinaryType, |
| NullType, |
| DecimalType, |
| StringType, |
| UserDefinedType, |
| ) |
| |
| from pyspark.storagelevel import StorageLevel |
| import pyspark.sql.connect.proto as pb2 |
| from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names |
| |
| from typing import ( |
| Any, |
| Callable, |
| Sequence, |
| List, |
| ) |
| |
| |
| class LocalDataToArrowConversion: |
| """ |
| Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow. |
| Currently, only :class:`SparkSession` in Spark Connect can use this class. |
| """ |
| |
| @staticmethod |
| def _need_converter(dataType: DataType) -> bool: |
| if isinstance(dataType, NullType): |
| return True |
| elif isinstance(dataType, StructType): |
| # Struct maybe rows, should convert to dict. |
| return True |
| elif isinstance(dataType, ArrayType): |
| return LocalDataToArrowConversion._need_converter(dataType.elementType) |
| elif isinstance(dataType, MapType): |
| # Different from PySpark, here always needs conversion, |
| # since an Arrow Map requires a list of tuples. |
| return True |
| elif isinstance(dataType, BinaryType): |
| return True |
| elif isinstance(dataType, (TimestampType, TimestampNTZType)): |
| # Always truncate |
| return True |
| elif isinstance(dataType, DecimalType): |
| # Convert Decimal('NaN') to None |
| return True |
| elif isinstance(dataType, StringType): |
| # Coercion to StringType is allowed |
| return True |
| elif isinstance(dataType, UserDefinedType): |
| return True |
| else: |
| return False |
| |
| @staticmethod |
| def _create_converter(dataType: DataType) -> Callable: |
| assert dataType is not None and isinstance(dataType, DataType) |
| |
| if not LocalDataToArrowConversion._need_converter(dataType): |
| return lambda value: value |
| |
| if isinstance(dataType, NullType): |
| return lambda value: None |
| |
| elif isinstance(dataType, StructType): |
| field_names = dataType.fieldNames() |
| dedup_field_names = _dedup_names(dataType.names) |
| |
| field_convs = [ |
| LocalDataToArrowConversion._create_converter(field.dataType) |
| for field in dataType.fields |
| ] |
| |
| def convert_struct(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, (tuple, dict)) or hasattr( |
| value, "__dict__" |
| ), f"{type(value)} {value}" |
| |
| _dict = {} |
| if ( |
| not isinstance(value, Row) |
| and not isinstance(value, tuple) # inherited namedtuple |
| and hasattr(value, "__dict__") |
| ): |
| value = value.__dict__ |
| if isinstance(value, dict): |
| for i, field in enumerate(field_names): |
| _dict[dedup_field_names[i]] = field_convs[i](value.get(field)) |
| else: |
| if len(value) != len(field_names): |
| raise ValueError( |
| f"Length mismatch: Expected axis has {len(field_names)} elements, " |
| f"new values have {len(value)} elements" |
| ) |
| for i in range(len(field_names)): |
| _dict[dedup_field_names[i]] = field_convs[i](value[i]) |
| |
| return _dict |
| |
| return convert_struct |
| |
| elif isinstance(dataType, ArrayType): |
| element_conv = LocalDataToArrowConversion._create_converter(dataType.elementType) |
| |
| def convert_array(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, (list, array.array)) |
| return [element_conv(v) for v in value] |
| |
| return convert_array |
| |
| elif isinstance(dataType, MapType): |
| key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType) |
| value_conv = LocalDataToArrowConversion._create_converter(dataType.valueType) |
| |
| def convert_map(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, dict) |
| |
| _tuples = [] |
| for k, v in value.items(): |
| _tuples.append((key_conv(k), value_conv(v))) |
| |
| return _tuples |
| |
| return convert_map |
| |
| elif isinstance(dataType, BinaryType): |
| |
| def convert_binary(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, (bytes, bytearray)) |
| return bytes(value) |
| |
| return convert_binary |
| |
| elif isinstance(dataType, TimestampType): |
| |
| def convert_timestamp(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, datetime.datetime) |
| return value.astimezone(datetime.timezone.utc) |
| |
| return convert_timestamp |
| |
| elif isinstance(dataType, TimestampNTZType): |
| |
| def convert_timestamp_ntz(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, datetime.datetime) and value.tzinfo is None |
| return value |
| |
| return convert_timestamp_ntz |
| |
| elif isinstance(dataType, DecimalType): |
| |
| def convert_decimal(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, decimal.Decimal) |
| return None if value.is_nan() else value |
| |
| return convert_decimal |
| |
| elif isinstance(dataType, StringType): |
| |
| def convert_string(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| # only atomic types are supported |
| assert isinstance( |
| value, |
| ( |
| bool, |
| int, |
| float, |
| str, |
| bytes, |
| bytearray, |
| decimal.Decimal, |
| datetime.date, |
| datetime.datetime, |
| datetime.timedelta, |
| ), |
| ) |
| if isinstance(value, bool): |
| # To match the PySpark which convert bool to string in |
| # the JVM side (python.EvaluatePython.makeFromJava) |
| return str(value).lower() |
| else: |
| return str(value) |
| |
| return convert_string |
| |
| elif isinstance(dataType, UserDefinedType): |
| udt: UserDefinedType = dataType |
| |
| conv = LocalDataToArrowConversion._create_converter(udt.sqlType()) |
| |
| def convert_udt(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| return conv(udt.serialize(value)) |
| |
| return convert_udt |
| |
| else: |
| return lambda value: value |
| |
| @staticmethod |
| def convert(data: Sequence[Any], schema: StructType) -> "pa.Table": |
| assert isinstance(data, list) and len(data) > 0 |
| |
| assert schema is not None and isinstance(schema, StructType) |
| |
| column_names = schema.fieldNames() |
| |
| column_convs = [ |
| LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields |
| ] |
| |
| pylist: List[List] = [[] for _ in range(len(column_names))] |
| |
| for item in data: |
| if ( |
| not isinstance(item, Row) |
| and not isinstance(item, tuple) # inherited namedtuple |
| and hasattr(item, "__dict__") |
| ): |
| item = item.__dict__ |
| if isinstance(item, dict): |
| for i, col in enumerate(column_names): |
| pylist[i].append(column_convs[i](item.get(col))) |
| else: |
| if len(item) != len(column_names): |
| raise ValueError( |
| f"Length mismatch: Expected axis has {len(column_names)} elements, " |
| f"new values have {len(item)} elements" |
| ) |
| for i in range(len(column_names)): |
| pylist[i].append(column_convs[i](item[i])) |
| |
| pa_schema = to_arrow_schema( |
| StructType( |
| [ |
| StructField( |
| field.name, _deduplicate_field_names(field.dataType), field.nullable |
| ) |
| for field in schema.fields |
| ] |
| ) |
| ) |
| |
| return pa.Table.from_arrays(pylist, schema=pa_schema) |
| |
| |
| class ArrowTableToRowsConversion: |
| """ |
| Conversion from Arrow Table to Rows. |
| Currently, only :class:`DataFrame` in Spark Connect can use this class. |
| """ |
| |
| @staticmethod |
| def _need_converter(dataType: DataType) -> bool: |
| if isinstance(dataType, NullType): |
| return True |
| elif isinstance(dataType, StructType): |
| return True |
| elif isinstance(dataType, ArrayType): |
| return ArrowTableToRowsConversion._need_converter(dataType.elementType) |
| elif isinstance(dataType, MapType): |
| # Different from PySpark, here always needs conversion, |
| # since the input from Arrow is a list of tuples. |
| return True |
| elif isinstance(dataType, BinaryType): |
| return True |
| elif isinstance(dataType, (TimestampType, TimestampNTZType)): |
| # Always remove the time zone info for now |
| return True |
| elif isinstance(dataType, UserDefinedType): |
| return True |
| else: |
| return False |
| |
| @staticmethod |
| def _create_converter(dataType: DataType) -> Callable: |
| assert dataType is not None and isinstance(dataType, DataType) |
| |
| if not ArrowTableToRowsConversion._need_converter(dataType): |
| return lambda value: value |
| |
| if isinstance(dataType, NullType): |
| return lambda value: None |
| |
| elif isinstance(dataType, StructType): |
| field_names = dataType.names |
| dedup_field_names = _dedup_names(field_names) |
| |
| field_convs = [ |
| ArrowTableToRowsConversion._create_converter(f.dataType) for f in dataType.fields |
| ] |
| |
| def convert_struct(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, dict) |
| |
| _values = [ |
| field_convs[i](value.get(name, None)) |
| for i, name in enumerate(dedup_field_names) |
| ] |
| return _create_row(field_names, _values) |
| |
| return convert_struct |
| |
| elif isinstance(dataType, ArrayType): |
| element_conv = ArrowTableToRowsConversion._create_converter(dataType.elementType) |
| |
| def convert_array(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, list) |
| return [element_conv(v) for v in value] |
| |
| return convert_array |
| |
| elif isinstance(dataType, MapType): |
| key_conv = ArrowTableToRowsConversion._create_converter(dataType.keyType) |
| value_conv = ArrowTableToRowsConversion._create_converter(dataType.valueType) |
| |
| def convert_map(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, list) |
| assert all(isinstance(t, tuple) and len(t) == 2 for t in value) |
| return dict((key_conv(t[0]), value_conv(t[1])) for t in value) |
| |
| return convert_map |
| |
| elif isinstance(dataType, BinaryType): |
| |
| def convert_binary(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, bytes) |
| return bytearray(value) |
| |
| return convert_binary |
| |
| elif isinstance(dataType, TimestampType): |
| |
| def convert_timestample(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, datetime.datetime) |
| return value.astimezone().replace(tzinfo=None) |
| |
| return convert_timestample |
| |
| elif isinstance(dataType, TimestampNTZType): |
| |
| def convert_timestample_ntz(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| assert isinstance(value, datetime.datetime) |
| return value |
| |
| return convert_timestample_ntz |
| |
| elif isinstance(dataType, UserDefinedType): |
| udt: UserDefinedType = dataType |
| |
| conv = ArrowTableToRowsConversion._create_converter(udt.sqlType()) |
| |
| def convert_udt(value: Any) -> Any: |
| if value is None: |
| return None |
| else: |
| return udt.deserialize(conv(value)) |
| |
| return convert_udt |
| |
| else: |
| return lambda value: value |
| |
| @staticmethod |
| def convert(table: "pa.Table", schema: StructType) -> List[Row]: |
| assert isinstance(table, pa.Table) |
| |
| assert schema is not None and isinstance(schema, StructType) |
| |
| field_converters = [ |
| ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields |
| ] |
| |
| columnar_data = [column.to_pylist() for column in table.columns] |
| |
| rows: List[Row] = [] |
| for i in range(0, table.num_rows): |
| values = [field_converters[j](columnar_data[j][i]) for j in range(table.num_columns)] |
| rows.append(_create_row(fields=schema.fieldNames(), values=values)) |
| return rows |
| |
| |
| def storage_level_to_proto(storage_level: StorageLevel) -> pb2.StorageLevel: |
| assert storage_level is not None and isinstance(storage_level, StorageLevel) |
| return pb2.StorageLevel( |
| use_disk=storage_level.useDisk, |
| use_memory=storage_level.useMemory, |
| use_off_heap=storage_level.useOffHeap, |
| deserialized=storage_level.deserialized, |
| replication=storage_level.replication, |
| ) |
| |
| |
| def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel: |
| assert storage_level is not None and isinstance(storage_level, pb2.StorageLevel) |
| return StorageLevel( |
| useDisk=storage_level.use_disk, |
| useMemory=storage_level.use_memory, |
| useOffHeap=storage_level.use_off_heap, |
| deserialized=storage_level.deserialized, |
| replication=storage_level.replication, |
| ) |