blob: 8ef5607219bd8d3df76f07f8b6a1335fe03137c9 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
from abc import ABC, abstractmethod
import pyarrow as pa
import pytz
from pyflink.common.typeinfo import TypeInformation, BasicTypeInfo, BasicType, DateTypeInfo, \
TimeTypeInfo, TimestampTypeInfo, PrimitiveArrayTypeInfo, BasicArrayTypeInfo, TupleTypeInfo, \
MapTypeInfo, ListTypeInfo, RowTypeInfo, PickledBytesTypeInfo, ObjectArrayTypeInfo, \
ExternalTypeInfo
from pyflink.fn_execution import flink_fn_execution_pb2
from pyflink.table.types import TinyIntType, SmallIntType, IntType, BigIntType, BooleanType, \
FloatType, DoubleType, VarCharType, VarBinaryType, DecimalType, DateType, TimeType, \
LocalZonedTimestampType, RowType, RowField, to_arrow_type, TimestampType, ArrayType
try:
from pyflink.fn_execution import coder_impl_fast as coder_impl
except:
from pyflink.fn_execution import coder_impl_slow as coder_impl
__all__ = ['FlattenRowCoder', 'RowCoder', 'BigIntCoder', 'TinyIntCoder', 'BooleanCoder',
'SmallIntCoder', 'IntCoder', 'FloatCoder', 'DoubleCoder', 'BinaryCoder', 'CharCoder',
'DateCoder', 'TimeCoder', 'TimestampCoder', 'LocalZonedTimestampCoder', 'InstantCoder',
'GenericArrayCoder', 'PrimitiveArrayCoder', 'MapCoder', 'DecimalCoder',
'BigDecimalCoder', 'TupleCoder', 'TimeWindowCoder', 'CountWindowCoder',
'PickleCoder', 'CloudPickleCoder', 'DataViewFilterCoder']
#########################################################################
# Top-level coder: ValueCoder & IterableCoder
#########################################################################
# LengthPrefixBaseCoder is the top level coder and the other coders will be used as the field coder
class LengthPrefixBaseCoder(ABC):
def __init__(self, field_coder: 'FieldCoder'):
self._field_coder = field_coder
@abstractmethod
def get_impl(self):
pass
@classmethod
def from_coder_info_descriptor_proto(cls, coder_info_descriptor_proto):
field_coder = cls._to_field_coder(coder_info_descriptor_proto)
mode = coder_info_descriptor_proto.mode
separated_with_end_message = coder_info_descriptor_proto.separated_with_end_message
if mode == flink_fn_execution_pb2.CoderInfoDescriptor.SINGLE:
return ValueCoder(field_coder)
else:
return IterableCoder(field_coder, separated_with_end_message)
@classmethod
def _to_field_coder(cls, coder_info_descriptor_proto):
if coder_info_descriptor_proto.HasField('flatten_row_type'):
schema_proto = coder_info_descriptor_proto.flatten_row_type.schema
field_coders = [from_proto(f.type) for f in schema_proto.fields]
return FlattenRowCoder(field_coders)
elif coder_info_descriptor_proto.HasField('row_type'):
schema_proto = coder_info_descriptor_proto.row_type.schema
field_coders = [from_proto(f.type) for f in schema_proto.fields]
field_names = [f.name for f in schema_proto.fields]
return RowCoder(field_coders, field_names)
elif coder_info_descriptor_proto.HasField('arrow_type'):
timezone = pytz.timezone(os.environ['table.exec.timezone'])
schema_proto = coder_info_descriptor_proto.arrow_type.schema
row_type = cls._to_row_type(schema_proto)
return ArrowCoder(cls._to_arrow_schema(row_type), row_type, timezone)
elif coder_info_descriptor_proto.HasField('over_window_arrow_type'):
timezone = pytz.timezone(os.environ['table.exec.timezone'])
schema_proto = coder_info_descriptor_proto.over_window_arrow_type.schema
row_type = cls._to_row_type(schema_proto)
return OverWindowArrowCoder(
cls._to_arrow_schema(row_type), row_type, timezone)
elif coder_info_descriptor_proto.HasField('raw_type'):
type_info_proto = coder_info_descriptor_proto.raw_type.type_info
field_coder = from_type_info_proto(type_info_proto)
return field_coder
else:
raise ValueError("Unexpected coder type %s" % coder_info_descriptor_proto)
@classmethod
def _to_arrow_schema(cls, row_type):
return pa.schema([pa.field(n, to_arrow_type(t), t._nullable)
for n, t in zip(row_type.field_names(), row_type.field_types())])
@classmethod
def _to_data_type(cls, field_type):
if field_type.type_name == flink_fn_execution_pb2.Schema.TINYINT:
return TinyIntType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.SMALLINT:
return SmallIntType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.INT:
return IntType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.BIGINT:
return BigIntType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.BOOLEAN:
return BooleanType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.FLOAT:
return FloatType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.DOUBLE:
return DoubleType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.VARCHAR:
return VarCharType(0x7fffffff, field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.VARBINARY:
return VarBinaryType(0x7fffffff, field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.DECIMAL:
return DecimalType(field_type.decimal_info.precision,
field_type.decimal_info.scale,
field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.DATE:
return DateType(field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.TIME:
return TimeType(field_type.time_info.precision, field_type.nullable)
elif field_type.type_name == \
flink_fn_execution_pb2.Schema.LOCAL_ZONED_TIMESTAMP:
return LocalZonedTimestampType(field_type.local_zoned_timestamp_info.precision,
field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.TIMESTAMP:
return TimestampType(field_type.timestamp_info.precision, field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.BASIC_ARRAY:
return ArrayType(cls._to_data_type(field_type.collection_element_type),
field_type.nullable)
elif field_type.type_name == flink_fn_execution_pb2.Schema.TypeName.ROW:
return RowType(
[RowField(f.name, cls._to_data_type(f.type), f.description)
for f in field_type.row_schema.fields], field_type.nullable)
else:
raise ValueError("field_type %s is not supported." % field_type)
@classmethod
def _to_row_type(cls, row_schema):
return RowType([RowField(f.name, cls._to_data_type(f.type)) for f in row_schema.fields])
class IterableCoder(LengthPrefixBaseCoder):
"""
Coder for iterable data.
"""
def __init__(self, field_coder: 'FieldCoder', separated_with_end_message):
super(IterableCoder, self).__init__(field_coder)
self._separated_with_end_message = separated_with_end_message
def get_impl(self):
return coder_impl.IterableCoderImpl(self._field_coder.get_impl(),
self._separated_with_end_message)
class ValueCoder(LengthPrefixBaseCoder):
"""
Coder for single data.
"""
def __init__(self, field_coder: 'FieldCoder'):
super(ValueCoder, self).__init__(field_coder)
def get_impl(self):
return coder_impl.ValueCoderImpl(self._field_coder.get_impl())
#########################################################################
# Low-level coder: FieldCoder
#########################################################################
class FieldCoder(ABC):
def get_impl(self) -> coder_impl.FieldCoderImpl:
pass
def __eq__(self, other):
return type(self) == type(other)
class FlattenRowCoder(FieldCoder):
"""
Coder for Row. The decoded result will be flattened as a list of column values of a row instead
of a row object.
"""
def __init__(self, field_coders):
self._field_coders = field_coders
def get_impl(self):
return coder_impl.FlattenRowCoderImpl([c.get_impl() for c in self._field_coders])
def __repr__(self):
return 'FlattenRowCoder[%s]' % ', '.join(str(c) for c in self._field_coders)
def __eq__(self, other: 'FlattenRowCoder'):
return (self.__class__ == other.__class__
and len(self._field_coders) == len(other._field_coders)
and [self._field_coders[i] == other._field_coders[i] for i in
range(len(self._field_coders))])
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self._field_coders)
class ArrowCoder(FieldCoder):
"""
Coder for Arrow.
"""
def __init__(self, schema, row_type, timezone):
self._schema = schema
self._row_type = row_type
self._timezone = timezone
def get_impl(self):
return coder_impl.ArrowCoderImpl(self._schema, self._row_type, self._timezone)
def __repr__(self):
return 'ArrowCoder[%s]' % self._schema
class OverWindowArrowCoder(FieldCoder):
"""
Coder for batch pandas over window aggregation.
"""
def __init__(self, schema, row_type, timezone):
self._arrow_coder = ArrowCoder(schema, row_type, timezone)
def get_impl(self):
return coder_impl.OverWindowArrowCoderImpl(self._arrow_coder.get_impl())
def __repr__(self):
return 'OverWindowArrowCoder[%s]' % self._arrow_coder
class RowCoder(FieldCoder):
"""
Coder for Row.
"""
def __init__(self, field_coders, field_names):
self._field_coders = field_coders
self._field_names = field_names
def get_impl(self):
return coder_impl.RowCoderImpl([c.get_impl() for c in self._field_coders],
self._field_names)
def __repr__(self):
return 'RowCoder[%s]' % ', '.join(str(c) for c in self._field_coders)
def __eq__(self, other: 'RowCoder'):
return (self.__class__ == other.__class__
and self._field_names == other._field_names
and [self._field_coders[i] == other._field_coders[i] for i in
range(len(self._field_coders))])
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self._field_coders)
class CollectionCoder(FieldCoder):
"""
Base coder for collection.
"""
def __init__(self, elem_coder):
self._elem_coder = elem_coder
def is_deterministic(self):
return self._elem_coder.is_deterministic()
def __eq__(self, other: 'CollectionCoder'):
return (self.__class__ == other.__class__
and self._elem_coder == other._elem_coder)
def __repr__(self):
return '%s[%s]' % (self.__class__.__name__, repr(self._elem_coder))
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self._elem_coder)
class GenericArrayCoder(CollectionCoder):
"""
Coder for generic array such as basic array or object array.
"""
def __init__(self, elem_coder):
super(GenericArrayCoder, self).__init__(elem_coder)
def get_impl(self):
return coder_impl.GenericArrayCoderImpl(self._elem_coder.get_impl())
class PrimitiveArrayCoder(CollectionCoder):
"""
Coder for Primitive Array.
"""
def __init__(self, elem_coder):
super(PrimitiveArrayCoder, self).__init__(elem_coder)
def get_impl(self):
return coder_impl.PrimitiveArrayCoderImpl(self._elem_coder.get_impl())
class MapCoder(FieldCoder):
"""
Coder for Map.
"""
def __init__(self, key_coder, value_coder):
self._key_coder = key_coder
self._value_coder = value_coder
def get_impl(self):
return coder_impl.MapCoderImpl(self._key_coder.get_impl(), self._value_coder.get_impl())
def is_deterministic(self):
return self._key_coder.is_deterministic() and self._value_coder.is_deterministic()
def __repr__(self):
return 'MapCoder[%s]' % ','.join([repr(self._key_coder), repr(self._value_coder)])
def __eq__(self, other: 'MapCoder'):
return (self.__class__ == other.__class__
and self._key_coder == other._key_coder
and self._value_coder == other._value_coder)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash([self._key_coder, self._value_coder])
class BigIntCoder(FieldCoder):
"""
Coder for 8 bytes long.
"""
def get_impl(self):
return coder_impl.BigIntCoderImpl()
class TinyIntCoder(FieldCoder):
"""
Coder for Byte.
"""
def get_impl(self):
return coder_impl.TinyIntCoderImpl()
class BooleanCoder(FieldCoder):
"""
Coder for Boolean.
"""
def get_impl(self):
return coder_impl.BooleanCoderImpl()
class SmallIntCoder(FieldCoder):
"""
Coder for Short.
"""
def get_impl(self):
return coder_impl.SmallIntCoderImpl()
class IntCoder(FieldCoder):
"""
Coder for 4 bytes int.
"""
def get_impl(self):
return coder_impl.IntCoderImpl()
class FloatCoder(FieldCoder):
"""
Coder for Float.
"""
def get_impl(self):
return coder_impl.FloatCoderImpl()
class DoubleCoder(FieldCoder):
"""
Coder for Double.
"""
def get_impl(self):
return coder_impl.DoubleCoderImpl()
class DecimalCoder(FieldCoder):
"""
Coder for Decimal.
"""
def __init__(self, precision, scale):
self.precision = precision
self.scale = scale
def get_impl(self):
return coder_impl.DecimalCoderImpl(self.precision, self.scale)
def __eq__(self, other: 'DecimalCoder'):
return (self.__class__ == other.__class__ and
self.precision == other.precision and
self.scale == other.scale)
class BigDecimalCoder(FieldCoder):
"""
Coder for Basic Decimal that no need to have precision and scale specified.
"""
def get_impl(self):
return coder_impl.BigDecimalCoderImpl()
class BinaryCoder(FieldCoder):
"""
Coder for Byte Array.
"""
def get_impl(self):
return coder_impl.BinaryCoderImpl()
class CharCoder(FieldCoder):
"""
Coder for Character String.
"""
def get_impl(self):
return coder_impl.CharCoderImpl()
class DateCoder(FieldCoder):
"""
Coder for Date
"""
def get_impl(self):
return coder_impl.DateCoderImpl()
class TimeCoder(FieldCoder):
"""
Coder for Time.
"""
def get_impl(self):
return coder_impl.TimeCoderImpl()
class TimestampCoder(FieldCoder):
"""
Coder for Timestamp.
"""
def __init__(self, precision):
self.precision = precision
def get_impl(self):
return coder_impl.TimestampCoderImpl(self.precision)
def __eq__(self, other: 'TimestampCoder'):
return self.__class__ == other.__class__ and self.precision == other.precision
class LocalZonedTimestampCoder(FieldCoder):
"""
Coder for LocalZonedTimestamp.
"""
def __init__(self, precision, timezone):
self.precision = precision
self.timezone = timezone
def get_impl(self):
return coder_impl.LocalZonedTimestampCoderImpl(self.precision, self.timezone)
def __eq__(self, other: 'LocalZonedTimestampCoder'):
return (self.__class__ == other.__class__ and
self.precision == other.precision and
self.timezone == other.timezone)
class InstantCoder(FieldCoder):
"""
Coder for Instant.
"""
def get_impl(self) -> coder_impl.FieldCoderImpl:
return coder_impl.InstantCoderImpl()
class CloudPickleCoder(FieldCoder):
"""
Coder used with cloudpickle to encode python object.
"""
def get_impl(self):
return coder_impl.CloudPickleCoderImpl()
class PickleCoder(FieldCoder):
"""
Coder used with pickle to encode python object.
"""
def get_impl(self):
return coder_impl.PickleCoderImpl()
class TupleCoder(FieldCoder):
"""
Coder for Tuple.
"""
def __init__(self, field_coders):
self._field_coders = field_coders
def get_impl(self):
return coder_impl.TupleCoderImpl([c.get_impl() for c in self._field_coders])
def __repr__(self):
return 'TupleCoder[%s]' % ', '.join(str(c) for c in self._field_coders)
def __eq__(self, other: 'TupleCoder'):
return (self.__class__ == other.__class__ and
[self._field_coders[i] == other._field_coders[i]
for i in range(len(self._field_coders))])
class TimeWindowCoder(FieldCoder):
"""
Coder for TimeWindow.
"""
def get_impl(self):
return coder_impl.TimeWindowCoderImpl()
class CountWindowCoder(FieldCoder):
"""
Coder for CountWindow.
"""
def get_impl(self):
return coder_impl.CountWindowCoderImpl()
class DataViewFilterCoder(FieldCoder):
"""
Coder for data view filter.
"""
def __init__(self, udf_data_view_specs):
self._udf_data_view_specs = udf_data_view_specs
def get_impl(self):
return coder_impl.DataViewFilterCoderImpl(self._udf_data_view_specs)
type_name = flink_fn_execution_pb2.Schema
_type_name_mappings = {
type_name.TINYINT: TinyIntCoder(),
type_name.SMALLINT: SmallIntCoder(),
type_name.INT: IntCoder(),
type_name.BIGINT: BigIntCoder(),
type_name.BOOLEAN: BooleanCoder(),
type_name.FLOAT: FloatCoder(),
type_name.DOUBLE: DoubleCoder(),
type_name.BINARY: BinaryCoder(),
type_name.VARBINARY: BinaryCoder(),
type_name.CHAR: CharCoder(),
type_name.VARCHAR: CharCoder(),
type_name.DATE: DateCoder(),
type_name.TIME: TimeCoder(),
}
def from_proto(field_type):
"""
Creates the corresponding :class:`Coder` given the protocol representation of the field type.
:param field_type: the protocol representation of the field type
:return: :class:`Coder`
"""
field_type_name = field_type.type_name
coder = _type_name_mappings.get(field_type_name)
if coder is not None:
return coder
if field_type_name == type_name.ROW:
return RowCoder([from_proto(f.type) for f in field_type.row_schema.fields],
[f.name for f in field_type.row_schema.fields])
if field_type_name == type_name.TIMESTAMP:
return TimestampCoder(field_type.timestamp_info.precision)
if field_type_name == type_name.LOCAL_ZONED_TIMESTAMP:
timezone = pytz.timezone(os.environ['table.exec.timezone'])
return LocalZonedTimestampCoder(field_type.local_zoned_timestamp_info.precision, timezone)
elif field_type_name == type_name.BASIC_ARRAY:
return GenericArrayCoder(from_proto(field_type.collection_element_type))
elif field_type_name == type_name.MAP:
return MapCoder(from_proto(field_type.map_info.key_type),
from_proto(field_type.map_info.value_type))
elif field_type_name == type_name.DECIMAL:
return DecimalCoder(field_type.decimal_info.precision,
field_type.decimal_info.scale)
else:
raise ValueError("field_type %s is not supported." % field_type)
# for data stream type information.
type_info_name = flink_fn_execution_pb2.TypeInfo
_type_info_name_mappings = {
type_info_name.STRING: CharCoder(),
type_info_name.BYTE: TinyIntCoder(),
type_info_name.BOOLEAN: BooleanCoder(),
type_info_name.SHORT: SmallIntCoder(),
type_info_name.INT: IntCoder(),
type_info_name.LONG: BigIntCoder(),
type_info_name.FLOAT: FloatCoder(),
type_info_name.DOUBLE: DoubleCoder(),
type_info_name.CHAR: CharCoder(),
type_info_name.BIG_INT: BigIntCoder(),
type_info_name.BIG_DEC: BigDecimalCoder(),
type_info_name.SQL_DATE: DateCoder(),
type_info_name.SQL_TIME: TimeCoder(),
type_info_name.SQL_TIMESTAMP: TimestampCoder(3),
type_info_name.PICKLED_BYTES: CloudPickleCoder(),
type_info_name.INSTANT: InstantCoder()
}
def from_type_info_proto(type_info):
field_type_name = type_info.type_name
try:
return _type_info_name_mappings[field_type_name]
except KeyError:
if field_type_name == type_info_name.ROW:
return RowCoder(
[from_type_info_proto(f.field_type) for f in type_info.row_type_info.fields],
[f.field_name for f in type_info.row_type_info.fields])
elif field_type_name == type_info_name.PRIMITIVE_ARRAY:
if type_info.collection_element_type.type_name == type_info_name.BYTE:
return BinaryCoder()
return PrimitiveArrayCoder(from_type_info_proto(type_info.collection_element_type))
elif field_type_name in (type_info_name.BASIC_ARRAY,
type_info_name.OBJECT_ARRAY,
type_info_name.LIST):
return GenericArrayCoder(from_type_info_proto(type_info.collection_element_type))
elif field_type_name == type_info_name.TUPLE:
return TupleCoder([from_type_info_proto(field_type)
for field_type in type_info.tuple_type_info.field_types])
elif field_type_name == type_info_name.MAP:
return MapCoder(from_type_info_proto(type_info.map_type_info.key_type),
from_type_info_proto(type_info.map_type_info.value_type))
else:
raise ValueError("Unsupported type_info %s." % type_info)
_basic_type_info_mappings = {
BasicType.BYTE: TinyIntCoder(),
BasicType.BOOLEAN: BooleanCoder(),
BasicType.SHORT: SmallIntCoder(),
BasicType.INT: IntCoder(),
BasicType.LONG: BigIntCoder(),
BasicType.BIG_INT: BigIntCoder(),
BasicType.FLOAT: FloatCoder(),
BasicType.DOUBLE: DoubleCoder(),
BasicType.STRING: CharCoder(),
BasicType.CHAR: CharCoder(),
BasicType.BIG_DEC: BigDecimalCoder(),
BasicType.INSTANT: InstantCoder()
}
def from_type_info(type_info: TypeInformation) -> FieldCoder:
"""
Mappings from type_info to Coder
"""
if isinstance(type_info, PickledBytesTypeInfo):
return PickleCoder()
elif isinstance(type_info, BasicTypeInfo):
return _basic_type_info_mappings[type_info._basic_type]
elif isinstance(type_info, DateTypeInfo):
return DateCoder()
elif isinstance(type_info, TimeTypeInfo):
return TimeCoder()
elif isinstance(type_info, TimestampTypeInfo):
return TimestampCoder(3)
elif isinstance(type_info, PrimitiveArrayTypeInfo):
element_type = type_info._element_type
if isinstance(element_type, BasicTypeInfo) and element_type._basic_type == BasicType.BYTE:
return BinaryCoder()
else:
return PrimitiveArrayCoder(from_type_info(element_type))
elif isinstance(type_info, (BasicArrayTypeInfo, ObjectArrayTypeInfo)):
return GenericArrayCoder(from_type_info(type_info._element_type))
elif isinstance(type_info, ListTypeInfo):
return GenericArrayCoder(from_type_info(type_info.elem_type))
elif isinstance(type_info, MapTypeInfo):
return MapCoder(
from_type_info(type_info._key_type_info), from_type_info(type_info._value_type_info))
elif isinstance(type_info, TupleTypeInfo):
return TupleCoder([from_type_info(field_type)
for field_type in type_info.get_field_types()])
elif isinstance(type_info, RowTypeInfo):
return RowCoder(
[from_type_info(f) for f in type_info.get_field_types()],
[f for f in type_info.get_field_names()])
elif isinstance(type_info, ExternalTypeInfo):
return from_type_info(type_info._type_info)
else:
raise ValueError("Unsupported type_info %s." % type_info)