| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| import sys |
| import decimal |
| import time |
| import math |
| import datetime |
| import calendar |
| import json |
| import re |
| import base64 |
| from array import array |
| import ctypes |
| from collections.abc import Iterable |
| from functools import reduce |
| from typing import ( |
| cast, |
| overload, |
| Any, |
| Callable, |
| ClassVar, |
| Dict, |
| Iterator, |
| List, |
| Optional, |
| Union, |
| Tuple, |
| Type, |
| TypeVar, |
| TYPE_CHECKING, |
| ) |
| |
| from pyspark.util import is_remote_only |
| from pyspark.serializers import CloudPickleSerializer |
| from pyspark.sql.utils import has_numpy, get_active_spark_context |
| from pyspark.sql.variant_utils import VariantUtils |
| from pyspark.errors import ( |
| PySparkNotImplementedError, |
| PySparkTypeError, |
| PySparkValueError, |
| PySparkIndexError, |
| PySparkRuntimeError, |
| PySparkAttributeError, |
| PySparkKeyError, |
| ) |
| |
| if has_numpy: |
| import numpy as np |
| |
| if TYPE_CHECKING: |
| import numpy as np |
| from py4j.java_gateway import GatewayClient, JavaGateway, JavaClass |
| |
| T = TypeVar("T") |
| U = TypeVar("U") |
| |
| __all__ = [ |
| "DataType", |
| "NullType", |
| "CharType", |
| "StringType", |
| "VarcharType", |
| "BinaryType", |
| "BooleanType", |
| "DateType", |
| "TimestampType", |
| "TimestampNTZType", |
| "DecimalType", |
| "DoubleType", |
| "FloatType", |
| "ByteType", |
| "IntegerType", |
| "LongType", |
| "DayTimeIntervalType", |
| "YearMonthIntervalType", |
| "CalendarIntervalType", |
| "Row", |
| "ShortType", |
| "ArrayType", |
| "MapType", |
| "StructField", |
| "StructType", |
| "VariantType", |
| "VariantVal", |
| ] |
| |
| |
| class DataType: |
| """Base class for data types.""" |
| |
| def __repr__(self) -> str: |
| return self.__class__.__name__ + "()" |
| |
| def __hash__(self) -> int: |
| return hash(str(self)) |
| |
| def __eq__(self, other: Any) -> bool: |
| return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ |
| |
| def __ne__(self, other: Any) -> bool: |
| return not self.__eq__(other) |
| |
| @classmethod |
| def typeName(cls) -> str: |
| return cls.__name__[:-4].lower() |
| |
| def simpleString(self) -> str: |
| return self.typeName() |
| |
| def jsonValue(self) -> Union[str, Dict[str, Any]]: |
| return self.typeName() |
| |
| def json(self) -> str: |
| return json.dumps(self.jsonValue(), separators=(",", ":"), sort_keys=True) |
| |
| def needConversion(self) -> bool: |
| """ |
| Does this type needs conversion between Python object and internal SQL object. |
| |
| This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. |
| """ |
| return False |
| |
| def toInternal(self, obj: Any) -> Any: |
| """ |
| Converts a Python object into an internal SQL object. |
| """ |
| return obj |
| |
| def fromInternal(self, obj: Any) -> Any: |
| """ |
| Converts an internal SQL object into a native Python object. |
| """ |
| return obj |
| |
| def _as_nullable(self) -> "DataType": |
| return self |
| |
| @classmethod |
| def fromDDL(cls, ddl: str) -> "DataType": |
| """ |
| Creates :class:`DataType` for a given DDL-formatted string. |
| |
| .. versionadded:: 4.0.0 |
| |
| Parameters |
| ---------- |
| ddl : str |
| DDL-formatted string representation of types, e.g. |
| :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct |
| type can omit the ``struct<>`` for the compatibility reason with |
| ``spark.createDataFrame`` and Python UDFs. |
| |
| Returns |
| ------- |
| :class:`DataType` |
| |
| Examples |
| -------- |
| Create a StructType by the corresponding DDL formatted string. |
| |
| >>> from pyspark.sql.types import DataType |
| >>> DataType.fromDDL("b string, a int") |
| StructType([StructField('b', StringType(), True), StructField('a', IntegerType(), True)]) |
| |
| Create a single DataType by the corresponding DDL formatted string. |
| |
| >>> DataType.fromDDL("decimal(10,10)") |
| DecimalType(10,10) |
| |
| Create a StructType by the legacy string format. |
| |
| >>> DataType.fromDDL("b: string, a: int") |
| StructType([StructField('b', StringType(), True), StructField('a', IntegerType(), True)]) |
| """ |
| from pyspark.sql import SparkSession |
| from pyspark.sql.functions import udf |
| |
| # Intentionally uses SparkSession so one implementation can be shared with/without |
| # Spark Connect. |
| schema = ( |
| SparkSession.active().range(0).select(udf(lambda x: x, returnType=ddl)("id")).schema |
| ) |
| assert len(schema) == 1 |
| return schema[0].dataType |
| |
| |
| # This singleton pattern does not work with pickle, you will get |
| # another object after pickle and unpickle |
| class DataTypeSingleton(type): |
| """Metaclass for DataType""" |
| |
| _instances: ClassVar[Dict[Type["DataTypeSingleton"], "DataTypeSingleton"]] = {} |
| |
| def __call__(cls: Type[T]) -> T: |
| if cls not in cls._instances: # type: ignore[attr-defined] |
| cls._instances[cls] = super( # type: ignore[misc, attr-defined] |
| DataTypeSingleton, cls |
| ).__call__() |
| return cls._instances[cls] # type: ignore[attr-defined] |
| |
| |
| class NullType(DataType, metaclass=DataTypeSingleton): |
| """Null type. |
| |
| The data type representing None, used for the types that cannot be inferred. |
| """ |
| |
| @classmethod |
| def typeName(cls) -> str: |
| return "void" |
| |
| |
| class AtomicType(DataType): |
| """An internal type used to represent everything that is not |
| null, UDTs, arrays, structs, and maps.""" |
| |
| |
| class NumericType(AtomicType): |
| """Numeric data types.""" |
| |
| |
| class IntegralType(NumericType, metaclass=DataTypeSingleton): |
| """Integral data types.""" |
| |
| pass |
| |
| |
| class FractionalType(NumericType): |
| """Fractional data types.""" |
| |
| |
| class StringType(AtomicType): |
| """String data type. |
| |
| Parameters |
| ---------- |
| collation : str |
| name of the collation, default is UTF8_BINARY. |
| """ |
| |
| collationNames = ["UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI"] |
| |
| def __init__(self, collation: Optional[str] = None): |
| self.collationId = 0 if collation is None else self.collationNameToId(collation) |
| |
| @classmethod |
| def fromCollationId(self, collationId: int) -> "StringType": |
| return StringType(StringType.collationNames[collationId]) |
| |
| def collationIdToName(self) -> str: |
| if self.collationId == 0: |
| return "" |
| else: |
| return " collate %s" % StringType.collationNames[self.collationId] |
| |
| @classmethod |
| def collationNameToId(cls, collationName: str) -> int: |
| return StringType.collationNames.index(collationName) |
| |
| def simpleString(self) -> str: |
| return "string" + self.collationIdToName() |
| |
| def jsonValue(self) -> str: |
| return "string" + self.collationIdToName() |
| |
| def __repr__(self) -> str: |
| return ( |
| "StringType('%s')" % StringType.collationNames[self.collationId] |
| if self.collationId != 0 |
| else "StringType()" |
| ) |
| |
| |
| class CharType(AtomicType): |
| """Char data type |
| |
| Parameters |
| ---------- |
| length : int |
| the length limitation. |
| """ |
| |
| def __init__(self, length: int): |
| self.length = length |
| |
| def simpleString(self) -> str: |
| return "char(%d)" % (self.length) |
| |
| def jsonValue(self) -> str: |
| return "char(%d)" % (self.length) |
| |
| def __repr__(self) -> str: |
| return "CharType(%d)" % (self.length) |
| |
| |
| class VarcharType(AtomicType): |
| """Varchar data type |
| |
| Parameters |
| ---------- |
| length : int |
| the length limitation. |
| """ |
| |
| def __init__(self, length: int): |
| self.length = length |
| |
| def simpleString(self) -> str: |
| return "varchar(%d)" % (self.length) |
| |
| def jsonValue(self) -> str: |
| return "varchar(%d)" % (self.length) |
| |
| def __repr__(self) -> str: |
| return "VarcharType(%d)" % (self.length) |
| |
| |
| class BinaryType(AtomicType, metaclass=DataTypeSingleton): |
| """Binary (byte array) data type.""" |
| |
| pass |
| |
| |
| class BooleanType(AtomicType, metaclass=DataTypeSingleton): |
| """Boolean data type.""" |
| |
| pass |
| |
| |
| class DateType(AtomicType, metaclass=DataTypeSingleton): |
| """Date (datetime.date) data type.""" |
| |
| EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| def toInternal(self, d: datetime.date) -> int: |
| if d is not None: |
| return d.toordinal() - self.EPOCH_ORDINAL |
| |
| def fromInternal(self, v: int) -> datetime.date: |
| if v is not None: |
| return datetime.date.fromordinal(v + self.EPOCH_ORDINAL) |
| |
| |
| class TimestampType(AtomicType, metaclass=DataTypeSingleton): |
| """Timestamp (datetime.datetime) data type.""" |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| def toInternal(self, dt: datetime.datetime) -> int: |
| if dt is not None: |
| seconds = ( |
| calendar.timegm(dt.utctimetuple()) if dt.tzinfo else time.mktime(dt.timetuple()) |
| ) |
| return int(seconds) * 1000000 + dt.microsecond |
| |
| def fromInternal(self, ts: int) -> datetime.datetime: |
| if ts is not None: |
| # using int to avoid precision loss in float |
| return datetime.datetime.fromtimestamp(ts // 1000000).replace(microsecond=ts % 1000000) |
| |
| |
| class TimestampNTZType(AtomicType, metaclass=DataTypeSingleton): |
| """Timestamp (datetime.datetime) data type without timezone information.""" |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| @classmethod |
| def typeName(cls) -> str: |
| return "timestamp_ntz" |
| |
| def toInternal(self, dt: datetime.datetime) -> int: |
| if dt is not None: |
| seconds = calendar.timegm(dt.timetuple()) |
| return int(seconds) * 1000000 + dt.microsecond |
| |
| def fromInternal(self, ts: int) -> datetime.datetime: |
| if ts is not None: |
| # using int to avoid precision loss in float |
| return datetime.datetime.utcfromtimestamp(ts // 1000000).replace( |
| microsecond=ts % 1000000 |
| ) |
| |
| |
| class DecimalType(FractionalType): |
| """Decimal (decimal.Decimal) data type. |
| |
| The DecimalType must have fixed precision (the maximum total number of digits) |
| and scale (the number of digits on the right of dot). For example, (5, 2) can |
| support the value from [-999.99 to 999.99]. |
| |
| The precision can be up to 38, the scale must be less or equal to precision. |
| |
| When creating a DecimalType, the default precision and scale is (10, 0). When inferring |
| schema from decimal.Decimal objects, it will be DecimalType(38, 18). |
| |
| Parameters |
| ---------- |
| precision : int, optional |
| the maximum (i.e. total) number of digits (default: 10) |
| scale : int, optional |
| the number of digits on right side of dot. (default: 0) |
| """ |
| |
| def __init__(self, precision: int = 10, scale: int = 0): |
| self.precision = precision |
| self.scale = scale |
| self.hasPrecisionInfo = True # this is a public API |
| |
| def simpleString(self) -> str: |
| return "decimal(%d,%d)" % (self.precision, self.scale) |
| |
| def jsonValue(self) -> str: |
| return "decimal(%d,%d)" % (self.precision, self.scale) |
| |
| def __repr__(self) -> str: |
| return "DecimalType(%d,%d)" % (self.precision, self.scale) |
| |
| |
| class DoubleType(FractionalType, metaclass=DataTypeSingleton): |
| """Double data type, representing double precision floats.""" |
| |
| pass |
| |
| |
| class FloatType(FractionalType, metaclass=DataTypeSingleton): |
| """Float data type, representing single precision floats.""" |
| |
| pass |
| |
| |
| class ByteType(IntegralType): |
| """Byte data type, representing signed 8-bit integers.""" |
| |
| def simpleString(self) -> str: |
| return "tinyint" |
| |
| |
| class IntegerType(IntegralType): |
| """Int data type, representing signed 32-bit integers.""" |
| |
| def simpleString(self) -> str: |
| return "int" |
| |
| |
| class LongType(IntegralType): |
| """Long data type, representing signed 64-bit integers. |
| |
| If the values are beyond the range of [-9223372036854775808, 9223372036854775807], |
| please use :class:`DecimalType`. |
| """ |
| |
| def simpleString(self) -> str: |
| return "bigint" |
| |
| |
| class ShortType(IntegralType): |
| """Short data type, representing signed 16-bit integers.""" |
| |
| def simpleString(self) -> str: |
| return "smallint" |
| |
| |
| class AnsiIntervalType(AtomicType): |
| """The interval type which conforms to the ANSI SQL standard.""" |
| |
| pass |
| |
| |
| class DayTimeIntervalType(AnsiIntervalType): |
| """DayTimeIntervalType (datetime.timedelta).""" |
| |
| DAY = 0 |
| HOUR = 1 |
| MINUTE = 2 |
| SECOND = 3 |
| |
| _fields = { |
| DAY: "day", |
| HOUR: "hour", |
| MINUTE: "minute", |
| SECOND: "second", |
| } |
| |
| _inverted_fields = dict(zip(_fields.values(), _fields.keys())) |
| |
| def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): |
| if startField is None and endField is None: |
| # Default matched to scala side. |
| startField = DayTimeIntervalType.DAY |
| endField = DayTimeIntervalType.SECOND |
| elif startField is not None and endField is None: |
| endField = startField |
| |
| fields = DayTimeIntervalType._fields |
| if startField not in fields.keys() or endField not in fields.keys(): |
| raise PySparkRuntimeError( |
| error_class="INVALID_INTERVAL_CASTING", |
| message_parameters={"start_field": str(startField), "end_field": str(endField)}, |
| ) |
| self.startField = startField |
| self.endField = endField |
| |
| def _str_repr(self) -> str: |
| fields = DayTimeIntervalType._fields |
| start_field_name = fields[self.startField] |
| end_field_name = fields[self.endField] |
| if start_field_name == end_field_name: |
| return "interval %s" % start_field_name |
| else: |
| return "interval %s to %s" % (start_field_name, end_field_name) |
| |
| simpleString = _str_repr |
| |
| jsonValue = _str_repr |
| |
| def __repr__(self) -> str: |
| return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| def toInternal(self, dt: datetime.timedelta) -> Optional[int]: |
| if dt is not None: |
| return (((dt.days * 86400) + dt.seconds) * 1_000_000) + dt.microseconds |
| |
| def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: |
| if micros is not None: |
| return datetime.timedelta(microseconds=micros) |
| |
| |
| class YearMonthIntervalType(AnsiIntervalType): |
| """YearMonthIntervalType, represents year-month intervals of the SQL standard""" |
| |
| YEAR = 0 |
| MONTH = 1 |
| |
| _fields = { |
| YEAR: "year", |
| MONTH: "month", |
| } |
| |
| _inverted_fields = dict(zip(_fields.values(), _fields.keys())) |
| |
| def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None): |
| if startField is None and endField is None: |
| # Default matched to scala side. |
| startField = YearMonthIntervalType.YEAR |
| endField = YearMonthIntervalType.MONTH |
| elif startField is not None and endField is None: |
| endField = startField |
| |
| fields = YearMonthIntervalType._fields |
| if startField not in fields.keys() or endField not in fields.keys(): |
| raise PySparkRuntimeError( |
| error_class="INVALID_INTERVAL_CASTING", |
| message_parameters={"start_field": str(startField), "end_field": str(endField)}, |
| ) |
| self.startField = startField |
| self.endField = endField |
| |
| def _str_repr(self) -> str: |
| fields = YearMonthIntervalType._fields |
| start_field_name = fields[self.startField] |
| end_field_name = fields[self.endField] |
| if start_field_name == end_field_name: |
| return "interval %s" % start_field_name |
| else: |
| return "interval %s to %s" % (start_field_name, end_field_name) |
| |
| simpleString = _str_repr |
| |
| jsonValue = _str_repr |
| |
| def __repr__(self) -> str: |
| return "%s(%d, %d)" % (type(self).__name__, self.startField, self.endField) |
| |
| |
| class CalendarIntervalType(DataType, metaclass=DataTypeSingleton): |
| """The data type representing calendar intervals. |
| |
| The calendar interval is stored internally in three components: |
| - an integer value representing the number of `months` in this interval. |
| - an integer value representing the number of `days` in this interval. |
| - a long value representing the number of `microseconds` in this interval. |
| """ |
| |
| @classmethod |
| def typeName(cls) -> str: |
| return "interval" |
| |
| |
| class ArrayType(DataType): |
| """Array data type. |
| |
| Parameters |
| ---------- |
| elementType : :class:`DataType` |
| :class:`DataType` of each element in the array. |
| containsNull : bool, optional |
| whether the array can contain null (None) values. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType |
| |
| The below example demonstrates how to create class:`ArrayType`: |
| |
| >>> arr = ArrayType(StringType()) |
| |
| The array can contain null (None) values by default: |
| |
| >>> ArrayType(StringType()) == ArrayType(StringType(), True) |
| True |
| >>> ArrayType(StringType(), False) == ArrayType(StringType()) |
| False |
| """ |
| |
| def __init__(self, elementType: DataType, containsNull: bool = True): |
| assert isinstance(elementType, DataType), "elementType %s should be an instance of %s" % ( |
| elementType, |
| DataType, |
| ) |
| self.elementType = elementType |
| self.containsNull = containsNull |
| |
| def simpleString(self) -> str: |
| return "array<%s>" % self.elementType.simpleString() |
| |
| def _as_nullable(self) -> "ArrayType": |
| return ArrayType(self.elementType._as_nullable(), containsNull=True) |
| |
| def toNullable(self) -> "ArrayType": |
| """ |
| Returns the same data type but set all nullability fields are true |
| (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). |
| |
| .. versionadded:: 4.0.0 |
| |
| Returns |
| ------- |
| :class:`ArrayType` |
| |
| Examples |
| -------- |
| Example 1: Simple nullability conversion |
| |
| >>> ArrayType(IntegerType(), containsNull=False).toNullable() |
| ArrayType(IntegerType(), True) |
| |
| Example 2: Nested nullability conversion |
| |
| >>> ArrayType( |
| ... StructType([ |
| ... StructField("b", IntegerType(), nullable=False), |
| ... StructField("c", ArrayType(IntegerType(), containsNull=False)) |
| ... ]), |
| ... containsNull=False |
| ... ).toNullable() |
| ArrayType(StructType([StructField('b', IntegerType(), True), |
| StructField('c', ArrayType(IntegerType(), True), True)]), True) |
| """ |
| return self._as_nullable() |
| |
| def __repr__(self) -> str: |
| return "ArrayType(%s, %s)" % (self.elementType, str(self.containsNull)) |
| |
| def jsonValue(self) -> Dict[str, Any]: |
| return { |
| "type": self.typeName(), |
| "elementType": self.elementType.jsonValue(), |
| "containsNull": self.containsNull, |
| } |
| |
| @classmethod |
| def fromJson(cls, json: Dict[str, Any]) -> "ArrayType": |
| return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) |
| |
| def needConversion(self) -> bool: |
| return self.elementType.needConversion() |
| |
| def toInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: |
| if not self.needConversion(): |
| return obj |
| return obj and [self.elementType.toInternal(v) for v in obj] |
| |
| def fromInternal(self, obj: List[Optional[T]]) -> List[Optional[T]]: |
| if not self.needConversion(): |
| return obj |
| return obj and [self.elementType.fromInternal(v) for v in obj] |
| |
| |
| class MapType(DataType): |
| """Map data type. |
| |
| Parameters |
| ---------- |
| keyType : :class:`DataType` |
| :class:`DataType` of the keys in the map. |
| valueType : :class:`DataType` |
| :class:`DataType` of the values in the map. |
| valueContainsNull : bool, optional |
| indicates whether values can contain null (None) values. |
| |
| Notes |
| ----- |
| Keys in a map data type are not allowed to be null (None). |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import IntegerType, FloatType, MapType, StringType |
| |
| The below example demonstrates how to create class:`MapType`: |
| |
| >>> map_type = MapType(StringType(), IntegerType()) |
| |
| The values of the map can contain null (``None``) values by default: |
| |
| >>> (MapType(StringType(), IntegerType()) |
| ... == MapType(StringType(), IntegerType(), True)) |
| True |
| >>> (MapType(StringType(), IntegerType(), False) |
| ... == MapType(StringType(), FloatType())) |
| False |
| """ |
| |
| def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): |
| assert isinstance(keyType, DataType), "keyType %s should be an instance of %s" % ( |
| keyType, |
| DataType, |
| ) |
| assert isinstance(valueType, DataType), "valueType %s should be an instance of %s" % ( |
| valueType, |
| DataType, |
| ) |
| self.keyType = keyType |
| self.valueType = valueType |
| self.valueContainsNull = valueContainsNull |
| |
| def simpleString(self) -> str: |
| return "map<%s,%s>" % (self.keyType.simpleString(), self.valueType.simpleString()) |
| |
| def _as_nullable(self) -> "MapType": |
| return MapType( |
| self.keyType._as_nullable(), self.valueType._as_nullable(), valueContainsNull=True |
| ) |
| |
| def toNullable(self) -> "MapType": |
| """ |
| Returns the same data type but set all nullability fields are true |
| (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). |
| |
| .. versionadded:: 4.0.0 |
| |
| Returns |
| ------- |
| :class:`MapType` |
| |
| Examples |
| -------- |
| Example 1: Simple nullability conversion |
| |
| >>> MapType(IntegerType(), StringType(), valueContainsNull=False).toNullable() |
| MapType(IntegerType(), StringType(), True) |
| |
| Example 2: Nested nullability conversion |
| |
| >>> MapType( |
| ... StringType(), |
| ... MapType( |
| ... IntegerType(), |
| ... ArrayType(IntegerType(), containsNull=False), |
| ... valueContainsNull=False |
| ... ), |
| ... valueContainsNull=False |
| ... ).toNullable() |
| MapType(StringType(), MapType(IntegerType(), ArrayType(IntegerType(), True), True), True) |
| """ |
| return self._as_nullable() |
| |
| def __repr__(self) -> str: |
| return "MapType(%s, %s, %s)" % (self.keyType, self.valueType, str(self.valueContainsNull)) |
| |
| def jsonValue(self) -> Dict[str, Any]: |
| return { |
| "type": self.typeName(), |
| "keyType": self.keyType.jsonValue(), |
| "valueType": self.valueType.jsonValue(), |
| "valueContainsNull": self.valueContainsNull, |
| } |
| |
| @classmethod |
| def fromJson(cls, json: Dict[str, Any]) -> "MapType": |
| return MapType( |
| _parse_datatype_json_value(json["keyType"]), |
| _parse_datatype_json_value(json["valueType"]), |
| json["valueContainsNull"], |
| ) |
| |
| def needConversion(self) -> bool: |
| return self.keyType.needConversion() or self.valueType.needConversion() |
| |
| def toInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: |
| if not self.needConversion(): |
| return obj |
| return obj and dict( |
| (self.keyType.toInternal(k), self.valueType.toInternal(v)) for k, v in obj.items() |
| ) |
| |
| def fromInternal(self, obj: Dict[T, Optional[U]]) -> Dict[T, Optional[U]]: |
| if not self.needConversion(): |
| return obj |
| return obj and dict( |
| (self.keyType.fromInternal(k), self.valueType.fromInternal(v)) for k, v in obj.items() |
| ) |
| |
| |
| class StructField(DataType): |
| """A field in :class:`StructType`. |
| |
| Parameters |
| ---------- |
| name : str |
| name of the field. |
| dataType : :class:`DataType` |
| :class:`DataType` of the field. |
| nullable : bool, optional |
| whether the field can be null (None) or not. |
| metadata : dict, optional |
| a dict from string to simple type that can be toInternald to JSON automatically |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import StringType, StructField |
| >>> (StructField("f1", StringType(), True) |
| ... == StructField("f1", StringType(), True)) |
| True |
| >>> (StructField("f1", StringType(), True) |
| ... == StructField("f2", StringType(), True)) |
| False |
| """ |
| |
| def __init__( |
| self, |
| name: str, |
| dataType: DataType, |
| nullable: bool = True, |
| metadata: Optional[Dict[str, Any]] = None, |
| ): |
| assert isinstance(dataType, DataType), "dataType %s should be an instance of %s" % ( |
| dataType, |
| DataType, |
| ) |
| assert isinstance(name, str), "field name %s should be a string" % (name) |
| self.name = name |
| self.dataType = dataType |
| self.nullable = nullable |
| self.metadata = metadata or {} |
| |
| def simpleString(self) -> str: |
| return "%s:%s" % (self.name, self.dataType.simpleString()) |
| |
| def __repr__(self) -> str: |
| return "StructField('%s', %s, %s)" % (self.name, self.dataType, str(self.nullable)) |
| |
| def jsonValue(self) -> Dict[str, Any]: |
| return { |
| "name": self.name, |
| "type": self.dataType.jsonValue(), |
| "nullable": self.nullable, |
| "metadata": self.metadata, |
| } |
| |
| @classmethod |
| def fromJson(cls, json: Dict[str, Any]) -> "StructField": |
| return StructField( |
| json["name"], |
| _parse_datatype_json_value(json["type"]), |
| json.get("nullable", True), |
| json.get("metadata"), |
| ) |
| |
| def needConversion(self) -> bool: |
| return self.dataType.needConversion() |
| |
| def toInternal(self, obj: T) -> T: |
| return self.dataType.toInternal(obj) |
| |
| def fromInternal(self, obj: T) -> T: |
| return self.dataType.fromInternal(obj) |
| |
| def typeName(self) -> str: # type: ignore[override] |
| raise PySparkTypeError( |
| error_class="INVALID_TYPENAME_CALL", |
| message_parameters={}, |
| ) |
| |
| |
| class StructType(DataType): |
| """Struct type, consisting of a list of :class:`StructField`. |
| |
| This is the data type representing a :class:`Row`. |
| |
| Iterating a :class:`StructType` will iterate over its :class:`StructField`\\s. |
| A contained :class:`StructField` can be accessed by its name or position. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import * |
| >>> struct1 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct1["f1"] |
| StructField('f1', StringType(), True) |
| >>> struct1[0] |
| StructField('f1', StringType(), True) |
| |
| >>> struct1 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct2 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct1 == struct2 |
| True |
| >>> struct1 = StructType([StructField("f1", CharType(10), True)]) |
| >>> struct2 = StructType([StructField("f1", CharType(10), True)]) |
| >>> struct1 == struct2 |
| True |
| >>> struct1 = StructType([StructField("f1", VarcharType(10), True)]) |
| >>> struct2 = StructType([StructField("f1", VarcharType(10), True)]) |
| >>> struct1 == struct2 |
| True |
| >>> struct1 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct2 = StructType([StructField("f1", StringType(), True), |
| ... StructField("f2", IntegerType(), False)]) |
| >>> struct1 == struct2 |
| False |
| |
| The below example demonstrates how to create a DataFrame based on a struct created |
| using class:`StructType` and class:`StructField`: |
| |
| >>> data = [("Alice", ["Java", "Scala"]), ("Bob", ["Python", "Scala"])] |
| >>> schema = StructType([ |
| ... StructField("name", StringType()), |
| ... StructField("languagesSkills", ArrayType(StringType())), |
| ... ]) |
| >>> df = spark.createDataFrame(data=data, schema=schema) |
| >>> df.printSchema() |
| root |
| |-- name: string (nullable = true) |
| |-- languagesSkills: array (nullable = true) |
| | |-- element: string (containsNull = true) |
| >>> df.show() |
| +-----+---------------+ |
| | name|languagesSkills| |
| +-----+---------------+ |
| |Alice| [Java, Scala]| |
| | Bob|[Python, Scala]| |
| +-----+---------------+ |
| """ |
| |
| def __init__(self, fields: Optional[List[StructField]] = None): |
| if not fields: |
| self.fields = [] |
| self.names = [] |
| else: |
| self.fields = fields |
| self.names = [f.name for f in fields] |
| assert all( |
| isinstance(f, StructField) for f in fields |
| ), "fields should be a list of StructField" |
| # Precalculated list of fields that need conversion with fromInternal/toInternal functions |
| self._needConversion = [f.needConversion() for f in self] |
| self._needSerializeAnyField = any(self._needConversion) |
| |
| @overload |
| def add( |
| self, |
| field: str, |
| data_type: Union[str, DataType], |
| nullable: bool = True, |
| metadata: Optional[Dict[str, Any]] = None, |
| ) -> "StructType": |
| ... |
| |
| @overload |
| def add(self, field: StructField) -> "StructType": |
| ... |
| |
| def add( |
| self, |
| field: Union[str, StructField], |
| data_type: Optional[Union[str, DataType]] = None, |
| nullable: bool = True, |
| metadata: Optional[Dict[str, Any]] = None, |
| ) -> "StructType": |
| """ |
| Construct a :class:`StructType` by adding new elements to it, to define the schema. |
| The method accepts either: |
| |
| a) A single parameter which is a :class:`StructField` object. |
| b) Between 2 and 4 parameters as (name, data_type, nullable (optional), |
| metadata(optional). The data_type parameter may be either a String or a |
| :class:`DataType` object. |
| |
| Parameters |
| ---------- |
| field : str or :class:`StructField` |
| Either the name of the field or a :class:`StructField` object |
| data_type : :class:`DataType`, optional |
| If present, the DataType of the :class:`StructField` to create |
| nullable : bool, optional |
| Whether the field to add should be nullable (default True) |
| metadata : dict, optional |
| Any additional metadata (default None) |
| |
| Returns |
| ------- |
| :class:`StructType` |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import IntegerType, StringType, StructField, StructType |
| >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) |
| >>> struct2 = StructType([StructField("f1", StringType(), True), |
| ... StructField("f2", StringType(), True, None)]) |
| >>> struct1 == struct2 |
| True |
| >>> struct1 = StructType().add(StructField("f1", StringType(), True)) |
| >>> struct2 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct1 == struct2 |
| True |
| >>> struct1 = StructType().add("f1", "string", True) |
| >>> struct2 = StructType([StructField("f1", StringType(), True)]) |
| >>> struct1 == struct2 |
| True |
| """ |
| if isinstance(field, StructField): |
| self.fields.append(field) |
| self.names.append(field.name) |
| else: |
| if isinstance(field, str) and data_type is None: |
| raise PySparkValueError( |
| error_class="ARGUMENT_REQUIRED", |
| message_parameters={ |
| "arg_name": "data_type", |
| "condition": "passing name of struct_field to create", |
| }, |
| ) |
| |
| if isinstance(data_type, str): |
| data_type_f = _parse_datatype_json_value(data_type) |
| else: |
| data_type_f = data_type |
| self.fields.append(StructField(field, data_type_f, nullable, metadata)) |
| self.names.append(field) |
| # Precalculated list of fields that need conversion with fromInternal/toInternal functions |
| self._needConversion = [f.needConversion() for f in self] |
| self._needSerializeAnyField = any(self._needConversion) |
| return self |
| |
| def __iter__(self) -> Iterator[StructField]: |
| """Iterate the fields""" |
| return iter(self.fields) |
| |
| def __len__(self) -> int: |
| """Return the number of fields.""" |
| return len(self.fields) |
| |
| def __getitem__(self, key: Union[str, int]) -> StructField: |
| """Access fields by name or slice.""" |
| if isinstance(key, str): |
| for field in self: |
| if field.name == key: |
| return field |
| raise PySparkKeyError( |
| error_class="KEY_NOT_EXISTS", message_parameters={"key": str(key)} |
| ) |
| elif isinstance(key, int): |
| try: |
| return self.fields[key] |
| except IndexError: |
| raise PySparkIndexError( |
| error_class="INDEX_OUT_OF_RANGE", |
| message_parameters={"arg_name": "StructType", "index": str(key)}, |
| ) |
| elif isinstance(key, slice): |
| return StructType(self.fields[key]) |
| else: |
| raise PySparkTypeError( |
| error_class="NOT_INT_OR_SLICE_OR_STR", |
| message_parameters={"arg_name": "key", "arg_type": type(key).__name__}, |
| ) |
| |
| def simpleString(self) -> str: |
| return "struct<%s>" % (",".join(f.simpleString() for f in self)) |
| |
| def _as_nullable(self) -> "StructType": |
| fields = [] |
| for field in self.fields: |
| fields.append( |
| StructField( |
| field.name, |
| field.dataType._as_nullable(), |
| nullable=True, |
| metadata=field.metadata, |
| ) |
| ) |
| return StructType(fields) |
| |
| def toNullable(self) -> "StructType": |
| """ |
| Returns the same data type but set all nullability fields are true |
| (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). |
| |
| .. versionadded:: 4.0.0 |
| |
| Returns |
| ------- |
| :class:`StructType` |
| |
| Examples |
| -------- |
| Example 1: Simple nullability conversion |
| |
| >>> StructType([StructField("a", IntegerType(), nullable=False)]).toNullable() |
| StructType([StructField('a', IntegerType(), True)]) |
| |
| Example 2: Nested nullability conversion |
| |
| >>> StructType([ |
| ... StructField("a", |
| ... StructType([ |
| ... StructField("b", IntegerType(), nullable=False), |
| ... StructField("c", StructType([ |
| ... StructField("d", IntegerType(), nullable=False) |
| ... ])) |
| ... ]), |
| ... nullable=False) |
| ... ]).toNullable() |
| StructType([StructField('a', StructType([StructField('b', IntegerType(), True), |
| StructField('c', StructType([StructField('d', IntegerType(), True)]), True)]), True)]) |
| """ |
| return self._as_nullable() |
| |
| def __repr__(self) -> str: |
| return "StructType([%s])" % ", ".join(str(field) for field in self) |
| |
| def jsonValue(self) -> Dict[str, Any]: |
| return {"type": self.typeName(), "fields": [f.jsonValue() for f in self]} |
| |
| @classmethod |
| def fromJson(cls, json: Dict[str, Any]) -> "StructType": |
| """ |
| Constructs :class:`StructType` from a schema defined in JSON format. |
| |
| Below is a JSON schema it must adhere to:: |
| |
| { |
| "title":"StructType", |
| "description":"Schema of StructType in json format", |
| "type":"object", |
| "properties":{ |
| "fields":{ |
| "description":"Array of struct fields", |
| "type":"array", |
| "items":{ |
| "type":"object", |
| "properties":{ |
| "name":{ |
| "description":"Name of the field", |
| "type":"string" |
| }, |
| "type":{ |
| "description": "Type of the field. Can either be |
| another nested StructType or primitive type", |
| "type":"object/string" |
| }, |
| "nullable":{ |
| "description":"If nulls are allowed", |
| "type":"boolean" |
| }, |
| "metadata":{ |
| "description":"Additional metadata to supply", |
| "type":"object" |
| }, |
| "required":[ |
| "name", |
| "type", |
| "nullable", |
| "metadata" |
| ] |
| } |
| } |
| } |
| } |
| } |
| |
| Parameters |
| ---------- |
| json : dict or a dict-like object e.g. JSON object |
| This "dict" must have "fields" key that returns an array of fields |
| each of which must have specific keys (name, type, nullable, metadata). |
| |
| Returns |
| ------- |
| :class:`StructType` |
| |
| Examples |
| -------- |
| >>> json_str = ''' |
| ... { |
| ... "fields": [ |
| ... { |
| ... "metadata": {}, |
| ... "name": "Person", |
| ... "nullable": true, |
| ... "type": { |
| ... "fields": [ |
| ... { |
| ... "metadata": {}, |
| ... "name": "name", |
| ... "nullable": false, |
| ... "type": "string" |
| ... }, |
| ... { |
| ... "metadata": {}, |
| ... "name": "surname", |
| ... "nullable": false, |
| ... "type": "string" |
| ... } |
| ... ], |
| ... "type": "struct" |
| ... } |
| ... } |
| ... ], |
| ... "type": "struct" |
| ... } |
| ... ''' |
| >>> import json |
| >>> scheme = StructType.fromJson(json.loads(json_str)) |
| >>> scheme.simpleString() |
| 'struct<Person:struct<name:string,surname:string>>' |
| """ |
| return StructType([StructField.fromJson(f) for f in json["fields"]]) |
| |
| def fieldNames(self) -> List[str]: |
| """ |
| Returns all field names in a list. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.types import StringType, StructField, StructType |
| >>> struct = StructType([StructField("f1", StringType(), True)]) |
| >>> struct.fieldNames() |
| ['f1'] |
| """ |
| return list(self.names) |
| |
| def needConversion(self) -> bool: |
| # We need convert Row()/namedtuple into tuple() |
| return True |
| |
| def toInternal(self, obj: Tuple) -> Tuple: |
| if obj is None: |
| return |
| |
| if self._needSerializeAnyField: |
| # Only calling toInternal function for fields that need conversion |
| if isinstance(obj, dict): |
| return tuple( |
| f.toInternal(obj.get(n)) if c else obj.get(n) |
| for n, f, c in zip(self.names, self.fields, self._needConversion) |
| ) |
| elif isinstance(obj, (tuple, list)): |
| return tuple( |
| f.toInternal(v) if c else v |
| for f, v, c in zip(self.fields, obj, self._needConversion) |
| ) |
| elif hasattr(obj, "__dict__"): |
| d = obj.__dict__ |
| return tuple( |
| f.toInternal(d.get(n)) if c else d.get(n) |
| for n, f, c in zip(self.names, self.fields, self._needConversion) |
| ) |
| else: |
| raise PySparkValueError( |
| error_class="UNEXPECTED_TUPLE_WITH_STRUCT", |
| message_parameters={"tuple": str(obj)}, |
| ) |
| else: |
| if isinstance(obj, dict): |
| return tuple(obj.get(n) for n in self.names) |
| elif isinstance(obj, (list, tuple)): |
| return tuple(obj) |
| elif hasattr(obj, "__dict__"): |
| d = obj.__dict__ |
| return tuple(d.get(n) for n in self.names) |
| else: |
| raise PySparkValueError( |
| error_class="UNEXPECTED_TUPLE_WITH_STRUCT", |
| message_parameters={"tuple": str(obj)}, |
| ) |
| |
| def fromInternal(self, obj: Tuple) -> "Row": |
| if obj is None: |
| return |
| if isinstance(obj, Row): |
| # it's already converted by pickler |
| return obj |
| |
| values: Union[Tuple, List] |
| if self._needSerializeAnyField: |
| # Only calling fromInternal function for fields that need conversion |
| values = [ |
| f.fromInternal(v) if c else v |
| for f, v, c in zip(self.fields, obj, self._needConversion) |
| ] |
| else: |
| values = obj |
| return _create_row(self.names, values) |
| |
| |
| class VariantType(AtomicType): |
| """ |
| Variant data type, representing semi-structured values. |
| |
| .. versionadded:: 4.0.0 |
| """ |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| def fromInternal(self, obj: Dict) -> Optional["VariantVal"]: |
| if obj is None or not all(key in obj for key in ["value", "metadata"]): |
| return None |
| return VariantVal(obj["value"], obj["metadata"]) |
| |
| |
| class UserDefinedType(DataType): |
| """User-defined type (UDT). |
| |
| .. note:: WARN: Spark Internal Use Only |
| """ |
| |
| @classmethod |
| def typeName(cls) -> str: |
| return cls.__name__.lower() |
| |
| @classmethod |
| def sqlType(cls) -> DataType: |
| """ |
| Underlying SQL storage type for this UDT. |
| """ |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "sqlType()"}, |
| ) |
| |
| @classmethod |
| def module(cls) -> str: |
| """ |
| The Python module of the UDT. |
| """ |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "module()"}, |
| ) |
| |
| @classmethod |
| def scalaUDT(cls) -> str: |
| """ |
| The class name of the paired Scala UDT (could be '', if there |
| is no corresponding one). |
| """ |
| return "" |
| |
| def needConversion(self) -> bool: |
| return True |
| |
| @classmethod |
| def _cachedSqlType(cls) -> DataType: |
| """ |
| Cache the sqlType() into class, because it's heavily used in `toInternal`. |
| """ |
| if not hasattr(cls, "_cached_sql_type"): |
| cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined] |
| return cls._cached_sql_type # type: ignore[attr-defined] |
| |
| def toInternal(self, obj: Any) -> Any: |
| if obj is not None: |
| return self._cachedSqlType().toInternal(self.serialize(obj)) |
| |
| def fromInternal(self, obj: Any) -> Any: |
| v = self._cachedSqlType().fromInternal(obj) |
| if v is not None: |
| return self.deserialize(v) |
| |
| def serialize(self, obj: Any) -> Any: |
| """ |
| Converts a user-type object into a SQL datum. |
| """ |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "toInternal()"}, |
| ) |
| |
| def deserialize(self, datum: Any) -> Any: |
| """ |
| Converts a SQL datum into a user-type object. |
| """ |
| raise PySparkNotImplementedError( |
| error_class="NOT_IMPLEMENTED", |
| message_parameters={"feature": "fromInternal()"}, |
| ) |
| |
| def simpleString(self) -> str: |
| return "udt" |
| |
| def json(self) -> str: |
| return json.dumps(self.jsonValue(), separators=(",", ":"), sort_keys=True) |
| |
| def jsonValue(self) -> Dict[str, Any]: |
| if self.scalaUDT(): |
| assert self.module() != "__main__", "UDT in __main__ cannot work with ScalaUDT" |
| schema = { |
| "type": "udt", |
| "class": self.scalaUDT(), |
| "pyClass": "%s.%s" % (self.module(), type(self).__name__), |
| "sqlType": self.sqlType().jsonValue(), |
| } |
| else: |
| ser = CloudPickleSerializer() |
| b = ser.dumps(type(self)) |
| schema = { |
| "type": "udt", |
| "pyClass": "%s.%s" % (self.module(), type(self).__name__), |
| "serializedClass": base64.b64encode(b).decode("utf8"), |
| "sqlType": self.sqlType().jsonValue(), |
| } |
| return schema |
| |
| @classmethod |
| def fromJson(cls, json: Dict[str, Any]) -> "UserDefinedType": |
| pyUDT = str(json["pyClass"]) # convert unicode to str |
| split = pyUDT.rfind(".") |
| pyModule = pyUDT[:split] |
| pyClass = pyUDT[split + 1 :] |
| m = __import__(pyModule, globals(), locals(), [pyClass]) |
| if not hasattr(m, pyClass): |
| s = base64.b64decode(json["serializedClass"].encode("utf-8")) |
| UDT = CloudPickleSerializer().loads(s) |
| else: |
| UDT = getattr(m, pyClass) |
| return UDT() |
| |
| def __eq__(self, other: Any) -> bool: |
| return type(self) == type(other) |
| |
| |
| class VariantVal: |
| """ |
| A class to represent a Variant value in Python. |
| |
| .. versionadded:: 4.0.0 |
| |
| Parameters |
| ---------- |
| value : bytes |
| The bytes representing the value component of the Variant. |
| metadata : bytes |
| The bytes representing the metadata component of the Variant. |
| |
| Methods |
| ------- |
| toPython() |
| Convert the VariantVal to a Python data structure. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql.functions import * |
| >>> df = spark.createDataFrame([ {'json': '''{ "a" : 1 }'''} ]) |
| >>> v = df.select(parse_json(df.json).alias("var")).collect()[0].var |
| >>> v.toPython() |
| {'a': 1} |
| """ |
| |
| def __init__(self, value: bytes, metadata: bytes): |
| self.value = value |
| self.metadata = metadata |
| |
| def __str__(self) -> str: |
| return VariantUtils.to_json(self.value, self.metadata) |
| |
| def __repr__(self) -> str: |
| return "VariantVal(%r, %r)" % (self.value, self.metadata) |
| |
| def toPython(self) -> Any: |
| """ |
| Convert the VariantVal to a Python data structure. |
| |
| Returns |
| ------- |
| Any |
| A Python object that represents the Variant. |
| """ |
| return VariantUtils.to_python(self.value, self.metadata) |
| |
| def toJson(self, zone_id: str = "UTC") -> str: |
| """ |
| Convert the VariantVal to a JSON string. The zone ID represents the time zone that the |
| timestamp should be printed in. It is defaulted to UTC. The list of valid zone IDs can be |
| found by importing the `zoneinfo` module and running :code:`zoneinfo.available_timezones()`. |
| |
| Returns |
| ------- |
| str |
| A JSON string that represents the Variant. |
| """ |
| return VariantUtils.to_json(self.value, self.metadata, zone_id) |
| |
| |
| _atomic_types: List[Type[DataType]] = [ |
| StringType, |
| CharType, |
| VarcharType, |
| BinaryType, |
| BooleanType, |
| DecimalType, |
| FloatType, |
| DoubleType, |
| ByteType, |
| ShortType, |
| IntegerType, |
| LongType, |
| DateType, |
| TimestampType, |
| TimestampNTZType, |
| NullType, |
| VariantType, |
| ] |
| _all_atomic_types: Dict[str, Type[DataType]] = dict((t.typeName(), t) for t in _atomic_types) |
| |
| _complex_types: List[Type[Union[ArrayType, MapType, StructType]]] = [ArrayType, MapType, StructType] |
| _all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = dict( |
| (v.typeName(), v) for v in _complex_types |
| ) |
| |
| _COLLATED_STRING = re.compile(r"string\s+collate\s+([\w_]+|`[\w_]`)") |
| _LENGTH_CHAR = re.compile(r"char\(\s*(\d+)\s*\)") |
| _LENGTH_VARCHAR = re.compile(r"varchar\(\s*(\d+)\s*\)") |
| _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") |
| _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") |
| _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") |
| |
| |
| def _drop_metadata(d: Union[DataType, StructField]) -> Union[DataType, StructField]: |
| assert isinstance(d, (DataType, StructField)) |
| if isinstance(d, StructField): |
| return StructField(d.name, _drop_metadata(d.dataType), d.nullable, None) |
| elif isinstance(d, StructType): |
| return StructType([cast(StructField, _drop_metadata(f)) for f in d.fields]) |
| elif isinstance(d, ArrayType): |
| return ArrayType(_drop_metadata(d.elementType), d.containsNull) |
| elif isinstance(d, MapType): |
| return MapType(_drop_metadata(d.keyType), _drop_metadata(d.valueType), d.valueContainsNull) |
| return d |
| |
| |
| def _parse_datatype_string(s: str) -> DataType: |
| """ |
| Parses the given data type string to a :class:`DataType`. The data type string format equals |
| :class:`DataType.simpleString`, except that the top level struct type can omit |
| the ``struct<>``. Since Spark 2.3, this also supports a schema in a DDL-formatted |
| string and case-insensitive strings. |
| |
| Examples |
| -------- |
| >>> _parse_datatype_string("int ") |
| IntegerType() |
| >>> _parse_datatype_string("INT ") |
| IntegerType() |
| >>> _parse_datatype_string("a: byte, b: decimal( 16 , 8 ) ") |
| StructType([StructField('a', ByteType(), True), StructField('b', DecimalType(16,8), True)]) |
| >>> _parse_datatype_string("a DOUBLE, b STRING") |
| StructType([StructField('a', DoubleType(), True), StructField('b', StringType(), True)]) |
| >>> _parse_datatype_string("a DOUBLE, b CHAR( 50 )") |
| StructType([StructField('a', DoubleType(), True), StructField('b', CharType(50), True)]) |
| >>> _parse_datatype_string("a DOUBLE, b VARCHAR( 50 )") |
| StructType([StructField('a', DoubleType(), True), StructField('b', VarcharType(50), True)]) |
| >>> _parse_datatype_string("a: array< short>") |
| StructType([StructField('a', ArrayType(ShortType(), True), True)]) |
| >>> _parse_datatype_string(" map<string , string > ") |
| MapType(StringType(), StringType(), True) |
| |
| >>> # Error cases |
| >>> _parse_datatype_string("blabla") # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| ParseException:... |
| >>> _parse_datatype_string("a: int,") # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| ParseException:... |
| >>> _parse_datatype_string("array<int") # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| ParseException:... |
| >>> _parse_datatype_string("map<int, boolean>>") # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| ParseException:... |
| """ |
| from py4j.java_gateway import JVMView |
| |
| sc = get_active_spark_context() |
| |
| def from_ddl_schema(type_str: str) -> DataType: |
| return _parse_datatype_json_string( |
| cast(JVMView, sc._jvm).org.apache.spark.sql.types.StructType.fromDDL(type_str).json() |
| ) |
| |
| def from_ddl_datatype(type_str: str) -> DataType: |
| return _parse_datatype_json_string( |
| cast(JVMView, sc._jvm) |
| .org.apache.spark.sql.api.python.PythonSQLUtils.parseDataType(type_str) |
| .json() |
| ) |
| |
| try: |
| # DDL format, "fieldname datatype, fieldname datatype". |
| return from_ddl_schema(s) |
| except Exception as e: |
| try: |
| # For backwards compatibility, "integer", "struct<fieldname: datatype>" and etc. |
| return from_ddl_datatype(s) |
| except BaseException: |
| try: |
| # For backwards compatibility, "fieldname: datatype, fieldname: datatype" case. |
| return from_ddl_datatype("struct<%s>" % s.strip()) |
| except BaseException: |
| raise e |
| |
| |
| def _parse_datatype_json_string(json_string: str) -> DataType: |
| """Parses the given data type JSON string. |
| |
| Examples |
| -------- |
| >>> import pickle |
| >>> def check_datatype(datatype): |
| ... pickled = pickle.loads(pickle.dumps(datatype)) |
| ... assert datatype == pickled |
| ... scala_datatype = spark._jsparkSession.parseDataType(datatype.json()) |
| ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) |
| ... assert datatype == python_datatype |
| ... |
| >>> for cls in _all_atomic_types.values(): |
| ... if cls is not VarcharType and cls is not CharType: |
| ... check_datatype(cls()) |
| ... else: |
| ... check_datatype(cls(1)) |
| |
| >>> # Simple ArrayType. |
| >>> simple_arraytype = ArrayType(StringType(), True) |
| >>> check_datatype(simple_arraytype) |
| |
| >>> # Simple MapType. |
| >>> simple_maptype = MapType(StringType(), LongType()) |
| >>> check_datatype(simple_maptype) |
| |
| >>> # Simple StructType. |
| >>> simple_structtype = StructType([ |
| ... StructField("a", DecimalType(), False), |
| ... StructField("b", BooleanType(), True), |
| ... StructField("c", LongType(), True), |
| ... StructField("d", BinaryType(), False)]) |
| >>> check_datatype(simple_structtype) |
| |
| >>> # Complex StructType. |
| >>> complex_structtype = StructType([ |
| ... StructField("simpleArray", simple_arraytype, True), |
| ... StructField("simpleMap", simple_maptype, True), |
| ... StructField("simpleStruct", simple_structtype, True), |
| ... StructField("boolean", BooleanType(), False), |
| ... StructField("chars", CharType(10), False), |
| ... StructField("words", VarcharType(10), False), |
| ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) |
| >>> check_datatype(complex_structtype) |
| |
| >>> # Complex ArrayType. |
| >>> complex_arraytype = ArrayType(complex_structtype, True) |
| >>> check_datatype(complex_arraytype) |
| |
| >>> # Complex MapType. |
| >>> complex_maptype = MapType(complex_structtype, |
| ... complex_arraytype, False) |
| >>> check_datatype(complex_maptype) |
| """ |
| return _parse_datatype_json_value(json.loads(json_string)) |
| |
| |
| def _parse_datatype_json_value(json_value: Union[dict, str]) -> DataType: |
| if not isinstance(json_value, dict): |
| if json_value in _all_atomic_types.keys(): |
| return _all_atomic_types[json_value]() |
| elif json_value == "decimal": |
| return DecimalType() |
| elif _FIXED_DECIMAL.match(json_value): |
| m = _FIXED_DECIMAL.match(json_value) |
| return DecimalType(int(m.group(1)), int(m.group(2))) # type: ignore[union-attr] |
| elif _INTERVAL_DAYTIME.match(json_value): |
| m = _INTERVAL_DAYTIME.match(json_value) |
| inverted_fields = DayTimeIntervalType._inverted_fields |
| first_field = inverted_fields.get(m.group(1)) # type: ignore[union-attr] |
| second_field = inverted_fields.get(m.group(3)) # type: ignore[union-attr] |
| if first_field is not None and second_field is None: |
| return DayTimeIntervalType(first_field) |
| return DayTimeIntervalType(first_field, second_field) |
| elif _INTERVAL_YEARMONTH.match(json_value): |
| m = _INTERVAL_YEARMONTH.match(json_value) |
| inverted_fields = YearMonthIntervalType._inverted_fields |
| first_field = inverted_fields.get(m.group(1)) # type: ignore[union-attr] |
| second_field = inverted_fields.get(m.group(3)) # type: ignore[union-attr] |
| if first_field is not None and second_field is None: |
| return YearMonthIntervalType(first_field) |
| return YearMonthIntervalType(first_field, second_field) |
| elif json_value == "interval": |
| return CalendarIntervalType() |
| elif _COLLATED_STRING.match(json_value): |
| m = _COLLATED_STRING.match(json_value) |
| return StringType(m.group(1)) # type: ignore[union-attr] |
| elif _LENGTH_CHAR.match(json_value): |
| m = _LENGTH_CHAR.match(json_value) |
| return CharType(int(m.group(1))) # type: ignore[union-attr] |
| elif _LENGTH_VARCHAR.match(json_value): |
| m = _LENGTH_VARCHAR.match(json_value) |
| return VarcharType(int(m.group(1))) # type: ignore[union-attr] |
| else: |
| raise PySparkValueError( |
| error_class="CANNOT_PARSE_DATATYPE", |
| message_parameters={"error": str(json_value)}, |
| ) |
| else: |
| tpe = json_value["type"] |
| if tpe in _all_complex_types: |
| return _all_complex_types[tpe].fromJson(json_value) |
| elif tpe == "udt": |
| return UserDefinedType.fromJson(json_value) |
| else: |
| raise PySparkValueError( |
| error_class="UNSUPPORTED_DATA_TYPE", |
| message_parameters={"data_type": str(tpe)}, |
| ) |
| |
| |
| # Mapping Python types to Spark SQL DataType |
| _type_mappings = { |
| type(None): NullType, |
| bool: BooleanType, |
| int: LongType, |
| float: DoubleType, |
| str: StringType, |
| bytearray: BinaryType, |
| decimal.Decimal: DecimalType, |
| datetime.date: DateType, |
| datetime.datetime: TimestampType, # can be TimestampNTZType |
| datetime.time: TimestampType, # can be TimestampNTZType |
| datetime.timedelta: DayTimeIntervalType, |
| bytes: BinaryType, |
| } |
| |
| # Mapping Python array types to Spark SQL DataType |
| # We should be careful here. The size of these types in python depends on C |
| # implementation. We need to make sure that this conversion does not lose any |
| # precision. Also, JVM only support signed types, when converting unsigned types, |
| # keep in mind that it require 1 more bit when stored as signed types. |
| # |
| # Reference for C integer size, see: |
| # ISO/IEC 9899:201x specification, chapter 5.2.4.2.1 Sizes of integer types <limits.h>. |
| # Reference for python array typecode, see: |
| # https://docs.python.org/2/library/array.html |
| # https://docs.python.org/3.6/library/array.html |
| # Reference for JVM's supported integral types: |
| # http://docs.oracle.com/javase/specs/jvms/se8/html/jvms-2.html#jvms-2.3.1 |
| |
| _array_signed_int_typecode_ctype_mappings = { |
| "b": ctypes.c_byte, |
| "h": ctypes.c_short, |
| "i": ctypes.c_int, |
| "l": ctypes.c_long, |
| } |
| |
| _array_unsigned_int_typecode_ctype_mappings = { |
| "B": ctypes.c_ubyte, |
| "H": ctypes.c_ushort, |
| "I": ctypes.c_uint, |
| "L": ctypes.c_ulong, |
| } |
| |
| |
| def _int_size_to_type( |
| size: int, |
| ) -> Optional[Union[Type[ByteType], Type[ShortType], Type[IntegerType], Type[LongType]]]: |
| """ |
| Return the Catalyst datatype from the size of integers. |
| """ |
| if size <= 8: |
| return ByteType |
| elif size <= 16: |
| return ShortType |
| elif size <= 32: |
| return IntegerType |
| elif size <= 64: |
| return LongType |
| else: |
| return None |
| |
| |
| # The list of all supported array typecodes, is stored here |
| _array_type_mappings: Dict[str, Type[DataType]] = { |
| # Warning: Actual properties for float and double in C is not specified in C. |
| # On almost every system supported by both python and JVM, they are IEEE 754 |
| # single-precision binary floating-point format and IEEE 754 double-precision |
| # binary floating-point format. And we do assume the same thing here for now. |
| "f": FloatType, |
| "d": DoubleType, |
| } |
| |
| # compute array typecode mappings for signed integer types |
| for _typecode in _array_signed_int_typecode_ctype_mappings.keys(): |
| size = ctypes.sizeof(_array_signed_int_typecode_ctype_mappings[_typecode]) * 8 |
| dt = _int_size_to_type(size) |
| if dt is not None: |
| _array_type_mappings[_typecode] = dt |
| |
| # compute array typecode mappings for unsigned integer types |
| for _typecode in _array_unsigned_int_typecode_ctype_mappings.keys(): |
| # JVM does not have unsigned types, so use signed types that is at least 1 |
| # bit larger to store |
| size = ctypes.sizeof(_array_unsigned_int_typecode_ctype_mappings[_typecode]) * 8 + 1 |
| dt = _int_size_to_type(size) |
| if dt is not None: |
| _array_type_mappings[_typecode] = dt |
| |
| # Type code 'u' in Python's array is deprecated since version 3.3, and will be |
| # removed in version 4.0. See: https://docs.python.org/3/library/array.html |
| if sys.version_info[0] < 4: |
| _array_type_mappings["u"] = StringType |
| |
| |
| def _from_numpy_type(nt: "np.dtype") -> Optional[DataType]: |
| """Convert NumPy type to Spark data type.""" |
| import numpy as np |
| |
| if nt == np.dtype("int8"): |
| return ByteType() |
| elif nt == np.dtype("int16"): |
| return ShortType() |
| elif nt == np.dtype("int32"): |
| return IntegerType() |
| elif nt == np.dtype("int64"): |
| return LongType() |
| elif nt == np.dtype("float32"): |
| return FloatType() |
| elif nt == np.dtype("float64"): |
| return DoubleType() |
| |
| return None |
| |
| |
| def _infer_type( |
| obj: Any, |
| infer_dict_as_struct: bool = False, |
| infer_array_from_first_element: bool = False, |
| prefer_timestamp_ntz: bool = False, |
| ) -> DataType: |
| """Infer the DataType from obj""" |
| if obj is None: |
| return NullType() |
| |
| if hasattr(obj, "__UDT__"): |
| return obj.__UDT__ |
| |
| dataType = _type_mappings.get(type(obj)) |
| if dataType is DecimalType: |
| # the precision and scale of `obj` may be different from row to row. |
| return DecimalType(38, 18) |
| if dataType is TimestampType and prefer_timestamp_ntz and obj.tzinfo is None: |
| return TimestampNTZType() |
| if dataType is DayTimeIntervalType: |
| return DayTimeIntervalType() |
| if dataType is YearMonthIntervalType: |
| return YearMonthIntervalType() |
| if dataType is CalendarIntervalType: |
| return CalendarIntervalType() |
| elif dataType is not None: |
| return dataType() |
| |
| if isinstance(obj, dict): |
| if infer_dict_as_struct: |
| struct = StructType() |
| for key, value in obj.items(): |
| if key is not None and value is not None: |
| struct.add( |
| key, |
| _infer_type( |
| value, |
| infer_dict_as_struct, |
| infer_array_from_first_element, |
| prefer_timestamp_ntz, |
| ), |
| True, |
| ) |
| return struct |
| else: |
| for key, value in obj.items(): |
| if key is not None and value is not None: |
| return MapType( |
| _infer_type( |
| key, |
| infer_dict_as_struct, |
| infer_array_from_first_element, |
| prefer_timestamp_ntz, |
| ), |
| _infer_type( |
| value, |
| infer_dict_as_struct, |
| infer_array_from_first_element, |
| prefer_timestamp_ntz, |
| ), |
| True, |
| ) |
| return MapType(NullType(), NullType(), True) |
| elif isinstance(obj, list): |
| if len(obj) > 0: |
| if infer_array_from_first_element: |
| return ArrayType( |
| _infer_type(obj[0], infer_dict_as_struct, prefer_timestamp_ntz), True |
| ) |
| else: |
| return ArrayType( |
| reduce( |
| _merge_type, |
| (_infer_type(v, infer_dict_as_struct, prefer_timestamp_ntz) for v in obj), |
| ), |
| True, |
| ) |
| return ArrayType(NullType(), True) |
| elif isinstance(obj, array): |
| if obj.typecode in _array_type_mappings: |
| return ArrayType(_array_type_mappings[obj.typecode](), False) |
| else: |
| raise PySparkTypeError( |
| error_class="UNSUPPORTED_DATA_TYPE", |
| message_parameters={"data_type": f"array({obj.typecode})"}, |
| ) |
| else: |
| try: |
| return _infer_schema( |
| obj, |
| infer_dict_as_struct=infer_dict_as_struct, |
| infer_array_from_first_element=infer_array_from_first_element, |
| ) |
| except TypeError: |
| raise PySparkTypeError( |
| error_class="UNSUPPORTED_DATA_TYPE", |
| message_parameters={"data_type": type(obj).__name__}, |
| ) |
| |
| |
| def _infer_schema( |
| row: Any, |
| names: Optional[List[str]] = None, |
| infer_dict_as_struct: bool = False, |
| infer_array_from_first_element: bool = False, |
| prefer_timestamp_ntz: bool = False, |
| ) -> StructType: |
| """Infer the schema from dict/namedtuple/object""" |
| items: Iterable[Tuple[str, Any]] |
| if isinstance(row, dict): |
| items = sorted(row.items()) |
| |
| elif isinstance(row, (tuple, list)): |
| if hasattr(row, "__fields__"): # Row |
| items = zip(row.__fields__, tuple(row)) |
| elif hasattr(row, "_fields"): # namedtuple |
| items = zip(row._fields, tuple(row)) |
| else: |
| if names is None: |
| names = ["_%d" % i for i in range(1, len(row) + 1)] |
| elif len(names) < len(row): |
| names.extend("_%d" % i for i in range(len(names) + 1, len(row) + 1)) |
| items = zip(names, row) |
| |
| elif hasattr(row, "__dict__"): # object |
| items = sorted(row.__dict__.items()) |
| |
| else: |
| raise PySparkTypeError( |
| error_class="CANNOT_INFER_SCHEMA_FOR_TYPE", |
| message_parameters={"data_type": type(row).__name__}, |
| ) |
| |
| fields = [] |
| for k, v in items: |
| try: |
| fields.append( |
| StructField( |
| k, |
| _infer_type( |
| v, |
| infer_dict_as_struct, |
| infer_array_from_first_element, |
| prefer_timestamp_ntz, |
| ), |
| True, |
| ) |
| ) |
| except TypeError: |
| raise PySparkTypeError( |
| error_class="CANNOT_INFER_TYPE_FOR_FIELD", |
| message_parameters={"field_name": k}, |
| ) |
| return StructType(fields) |
| |
| |
| def _has_nulltype(dt: DataType) -> bool: |
| """Return whether there is a NullType in `dt` or not""" |
| if isinstance(dt, StructType): |
| return any(_has_nulltype(f.dataType) for f in dt.fields) |
| elif isinstance(dt, ArrayType): |
| return _has_nulltype((dt.elementType)) |
| elif isinstance(dt, MapType): |
| return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) |
| else: |
| return isinstance(dt, NullType) |
| |
| |
| def _has_type(dt: DataType, dts: Union[type, Tuple[type, ...]]) -> bool: |
| """Return whether there are specified types""" |
| if isinstance(dt, dts): |
| return True |
| elif isinstance(dt, StructType): |
| return any(_has_type(f.dataType, dts) for f in dt.fields) |
| elif isinstance(dt, ArrayType): |
| return _has_type(dt.elementType, dts) |
| elif isinstance(dt, MapType): |
| return _has_type(dt.keyType, dts) or _has_type(dt.valueType, dts) |
| else: |
| return False |
| |
| |
| @overload |
| def _merge_type(a: StructType, b: StructType, name: Optional[str] = None) -> StructType: |
| ... |
| |
| |
| @overload |
| def _merge_type(a: ArrayType, b: ArrayType, name: Optional[str] = None) -> ArrayType: |
| ... |
| |
| |
| @overload |
| def _merge_type(a: MapType, b: MapType, name: Optional[str] = None) -> MapType: |
| ... |
| |
| |
| @overload |
| def _merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType: |
| ... |
| |
| |
| def _merge_type( |
| a: Union[StructType, ArrayType, MapType, DataType], |
| b: Union[StructType, ArrayType, MapType, DataType], |
| name: Optional[str] = None, |
| ) -> Union[StructType, ArrayType, MapType, DataType]: |
| if name is None: |
| |
| def new_msg(msg: str) -> str: |
| return msg |
| |
| def new_name(n: str) -> str: |
| return "field %s" % n |
| |
| else: |
| |
| def new_msg(msg: str) -> str: |
| return "%s: %s" % (name, msg) |
| |
| def new_name(n: str) -> str: |
| return "field %s in %s" % (n, name) |
| |
| if isinstance(a, NullType): |
| return b |
| elif isinstance(b, NullType): |
| return a |
| elif isinstance(a, TimestampType) and isinstance(b, TimestampNTZType): |
| return a |
| elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): |
| return b |
| elif isinstance(a, AtomicType) and isinstance(b, StringType): |
| return b |
| elif isinstance(a, StringType) and isinstance(b, AtomicType): |
| return a |
| elif type(a) is not type(b): |
| # TODO: type cast (such as int -> long) |
| raise PySparkTypeError( |
| error_class="CANNOT_MERGE_TYPE", |
| message_parameters={"data_type1": type(a).__name__, "data_type2": type(b).__name__}, |
| ) |
| |
| # same type |
| if isinstance(a, StructType): |
| nfs = dict((f.name, f.dataType) for f in cast(StructType, b).fields) |
| fields = [ |
| StructField( |
| f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()), name=new_name(f.name)) |
| ) |
| for f in a.fields |
| ] |
| names = set([f.name for f in fields]) |
| for n in nfs: |
| if n not in names: |
| fields.append(StructField(n, nfs[n])) |
| return StructType(fields) |
| |
| elif isinstance(a, ArrayType): |
| return ArrayType( |
| _merge_type( |
| a.elementType, cast(ArrayType, b).elementType, name="element in array %s" % name |
| ), |
| True, |
| ) |
| |
| elif isinstance(a, MapType): |
| return MapType( |
| _merge_type(a.keyType, cast(MapType, b).keyType, name="key of map %s" % name), |
| _merge_type(a.valueType, cast(MapType, b).valueType, name="value of map %s" % name), |
| True, |
| ) |
| else: |
| return a |
| |
| |
| def _need_converter(dataType: DataType) -> bool: |
| if isinstance(dataType, StructType): |
| return True |
| elif isinstance(dataType, ArrayType): |
| return _need_converter(dataType.elementType) |
| elif isinstance(dataType, MapType): |
| return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) |
| elif isinstance(dataType, NullType): |
| return True |
| else: |
| return False |
| |
| |
| def _create_converter(dataType: DataType) -> Callable: |
| """Create a converter to drop the names of fields in obj""" |
| if not _need_converter(dataType): |
| return lambda x: x |
| |
| if isinstance(dataType, ArrayType): |
| conv = _create_converter(dataType.elementType) |
| return lambda row: [conv(v) for v in row] |
| |
| elif isinstance(dataType, MapType): |
| kconv = _create_converter(dataType.keyType) |
| vconv = _create_converter(dataType.valueType) |
| return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items()) |
| |
| elif isinstance(dataType, NullType): |
| return lambda x: None |
| |
| elif not isinstance(dataType, StructType): |
| return lambda x: x |
| |
| # dataType must be StructType |
| names = [f.name for f in dataType.fields] |
| converters = [_create_converter(f.dataType) for f in dataType.fields] |
| convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) |
| |
| def convert_struct(obj: Any) -> Optional[Tuple]: |
| if obj is None: |
| return None |
| |
| if isinstance(obj, (tuple, list)): |
| if convert_fields: |
| return tuple(conv(v) for v, conv in zip(obj, converters)) |
| else: |
| return tuple(obj) |
| |
| if isinstance(obj, dict): |
| d = obj |
| elif hasattr(obj, "__dict__"): # object |
| d = obj.__dict__ |
| else: |
| raise PySparkTypeError( |
| error_class="UNSUPPORTED_DATA_TYPE", |
| message_parameters={"data_type": type(obj).__name__}, |
| ) |
| |
| if convert_fields: |
| return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) |
| else: |
| return tuple([d.get(name) for name in names]) |
| |
| return convert_struct |
| |
| |
| _acceptable_types = { |
| BooleanType: (bool,), |
| ByteType: (int,), |
| ShortType: (int,), |
| IntegerType: (int,), |
| LongType: (int,), |
| FloatType: (float,), |
| DoubleType: (float,), |
| DecimalType: (decimal.Decimal,), |
| StringType: (str,), |
| CharType: (str,), |
| VarcharType: (str,), |
| BinaryType: (bytearray, bytes), |
| DateType: (datetime.date, datetime.datetime), |
| TimestampType: (datetime.datetime,), |
| TimestampNTZType: (datetime.datetime,), |
| DayTimeIntervalType: (datetime.timedelta,), |
| ArrayType: (list, tuple, array), |
| MapType: (dict,), |
| StructType: (tuple, list, dict), |
| VariantType: ( |
| bool, |
| int, |
| float, |
| decimal.Decimal, |
| str, |
| bytearray, |
| bytes, |
| datetime.date, |
| datetime.datetime, |
| datetime.timedelta, |
| tuple, |
| list, |
| dict, |
| array, |
| ), |
| } |
| |
| |
| def _make_type_verifier( |
| dataType: DataType, |
| nullable: bool = True, |
| name: Optional[str] = None, |
| ) -> Callable: |
| """ |
| Make a verifier that checks the type of obj against dataType and raises a TypeError if they do |
| not match. |
| |
| This verifier also checks the value of obj against datatype and raises a ValueError if it's not |
| within the allowed range, e.g. using 128 as ByteType will overflow. Note that, Python float is |
| not checked, so it will become infinity when cast to Java float, if it overflows. |
| |
| Examples |
| -------- |
| >>> _make_type_verifier(StructType([]))(None) |
| >>> _make_type_verifier(StringType())("") |
| >>> _make_type_verifier(LongType())(0) |
| >>> _make_type_verifier(LongType())(1 << 64) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> _make_type_verifier(ArrayType(ShortType()))(list(range(3))) |
| >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkTypeError:... |
| >>> _make_type_verifier(MapType(StringType(), IntegerType()))({}) |
| >>> _make_type_verifier(StructType([]))(()) |
| >>> _make_type_verifier(StructType([]))([]) |
| >>> _make_type_verifier(StructType([]))([1]) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> # Check if numeric values are within the allowed range. |
| >>> _make_type_verifier(ByteType())(12) |
| >>> _make_type_verifier(ByteType())(1234) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> _make_type_verifier(ByteType(), False)(None) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> _make_type_verifier( |
| ... ArrayType(ShortType(), False))([1, None]) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> _make_type_verifier( # doctest: +IGNORE_EXCEPTION_DETAIL |
| ... MapType(StringType(), IntegerType()) |
| ... )({None: 1}) |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) |
| >>> _make_type_verifier(schema)((1, None)) # doctest: +IGNORE_EXCEPTION_DETAIL |
| Traceback (most recent call last): |
| ... |
| pyspark.errors.exceptions.base.PySparkValueError:... |
| """ |
| |
| if name is None: |
| |
| def new_msg(msg: str) -> str: |
| return msg |
| |
| def new_name(n: str) -> str: |
| return "field %s" % n |
| |
| else: |
| |
| def new_msg(msg: str) -> str: |
| return "%s: %s" % (name, msg) |
| |
| def new_name(n: str) -> str: |
| return "field %s in %s" % (n, name) |
| |
| def verify_nullability(obj: Any) -> bool: |
| if obj is None: |
| if nullable: |
| return True |
| else: |
| if name is not None: |
| raise PySparkValueError( |
| error_class="FIELD_NOT_NULLABLE_WITH_NAME", |
| message_parameters={ |
| "field_name": str(name), |
| }, |
| ) |
| raise PySparkValueError( |
| error_class="FIELD_NOT_NULLABLE", |
| message_parameters={}, |
| ) |
| else: |
| return False |
| |
| _type = type(dataType) |
| |
| def assert_acceptable_types(obj: Any) -> None: |
| assert _type in _acceptable_types, new_msg( |
| "unknown datatype: %s for object %r" % (dataType, obj) |
| ) |
| |
| def verify_acceptable_types(obj: Any) -> None: |
| # subclass of them can not be fromInternal in JVM |
| if type(obj) not in _acceptable_types[_type]: |
| if name is not None: |
| raise PySparkTypeError( |
| error_class="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME", |
| message_parameters={ |
| "field_name": str(name), |
| "data_type": str(dataType), |
| "obj": repr(obj), |
| "obj_type": str(type(obj)), |
| }, |
| ) |
| raise PySparkTypeError( |
| error_class="FIELD_DATA_TYPE_UNACCEPTABLE", |
| message_parameters={ |
| "data_type": str(dataType), |
| "obj": repr(obj), |
| "obj_type": str(type(obj)), |
| }, |
| ) |
| |
| if isinstance(dataType, (StringType, CharType, VarcharType)): |
| # StringType, CharType and VarcharType can work with any types |
| def verify_value(obj: Any) -> None: |
| pass |
| |
| elif isinstance(dataType, UserDefinedType): |
| verifier = _make_type_verifier(dataType.sqlType(), name=name) |
| |
| def verify_udf(obj: Any) -> None: |
| if not (hasattr(obj, "__UDT__") and obj.__UDT__ == dataType): |
| if name is not None: |
| raise PySparkValueError( |
| error_class="FIELD_TYPE_MISMATCH_WITH_NAME", |
| message_parameters={ |
| "field_name": str(name), |
| "obj": str(obj), |
| "data_type": str(dataType), |
| }, |
| ) |
| raise PySparkValueError( |
| error_class="FIELD_TYPE_MISMATCH", |
| message_parameters={ |
| "obj": str(obj), |
| "data_type": str(dataType), |
| }, |
| ) |
| verifier(dataType.toInternal(obj)) |
| |
| verify_value = verify_udf |
| |
| elif isinstance(dataType, ByteType): |
| |
| def verify_byte(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| lower_bound = -128 |
| upper_bound = 127 |
| if obj < lower_bound or obj > upper_bound: |
| raise PySparkValueError( |
| error_class="VALUE_OUT_OF_BOUNDS", |
| message_parameters={ |
| "arg_name": "obj", |
| "lower_bound": str(lower_bound), |
| "upper_bound": str(upper_bound), |
| "actual": str(obj), |
| }, |
| ) |
| |
| verify_value = verify_byte |
| |
| elif isinstance(dataType, ShortType): |
| |
| def verify_short(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| lower_bound = -32768 |
| upper_bound = 32767 |
| if obj < lower_bound or obj > upper_bound: |
| raise PySparkValueError( |
| error_class="VALUE_OUT_OF_BOUNDS", |
| message_parameters={ |
| "arg_name": "obj", |
| "lower_bound": str(lower_bound), |
| "upper_bound": str(upper_bound), |
| "actual": str(obj), |
| }, |
| ) |
| |
| verify_value = verify_short |
| |
| elif isinstance(dataType, IntegerType): |
| |
| def verify_integer(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| lower_bound = -2147483648 |
| upper_bound = 2147483647 |
| if obj < lower_bound or obj > upper_bound: |
| raise PySparkValueError( |
| error_class="VALUE_OUT_OF_BOUNDS", |
| message_parameters={ |
| "arg_name": "obj", |
| "lower_bound": str(lower_bound), |
| "upper_bound": str(upper_bound), |
| "actual": str(obj), |
| }, |
| ) |
| |
| verify_value = verify_integer |
| |
| elif isinstance(dataType, LongType): |
| |
| def verify_long(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| lower_bound = -9223372036854775808 |
| upper_bound = 9223372036854775807 |
| if obj < lower_bound or obj > upper_bound: |
| raise PySparkValueError( |
| error_class="VALUE_OUT_OF_BOUNDS", |
| message_parameters={ |
| "arg_name": "obj", |
| "lower_bound": str(lower_bound), |
| "upper_bound": str(upper_bound), |
| "actual": str(obj), |
| }, |
| ) |
| |
| verify_value = verify_long |
| |
| elif isinstance(dataType, ArrayType): |
| element_verifier = _make_type_verifier( |
| dataType.elementType, dataType.containsNull, name="element in array %s" % name |
| ) |
| |
| def verify_array(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| for i in obj: |
| element_verifier(i) |
| |
| verify_value = verify_array |
| |
| elif isinstance(dataType, MapType): |
| key_verifier = _make_type_verifier(dataType.keyType, False, name="key of map %s" % name) |
| value_verifier = _make_type_verifier( |
| dataType.valueType, dataType.valueContainsNull, name="value of map %s" % name |
| ) |
| |
| def verify_map(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| for k, v in obj.items(): |
| key_verifier(k) |
| value_verifier(v) |
| |
| verify_value = verify_map |
| |
| elif isinstance(dataType, StructType): |
| verifiers = [] |
| for f in dataType.fields: |
| verifier = _make_type_verifier(f.dataType, f.nullable, name=new_name(f.name)) |
| verifiers.append((f.name, verifier)) |
| |
| def verify_struct(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| |
| if isinstance(obj, dict): |
| for f, verifier in verifiers: |
| verifier(obj.get(f)) |
| elif isinstance(obj, (tuple, list)): |
| if len(obj) != len(verifiers): |
| if name is not None: |
| raise PySparkValueError( |
| error_class="FIELD_STRUCT_LENGTH_MISMATCH_WITH_NAME", |
| message_parameters={ |
| "field_name": str(name), |
| "object_length": str(len(obj)), |
| "field_length": str(len(verifiers)), |
| }, |
| ) |
| raise PySparkValueError( |
| error_class="FIELD_STRUCT_LENGTH_MISMATCH", |
| message_parameters={ |
| "object_length": str(len(obj)), |
| "field_length": str(len(verifiers)), |
| }, |
| ) |
| for v, (_, verifier) in zip(obj, verifiers): |
| verifier(v) |
| elif hasattr(obj, "__dict__"): |
| d = obj.__dict__ |
| for f, verifier in verifiers: |
| verifier(d.get(f)) |
| else: |
| if name is not None: |
| raise PySparkTypeError( |
| error_class="FIELD_DATA_TYPE_UNACCEPTABLE_WITH_NAME", |
| message_parameters={ |
| "field_name": str(name), |
| "data_type": str(dataType), |
| "obj": repr(obj), |
| "obj_type": str(type(obj)), |
| }, |
| ) |
| raise PySparkTypeError( |
| error_class="FIELD_DATA_TYPE_UNACCEPTABLE", |
| message_parameters={ |
| "data_type": str(dataType), |
| "obj": repr(obj), |
| "obj_type": str(type(obj)), |
| }, |
| ) |
| |
| verify_value = verify_struct |
| |
| elif isinstance(dataType, VariantType): |
| |
| def verify_variant(obj: Any) -> None: |
| # The variant data type can take in any type. |
| pass |
| |
| verify_value = verify_variant |
| |
| else: |
| |
| def verify_default(obj: Any) -> None: |
| assert_acceptable_types(obj) |
| verify_acceptable_types(obj) |
| |
| verify_value = verify_default |
| |
| def verify(obj: Any) -> None: |
| if not verify_nullability(obj): |
| verify_value(obj) |
| |
| return verify |
| |
| |
| # This is used to unpickle a Row from JVM |
| def _create_row_inbound_converter(dataType: DataType) -> Callable: |
| return lambda *a: dataType.fromInternal(a) |
| |
| |
| def _create_row( |
| fields: Union["Row", List[str]], values: Union[Tuple[Any, ...], List[Any]] |
| ) -> "Row": |
| row = Row(*values) |
| row.__fields__ = fields |
| return row |
| |
| |
| class Row(tuple): |
| |
| """ |
| A row in :class:`DataFrame`. |
| The fields in it can be accessed: |
| |
| * like attributes (``row.key``) |
| * like dictionary values (``row[key]``) |
| |
| ``key in row`` will search through row keys. |
| |
| Row can be used to create a row object by using named arguments. |
| It is not allowed to omit a named argument to represent that the value is |
| None or missing. This should be explicitly set to None in this case. |
| |
| .. versionchanged:: 3.0.0 |
| Rows created from named arguments no longer have |
| field names sorted alphabetically and will be ordered in the position as |
| entered. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql import Row |
| >>> row = Row(name="Alice", age=11) |
| >>> row |
| Row(name='Alice', age=11) |
| >>> row['name'], row['age'] |
| ('Alice', 11) |
| >>> row.name, row.age |
| ('Alice', 11) |
| >>> 'name' in row |
| True |
| >>> 'wrong_key' in row |
| False |
| |
| Row also can be used to create another Row like class, then it |
| could be used to create Row objects, such as |
| |
| >>> Person = Row("name", "age") |
| >>> Person |
| <Row('name', 'age')> |
| >>> 'name' in Person |
| True |
| >>> 'wrong_key' in Person |
| False |
| >>> Person("Alice", 11) |
| Row(name='Alice', age=11) |
| |
| This form can also be used to create rows as tuple values, i.e. with unnamed |
| fields. |
| |
| >>> row1 = Row("Alice", 11) |
| >>> row2 = Row(name="Alice", age=11) |
| >>> row1 == row2 |
| True |
| """ |
| |
| @overload |
| def __new__(cls, *args: str) -> "Row": |
| ... |
| |
| @overload |
| def __new__(cls, **kwargs: Any) -> "Row": |
| ... |
| |
| def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": |
| if args and kwargs: |
| raise PySparkValueError( |
| error_class="CANNOT_SET_TOGETHER", |
| message_parameters={"arg_list": "args and kwargs"}, |
| ) |
| if kwargs: |
| # create row objects |
| row = tuple.__new__(cls, list(kwargs.values())) |
| row.__fields__ = list(kwargs.keys()) |
| return row |
| else: |
| # create row class or objects |
| return tuple.__new__(cls, args) |
| |
| def asDict(self, recursive: bool = False) -> Dict[str, Any]: |
| """ |
| Return as a dict |
| |
| Parameters |
| ---------- |
| recursive : bool, optional |
| turns the nested Rows to dict (default: False). |
| |
| Notes |
| ----- |
| If a row contains duplicate field names, e.g., the rows of a join |
| between two :class:`DataFrame` that both have the fields of same names, |
| one of the duplicate fields will be selected by ``asDict``. ``__getitem__`` |
| will also return one of the duplicate fields, however returned value might |
| be different to ``asDict``. |
| |
| Examples |
| -------- |
| >>> from pyspark.sql import Row |
| >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} |
| True |
| >>> row = Row(key=1, value=Row(name='a', age=2)) |
| >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} |
| True |
| >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} |
| True |
| """ |
| if not hasattr(self, "__fields__"): |
| raise PySparkTypeError( |
| error_class="CANNOT_CONVERT_TYPE", |
| message_parameters={ |
| "from_type": "Row", |
| "to_type": "dict", |
| }, |
| ) |
| |
| if recursive: |
| |
| def conv(obj: Any) -> Any: |
| if isinstance(obj, Row): |
| return obj.asDict(True) |
| elif isinstance(obj, list): |
| return [conv(o) for o in obj] |
| elif isinstance(obj, dict): |
| return dict((k, conv(v)) for k, v in obj.items()) |
| else: |
| return obj |
| |
| return dict(zip(self.__fields__, (conv(o) for o in self))) |
| else: |
| return dict(zip(self.__fields__, self)) |
| |
| def __contains__(self, item: Any) -> bool: |
| if hasattr(self, "__fields__"): |
| return item in self.__fields__ |
| else: |
| return super(Row, self).__contains__(item) |
| |
| # let object acts like class |
| def __call__(self, *args: Any) -> "Row": |
| """create new Row object""" |
| if len(args) > len(self): |
| raise PySparkValueError( |
| error_class="TOO_MANY_VALUES", |
| message_parameters={ |
| "expected": str(len(self)), |
| "item": "fields", |
| "actual": str(len(args)), |
| }, |
| ) |
| return _create_row(self, args) |
| |
| def __getitem__(self, item: Any) -> Any: |
| if isinstance(item, (int, slice)): |
| return super(Row, self).__getitem__(item) |
| try: |
| # it will be slow when it has many fields, |
| # but this will not be used in normal cases |
| idx = self.__fields__.index(item) |
| return super(Row, self).__getitem__(idx) |
| except IndexError: |
| raise PySparkKeyError( |
| error_class="KEY_NOT_EXISTS", message_parameters={"key": str(item)} |
| ) |
| except ValueError: |
| raise PySparkValueError(item) |
| |
| def __getattr__(self, item: str) -> Any: |
| if item.startswith("__"): |
| raise PySparkAttributeError( |
| error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": item} |
| ) |
| try: |
| # it will be slow when it has many fields, |
| # but this will not be used in normal cases |
| idx = self.__fields__.index(item) |
| return self[idx] |
| except IndexError: |
| raise PySparkAttributeError( |
| error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": item} |
| ) |
| except ValueError: |
| raise PySparkAttributeError( |
| error_class="ATTRIBUTE_NOT_SUPPORTED", message_parameters={"attr_name": item} |
| ) |
| |
| def __setattr__(self, key: Any, value: Any) -> None: |
| if key != "__fields__": |
| raise PySparkRuntimeError( |
| error_class="READ_ONLY", |
| message_parameters={"object": "Row"}, |
| ) |
| self.__dict__[key] = value |
| |
| def __reduce__( |
| self, |
| ) -> Union[str, Tuple[Any, ...]]: |
| """Returns a tuple so Python knows how to pickle Row.""" |
| if hasattr(self, "__fields__"): |
| return (_create_row, (self.__fields__, tuple(self))) |
| else: |
| return tuple.__reduce__(self) |
| |
| def __repr__(self) -> str: |
| """Printable representation of Row used in Python REPL.""" |
| if hasattr(self, "__fields__"): |
| return "Row(%s)" % ", ".join( |
| "%s=%r" % (k, v) for k, v in zip(self.__fields__, tuple(self)) |
| ) |
| else: |
| return "<Row(%s)>" % ", ".join(repr(field) for field in self) |
| |
| |
| class DateConverter: |
| def can_convert(self, obj: Any) -> bool: |
| return isinstance(obj, datetime.date) |
| |
| def convert(self, obj: datetime.date, gateway_client: "GatewayClient") -> "JavaGateway": |
| from py4j.java_gateway import JavaClass |
| |
| Date = JavaClass("java.sql.Date", gateway_client) |
| return Date.valueOf(obj.strftime("%Y-%m-%d")) |
| |
| |
| class DatetimeConverter: |
| def can_convert(self, obj: Any) -> bool: |
| return isinstance(obj, datetime.datetime) |
| |
| def convert(self, obj: datetime.datetime, gateway_client: "GatewayClient") -> "JavaGateway": |
| from py4j.java_gateway import JavaClass |
| |
| Timestamp = JavaClass("java.sql.Timestamp", gateway_client) |
| seconds = ( |
| calendar.timegm(obj.utctimetuple()) if obj.tzinfo else time.mktime(obj.timetuple()) |
| ) |
| t = Timestamp(int(seconds) * 1000) |
| t.setNanos(obj.microsecond * 1000) |
| return t |
| |
| |
| class DatetimeNTZConverter: |
| def can_convert(self, obj: Any) -> bool: |
| from pyspark.sql.utils import is_timestamp_ntz_preferred |
| |
| return ( |
| isinstance(obj, datetime.datetime) |
| and obj.tzinfo is None |
| and is_timestamp_ntz_preferred() |
| ) |
| |
| def convert(self, obj: datetime.datetime, gateway_client: "GatewayClient") -> "JavaGateway": |
| from py4j.java_gateway import JavaClass |
| |
| seconds = calendar.timegm(obj.utctimetuple()) |
| DateTimeUtils = JavaClass( |
| "org.apache.spark.sql.catalyst.util.DateTimeUtils", |
| gateway_client, |
| ) |
| return DateTimeUtils.microsToLocalDateTime(int(seconds) * 1000000 + obj.microsecond) |
| |
| |
| class DayTimeIntervalTypeConverter: |
| def can_convert(self, obj: Any) -> bool: |
| return isinstance(obj, datetime.timedelta) |
| |
| def convert(self, obj: datetime.timedelta, gateway_client: "GatewayClient") -> "JavaGateway": |
| from py4j.java_gateway import JavaClass |
| |
| IntervalUtils = JavaClass( |
| "org.apache.spark.sql.catalyst.util.IntervalUtils", |
| gateway_client, |
| ) |
| return IntervalUtils.microsToDuration( |
| (math.floor(obj.total_seconds()) * 1000000) + obj.microseconds |
| ) |
| |
| |
| class NumpyScalarConverter: |
| def can_convert(self, obj: Any) -> bool: |
| return has_numpy and isinstance(obj, np.generic) |
| |
| def convert(self, obj: "np.generic", gateway_client: "GatewayClient") -> Any: |
| return obj.item() |
| |
| |
| class NumpyArrayConverter: |
| def _from_numpy_type_to_java_type( |
| self, nt: "np.dtype", gateway: "JavaGateway" |
| ) -> Optional["JavaClass"]: |
| """Convert NumPy type to Py4J Java type.""" |
| if nt in [np.dtype("int8"), np.dtype("int16")]: |
| # Mapping int8 to gateway.jvm.byte causes |
| # TypeError: 'bytes' object does not support item assignment |
| return gateway.jvm.short |
| elif nt == np.dtype("int32"): |
| return gateway.jvm.int |
| elif nt == np.dtype("int64"): |
| return gateway.jvm.long |
| elif nt == np.dtype("float32"): |
| return gateway.jvm.float |
| elif nt == np.dtype("float64"): |
| return gateway.jvm.double |
| elif nt == np.dtype("bool"): |
| return gateway.jvm.boolean |
| |
| return None |
| |
| def can_convert(self, obj: Any) -> bool: |
| return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1 |
| |
| def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGateway": |
| from pyspark import SparkContext |
| |
| gateway = SparkContext._gateway |
| assert gateway is not None |
| plist = obj.tolist() |
| |
| if len(obj) > 0 and isinstance(plist[0], str): |
| jtpe = gateway.jvm.String |
| else: |
| jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway) |
| if jtpe is None: |
| raise PySparkTypeError( |
| error_class="UNSUPPORTED_NUMPY_ARRAY_SCALAR", |
| message_parameters={"dtype": str(obj.dtype)}, |
| ) |
| jarr = gateway.new_array(jtpe, len(obj)) |
| for i in range(len(plist)): |
| jarr[i] = plist[i] |
| return jarr |
| |
| |
| if not is_remote_only(): |
| from py4j.protocol import register_input_converter |
| |
| # datetime is a subclass of date, we should register DatetimeConverter first |
| register_input_converter(DatetimeNTZConverter()) |
| register_input_converter(DatetimeConverter()) |
| register_input_converter(DateConverter()) |
| register_input_converter(DayTimeIntervalTypeConverter()) |
| register_input_converter(NumpyScalarConverter()) |
| # NumPy array satisfies py4j.java_collections.ListConverter, |
| # so prepend NumpyArrayConverter |
| register_input_converter(NumpyArrayConverter(), prepend=True) |
| |
| |
| def _test() -> None: |
| import doctest |
| from pyspark.sql import SparkSession |
| |
| globs = globals() |
| globs["spark"] = SparkSession.builder.getOrCreate() |
| (failure_count, test_count) = doctest.testmod( |
| globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE |
| ) |
| if failure_count: |
| sys.exit(-1) |
| |
| |
| if __name__ == "__main__": |
| _test() |