blob: a6939b10c3018ead0a9c89f6f42925b9106bb6e8 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import decimal
import datetime
import json
import struct
from array import array
from decimal import Decimal
from typing import Any, Callable, Dict, List, Tuple
from pyspark.errors import (
PySparkNotImplementedError,
PySparkValueError,
)
from zoneinfo import ZoneInfo
class VariantUtils:
"""
A utility class for VariantVal.
Adapted from library at: org.apache.spark.types.variant.VariantUtil
"""
BASIC_TYPE_BITS = 2
BASIC_TYPE_MASK = 0x3
TYPE_INFO_MASK = 0x3F
# The inclusive maximum value of the type info value. It is the size limit of `SHORT_STR`.
MAX_SHORT_STR_SIZE = 0x3F
# Below is all possible basic type values.
# Primitive value. The type info value must be one of the values in the below section.
PRIMITIVE = 0
# Short string value. The type info value is the string size, which must be in `[0,
# MAX_SHORT_STR_SIZE]`.
# The string content bytes directly follow the header byte.
SHORT_STR = 1
# Object value. The content contains a size, a list of field ids, a list of field offsets, and
# the actual field data. The length of the id list is `size`, while the length of the offset
# list is `size + 1`, where the last offset represent the total size of the field data. The
# fields in an object must be sorted by the field name in alphabetical order. Duplicate field
# names in one object are not allowed.
# We use 5 bits in the type info to specify the integer type of the object header: it should
# be 0_b4_b3b2_b1b0 (MSB is 0), where:
# - b4 specifies the type of size. When it is 0/1, `size` is a little-endian 1/4-byte
# unsigned integer.
# - b3b2/b1b0 specifies the integer type of id and offset. When the 2 bits are 0/1/2, the
# list contains 1/2/3-byte little-endian unsigned integers.
OBJECT = 2
# Array value. The content contains a size, a list of field offsets, and the actual element
# data. It is similar to an object without the id list. The length of the offset list
# is `size + 1`, where the last offset represent the total size of the element data.
# Its type info should be: 000_b2_b1b0:
# - b2 specifies the type of size.
# - b1b0 specifies the integer type of offset.
ARRAY = 3
# Below is all possible type info values for `PRIMITIVE`.
# JSON Null value. Empty content.
NULL = 0
# True value. Empty content.
TRUE = 1
# False value. Empty content.
FALSE = 2
# 1-byte little-endian signed integer.
INT1 = 3
# 2-byte little-endian signed integer.
INT2 = 4
# 4-byte little-endian signed integer.
INT4 = 5
# 4-byte little-endian signed integer.
INT8 = 6
# 8-byte IEEE double.
DOUBLE = 7
# 4-byte decimal. Content is 1-byte scale + 4-byte little-endian signed integer.
DECIMAL4 = 8
# 8-byte decimal. Content is 1-byte scale + 8-byte little-endian signed integer.
DECIMAL8 = 9
# 16-byte decimal. Content is 1-byte scale + 16-byte little-endian signed integer.
DECIMAL16 = 10
# Date value. Content is 4-byte little-endian signed integer that represents the number of days
# from the Unix epoch.
DATE = 11
# Timestamp value. Content is 8-byte little-endian signed integer that represents the number of
# microseconds elapsed since the Unix epoch, 1970-01-01 00:00:00 UTC. This is a timezone-aware
# field and when reading into a Python datetime object defaults to the UTC timezone.
TIMESTAMP = 12
# Timestamp_ntz value. It has the same content as `TIMESTAMP` but should always be interpreted
# as if the local time zone is UTC.
TIMESTAMP_NTZ = 13
# 4-byte IEEE float.
FLOAT = 14
# Binary value. The content is (4-byte little-endian unsigned integer representing the binary
# size) + (size bytes of binary content).
BINARY = 15
# Long string value. The content is (4-byte little-endian unsigned integer representing the
# string size) + (size bytes of string content).
LONG_STR = 16
# year-month interval value. The content is one byte representing the start and end field values
# (1 bit each starting at least significant bits) and a 4-byte little-endian signed integer
YEAR_MONTH_INTERVAL = 19
# day-time interval value. The content is one byte representing the start and end field values
# (2 bits each starting at least significant bits) and an 8-byte little-endian signed integer
DAY_TIME_INTERVAL = 20
U32_SIZE = 4
EPOCH = datetime.datetime(
year=1970, month=1, day=1, hour=0, minute=0, second=0, tzinfo=datetime.timezone.utc
)
EPOCH_NTZ = datetime.datetime(year=1970, month=1, day=1, hour=0, minute=0, second=0)
MAX_DECIMAL4_PRECISION = 9
MAX_DECIMAL4_VALUE = 10**MAX_DECIMAL4_PRECISION
MAX_DECIMAL8_PRECISION = 18
MAX_DECIMAL8_VALUE = 10**MAX_DECIMAL8_PRECISION
MAX_DECIMAL16_PRECISION = 38
MAX_DECIMAL16_VALUE = 10**MAX_DECIMAL16_PRECISION
# There is no PySpark equivalent of the SQL year-month interval type. This class acts as a
# placeholder for this type
class _PlaceholderYearMonthIntervalInternalType:
pass
@classmethod
def to_json(cls, value: bytes, metadata: bytes, zone_id: str = "UTC") -> str:
"""
Convert the VariantVal to a JSON string. The `zone_id` parameter denotes the time zone that
timestamp fields should be parsed in. It defaults to "UTC". The list of valid zone IDs can
found by importing the `zoneinfo` module and running `zoneinfo.available_timezones()`.
:return: JSON string
"""
return cls._to_json(value, metadata, 0, zone_id)
@classmethod
def to_python(cls, value: bytes, metadata: bytes) -> str:
"""
Convert the VariantVal to a nested Python object of Python data types.
:return: Python representation of the Variant nested structure
"""
return cls._to_python(value, metadata, 0)
@classmethod
def _read_long(cls, data: bytes, pos: int, num_bytes: int, signed: bool) -> int:
cls._check_index(pos, len(data))
cls._check_index(pos + num_bytes - 1, len(data))
return int.from_bytes(data[pos : pos + num_bytes], byteorder="little", signed=signed)
@classmethod
def _check_index(cls, pos: int, length: int) -> None:
if pos < 0 or pos >= length:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_type_info(cls, value: bytes, pos: int) -> Tuple[int, int]:
"""
Returns the (basic_type, type_info) pair from the given position in the value.
"""
basic_type = value[pos] & VariantUtils.BASIC_TYPE_MASK
type_info = (value[pos] >> VariantUtils.BASIC_TYPE_BITS) & VariantUtils.TYPE_INFO_MASK
return (basic_type, type_info)
@classmethod
def _get_day_time_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]:
"""
Returns the (start_field, end_field) pair for a variant representing a day-time interval
value stored at a given position in the value.
"""
cls._check_index(pos, len(value))
start_field = value[pos] & 0x3
end_field = (value[pos] >> 2) & 0x3
if end_field < start_field:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
return (start_field, end_field)
@classmethod
def _get_year_month_interval_fields(cls, value: bytes, pos: int) -> Tuple[int, int]:
"""
Returns the (start_field, end_field) paid for a variant representing a year-month interval
value stored at a given position in the value.
"""
cls._check_index(pos, len(value))
start_field = value[pos] & 0x1
end_field = (value[pos] >> 1) & 0x1
if end_field < start_field:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
return (start_field, end_field)
@classmethod
def _get_metadata_key(cls, metadata: bytes, id: int) -> str:
"""
Returns the key string from the dictionary in the metadata, corresponding to `id`.
"""
cls._check_index(0, len(metadata))
offset_size = ((metadata[0] >> 6) & 0x3) + 1
dict_size = cls._read_long(metadata, 1, offset_size, signed=False)
if id >= dict_size:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
string_start = 1 + (dict_size + 2) * offset_size
offset = cls._read_long(metadata, 1 + (id + 1) * offset_size, offset_size, signed=False)
next_offset = cls._read_long(
metadata, 1 + (id + 2) * offset_size, offset_size, signed=False
)
if offset > next_offset:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
cls._check_index(string_start + next_offset - 1, len(metadata))
return metadata[string_start + offset : (string_start + next_offset)].decode("utf-8")
@classmethod
def _get_boolean(cls, value: bytes, pos: int) -> bool:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE or (
type_info != VariantUtils.TRUE and type_info != VariantUtils.FALSE
):
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
return type_info == VariantUtils.TRUE
@classmethod
def _get_long(cls, value: bytes, pos: int) -> int:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.INT1:
return cls._read_long(value, pos + 1, 1, signed=True)
elif type_info == VariantUtils.INT2:
return cls._read_long(value, pos + 1, 2, signed=True)
elif type_info == VariantUtils.INT4 or type_info == VariantUtils.DATE:
return cls._read_long(value, pos + 1, 4, signed=True)
elif type_info == VariantUtils.INT8:
return cls._read_long(value, pos + 1, 8, signed=True)
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_date(cls, value: bytes, pos: int) -> datetime.date:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.DATE:
days_since_epoch = cls._read_long(value, pos + 1, 4, signed=True)
return datetime.date.fromordinal(VariantUtils.EPOCH.toordinal() + days_since_epoch)
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_timestamp(cls, value: bytes, pos: int, zone_id: str) -> datetime.datetime:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.TIMESTAMP_NTZ:
microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
return VariantUtils.EPOCH_NTZ + datetime.timedelta(
microseconds=microseconds_since_epoch
)
if type_info == VariantUtils.TIMESTAMP:
microseconds_since_epoch = cls._read_long(value, pos + 1, 8, signed=True)
return (
VariantUtils.EPOCH + datetime.timedelta(microseconds=microseconds_since_epoch)
).astimezone(ZoneInfo(zone_id))
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_yminterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]:
"""
Returns the (months, start_field, end_field) tuple from a year-month interval value at a
given position in a variant.
"""
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.YEAR_MONTH_INTERVAL:
months = cls._read_long(value, pos + 2, 4, signed=True)
start_field, end_field = cls._get_year_month_interval_fields(value, pos + 1)
return (months, start_field, end_field)
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_dtinterval_info(cls, value: bytes, pos: int) -> Tuple[int, int, int]:
"""
Returns the (micros, start_field, end_field) tuple from a day-time interval value at a given
position in a variant.
"""
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.DAY_TIME_INTERVAL:
micros = cls._read_long(value, pos + 2, 8, signed=True)
start_field, end_field = cls._get_day_time_interval_fields(value, pos + 1)
return (micros, start_field, end_field)
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_string(cls, value: bytes, pos: int) -> str:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type == VariantUtils.SHORT_STR or (
basic_type == VariantUtils.PRIMITIVE and type_info == VariantUtils.LONG_STR
):
start = 0
length = 0
if basic_type == VariantUtils.SHORT_STR:
start = pos + 1
length = type_info
else:
start = pos + 1 + VariantUtils.U32_SIZE
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
return value[start : start + length].decode("utf-8")
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_double(cls, value: bytes, pos: int) -> float:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if type_info == VariantUtils.FLOAT:
cls._check_index(pos + 4, len(value))
return struct.unpack("<f", value[pos + 1 : pos + 5])[0]
elif type_info == VariantUtils.DOUBLE:
cls._check_index(pos + 8, len(value))
return struct.unpack("<d", value[pos + 1 : pos + 9])[0]
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _check_decimal(cls, unscaled: int, scale: int, max_unscaled: int, max_scale: int) -> None:
# max_unscaled == 10**max_scale, but we pass a literal parameter to avoid redundant
# computation.
if unscaled >= max_unscaled or unscaled <= -max_unscaled or scale > max_scale:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _get_decimal(cls, value: bytes, pos: int) -> decimal.Decimal:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
scale = value[pos + 1]
unscaled = 0
if type_info == VariantUtils.DECIMAL4:
unscaled = cls._read_long(value, pos + 2, 4, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL4_VALUE, cls.MAX_DECIMAL4_PRECISION)
elif type_info == VariantUtils.DECIMAL8:
unscaled = cls._read_long(value, pos + 2, 8, signed=True)
cls._check_decimal(unscaled, scale, cls.MAX_DECIMAL8_VALUE, cls.MAX_DECIMAL8_PRECISION)
elif type_info == VariantUtils.DECIMAL16:
cls._check_index(pos + 17, len(value))
unscaled = int.from_bytes(value[pos + 2 : pos + 18], byteorder="little", signed=True)
cls._check_decimal(
unscaled, scale, cls.MAX_DECIMAL16_VALUE, cls.MAX_DECIMAL16_PRECISION
)
else:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
return decimal.Decimal(unscaled) * (decimal.Decimal(10) ** (-scale))
@classmethod
def _get_binary(cls, value: bytes, pos: int) -> bytes:
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.PRIMITIVE or type_info != VariantUtils.BINARY:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
start = pos + 1 + VariantUtils.U32_SIZE
length = cls._read_long(value, pos + 1, VariantUtils.U32_SIZE, signed=False)
cls._check_index(start + length - 1, len(value))
return bytes(value[start : start + length])
@classmethod
def _get_type(cls, value: bytes, pos: int) -> Any:
"""
Returns the Python type of the Variant at the given position.
"""
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type == VariantUtils.SHORT_STR:
return str
elif basic_type == VariantUtils.OBJECT:
return dict
elif basic_type == VariantUtils.ARRAY:
return array
elif type_info == VariantUtils.NULL:
return type(None)
elif type_info == VariantUtils.TRUE or type_info == VariantUtils.FALSE:
return bool
elif (
type_info == VariantUtils.INT1
or type_info == VariantUtils.INT2
or type_info == VariantUtils.INT4
or type_info == VariantUtils.INT8
):
return int
elif type_info == VariantUtils.DOUBLE or type_info == VariantUtils.FLOAT:
return float
elif (
type_info == VariantUtils.DECIMAL4
or type_info == VariantUtils.DECIMAL8
or type_info == VariantUtils.DECIMAL16
):
return decimal.Decimal
elif type_info == VariantUtils.BINARY:
return bytes
elif type_info == VariantUtils.DATE:
return datetime.date
elif type_info == VariantUtils.TIMESTAMP or type_info == VariantUtils.TIMESTAMP_NTZ:
return datetime.datetime
elif type_info == VariantUtils.LONG_STR:
return str
elif type_info == VariantUtils.DAY_TIME_INTERVAL:
return datetime.timedelta
elif type_info == VariantUtils.YEAR_MONTH_INTERVAL:
return cls._PlaceholderYearMonthIntervalInternalType
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _to_year_month_interval_ansi_string(
cls, months: int, start_field: int, end_field: int
) -> str:
"""
Used to convert months representing a year-month interval with given start and end
fields to its ANSI SQL string representation.
"""
YEAR = 0
MONTHS_PER_YEAR = 12
sign = ""
abs_months = months
if months < 0:
sign = "-"
abs_months = -abs_months
year = sign + str(abs_months // MONTHS_PER_YEAR)
year_and_month = year + "-" + str(abs_months % MONTHS_PER_YEAR)
format_builder = ["INTERVAL '"]
if start_field == end_field:
if start_field == YEAR:
format_builder.append(year + "' YEAR")
else:
format_builder.append(str(months) + "' MONTH")
else:
format_builder.append(year_and_month + "' YEAR TO MONTH")
return "".join(format_builder)
@classmethod
def _to_day_time_interval_ansi_string(
cls, micros: int, start_field: int, end_field: int
) -> str:
"""
Used to convert microseconds representing a day-tine interval with given start and end
fields to its ANSI SQL string representation.
"""
DAY = 0
HOUR = 1
MINUTE = 2
SECOND = 3
MIN_LONG_VALUE = -9223372036854775808
MAX_LONG_VALUE = 9223372036854775807
MICROS_PER_SECOND = 1000 * 1000
MICROS_PER_MINUTE = MICROS_PER_SECOND * 60
MICROS_PER_HOUR = MICROS_PER_MINUTE * 60
MICROS_PER_DAY = MICROS_PER_HOUR * 24
MAX_SECOND = MAX_LONG_VALUE // MICROS_PER_SECOND
MAX_MINUTE = MAX_LONG_VALUE // MICROS_PER_MINUTE
MAX_HOUR = MAX_LONG_VALUE // MICROS_PER_HOUR
MAX_DAY = MAX_LONG_VALUE // MICROS_PER_DAY
def field_to_string(field: int) -> str:
if field == DAY:
return "DAY"
elif field == HOUR:
return "HOUR"
elif field == MINUTE:
return "MINUTE"
elif field == SECOND:
return "SECOND"
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
if end_field < start_field:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
sign = ""
rest = micros
from_str = field_to_string(start_field).upper()
to_str = field_to_string(end_field).upper()
prefix = "INTERVAL '"
postfix = f"' {from_str}" if (start_field == end_field) else f"' {from_str} TO {to_str}"
if micros < 0:
if micros == MIN_LONG_VALUE:
# Especial handling of minimum `Long` value because negate op overflows `Long`.
# seconds = 106751991 * (24 * 60 * 60) + 4 * 60 * 60 + 54 = 9223372036854
# microseconds = -9223372036854000000L-775808 == Long.MinValue
base_str = "-106751991 04:00:54.775808000"
first_str = "-" + (
str(MAX_DAY)
if (start_field == DAY)
else (
str(MAX_HOUR)
if (start_field == HOUR)
else (
str(MAX_MINUTE)
if (start_field == MINUTE)
else str(MAX_SECOND) + ".775808"
)
)
)
if start_field == end_field:
return prefix + first_str + postfix
else:
substr_start = (
10 if (start_field == DAY) else (13 if (start_field == HOUR) else 16)
)
substr_end = (
13 if (end_field == HOUR) else (16 if (end_field == MINUTE) else 26)
)
return prefix + first_str + base_str[substr_start:substr_end] + postfix
else:
sign = "-"
rest = -rest
format_builder = [sign]
format_args = []
if start_field == DAY:
format_builder.append(str(rest // MICROS_PER_DAY))
rest %= MICROS_PER_DAY
elif start_field == HOUR:
format_builder.append("%02d")
format_args.append(rest // MICROS_PER_HOUR)
rest %= MICROS_PER_HOUR
elif start_field == MINUTE:
format_builder.append("%02d")
format_args.append(rest // MICROS_PER_MINUTE)
rest %= MICROS_PER_MINUTE
elif start_field == SECOND:
lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else ""
format_builder.append(
lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string()
)
if start_field < HOUR and HOUR <= end_field:
format_builder.append(" %02d")
format_args.append(rest // MICROS_PER_HOUR)
rest %= MICROS_PER_HOUR
if start_field < MINUTE and MINUTE <= end_field:
format_builder.append(":%02d")
format_args.append(rest // MICROS_PER_MINUTE)
rest %= MICROS_PER_MINUTE
if start_field < SECOND and SECOND <= end_field:
lead_zero = "0" if (rest < 10 * MICROS_PER_SECOND) else ""
format_builder.append(
":" + lead_zero + (Decimal(rest) / Decimal(1000000)).normalize().to_eng_string()
)
return prefix + ("".join(format_builder) % tuple(format_args)) + postfix
@classmethod
def _to_json(cls, value: bytes, metadata: bytes, pos: int, zone_id: str) -> str:
variant_type = cls._get_type(value, pos)
if variant_type == dict:
def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> str:
key_value_list = [
json.dumps(key) + ":" + cls._to_json(value, metadata, value_pos, zone_id)
for (key, value_pos) in key_value_pos_list
]
return "{" + ",".join(key_value_list) + "}"
return cls._handle_object(value, metadata, pos, handle_object)
elif variant_type == array:
def handle_array(value_pos_list: List[int]) -> str:
value_list = [
cls._to_json(value, metadata, value_pos, zone_id)
for value_pos in value_pos_list
]
return "[" + ",".join(value_list) + "]"
return cls._handle_array(value, pos, handle_array)
elif variant_type == datetime.timedelta:
micros, start_field, end_field = cls._get_dtinterval_info(value, pos)
return '"' + cls._to_day_time_interval_ansi_string(micros, start_field, end_field) + '"'
elif variant_type == cls._PlaceholderYearMonthIntervalInternalType:
months, start_field, end_field = cls._get_yminterval_info(value, pos)
return (
'"' + cls._to_year_month_interval_ansi_string(months, start_field, end_field) + '"'
)
else:
value = cls._get_scalar(variant_type, value, metadata, pos, zone_id)
if value is None:
return "null"
if type(value) == bool:
return "true" if value else "false"
if type(value) == str:
return json.dumps(value)
if type(value) == bytes:
# decoding simply converts byte array to string
return '"' + base64.b64encode(value).decode("utf-8") + '"'
if type(value) == datetime.date or type(value) == datetime.datetime:
return '"' + str(value) + '"'
return str(value)
@classmethod
def _to_python(cls, value: bytes, metadata: bytes, pos: int) -> Any:
variant_type = cls._get_type(value, pos)
if variant_type == dict:
def handle_object(key_value_pos_list: List[Tuple[str, int]]) -> Dict[str, Any]:
key_value_list = [
(key, cls._to_python(value, metadata, value_pos))
for (key, value_pos) in key_value_pos_list
]
return dict(key_value_list)
return cls._handle_object(value, metadata, pos, handle_object)
elif variant_type == array:
def handle_array(value_pos_list: List[int]) -> List[Any]:
value_list = [
cls._to_python(value, metadata, value_pos) for value_pos in value_pos_list
]
return value_list
return cls._handle_array(value, pos, handle_array)
elif variant_type == datetime.timedelta:
# day-time intervals are represented using timedelta in a trivial manner
return datetime.timedelta(microseconds=cls._get_dtinterval_info(value, pos)[0])
elif variant_type == cls._PlaceholderYearMonthIntervalInternalType:
raise PySparkNotImplementedError(
errorClass="NOT_IMPLEMENTED",
messageParameters={"feature": "VariantUtils.YEAR_MONTH_INTERVAL"},
)
else:
return cls._get_scalar(variant_type, value, metadata, pos, zone_id="UTC")
@classmethod
def _get_scalar(
cls, variant_type: Any, value: bytes, metadata: bytes, pos: int, zone_id: str
) -> Any:
if isinstance(None, variant_type):
return None
elif variant_type == bool:
return cls._get_boolean(value, pos)
elif variant_type == int:
return cls._get_long(value, pos)
elif variant_type == str:
return cls._get_string(value, pos)
elif variant_type == float:
return cls._get_double(value, pos)
elif variant_type == decimal.Decimal:
return cls._get_decimal(value, pos)
elif variant_type == bytes:
return cls._get_binary(value, pos)
elif variant_type == datetime.date:
return cls._get_date(value, pos)
elif variant_type == datetime.datetime:
return cls._get_timestamp(value, pos, zone_id)
else:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
@classmethod
def _handle_object(
cls, value: bytes, metadata: bytes, pos: int, func: Callable[[List[Tuple[str, int]]], Any]
) -> Any:
"""
Parses the variant object at position `pos`.
Calls `func` with a list of (key, value position) pairs of the object.
"""
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.OBJECT:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
large_size = ((type_info >> 4) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
id_size = ((type_info >> 2) & 0x3) + 1
offset_size = ((type_info) & 0x3) + 1
id_start = pos + 1 + size_bytes
offset_start = id_start + num_fields * id_size
data_start = offset_start + (num_fields + 1) * offset_size
key_value_pos_list = []
for i in range(num_fields):
id = cls._read_long(value, id_start + id_size * i, id_size, signed=False)
offset = cls._read_long(
value, offset_start + offset_size * i, offset_size, signed=False
)
value_pos = data_start + offset
key_value_pos_list.append((cls._get_metadata_key(metadata, id), value_pos))
return func(key_value_pos_list)
@classmethod
def _handle_array(cls, value: bytes, pos: int, func: Callable[[List[int]], Any]) -> Any:
"""
Parses the variant array at position `pos`.
Calls `func` with a list of element positions of the array.
"""
cls._check_index(pos, len(value))
basic_type, type_info = cls._get_type_info(value, pos)
if basic_type != VariantUtils.ARRAY:
raise PySparkValueError(errorClass="MALFORMED_VARIANT", messageParameters={})
large_size = ((type_info >> 2) & 0x1) != 0
size_bytes = VariantUtils.U32_SIZE if large_size else 1
num_fields = cls._read_long(value, pos + 1, size_bytes, signed=False)
offset_size = (type_info & 0x3) + 1
offset_start = pos + 1 + size_bytes
data_start = offset_start + (num_fields + 1) * offset_size
value_pos_list = []
for i in range(num_fields):
offset = cls._read_long(
value, offset_start + offset_size * i, offset_size, signed=False
)
element_pos = data_start + offset
value_pos_list.append(element_pos)
return func(value_pos_list)