| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| from pemja import findClass |
| |
| from pyflink.common.typeinfo import (TypeInformation, Types, BasicTypeInfo, BasicType, |
| PrimitiveArrayTypeInfo, BasicArrayTypeInfo, |
| ObjectArrayTypeInfo, MapTypeInfo) |
| from pyflink.datastream.state import (StateDescriptor, ValueStateDescriptor, |
| ReducingStateDescriptor, |
| AggregatingStateDescriptor, ListStateDescriptor, |
| MapStateDescriptor, StateTtlConfig) |
| |
| # Java Types Class |
| JTypes = findClass('org.apache.flink.api.common.typeinfo.Types') |
| JPrimitiveArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo') |
| JBasicArrayTypeInfo = findClass('org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo') |
| JPickledByteArrayTypeInfo = findClass('org.apache.flink.streaming.api.typeinfo.python.' |
| 'PickledByteArrayTypeInfo') |
| JMapTypeInfo = findClass('org.apache.flink.api.java.typeutils.MapTypeInfo') |
| |
| # Java State Descriptor Class |
| JValueStateDescriptor = findClass('org.apache.flink.api.common.state.ValueStateDescriptor') |
| JListStateDescriptor = findClass('org.apache.flink.api.common.state.ListStateDescriptor') |
| JMapStateDescriptor = findClass('org.apache.flink.api.common.state.MapStateDescriptor') |
| |
| # Java StateTtlConfig |
| JStateTtlConfig = findClass('org.apache.flink.api.common.state.StateTtlConfig') |
| JDuration = findClass('java.time.Duration') |
| JUpdateType = findClass('org.apache.flink.api.common.state.StateTtlConfig$UpdateType') |
| JStateVisibility = findClass('org.apache.flink.api.common.state.StateTtlConfig$StateVisibility') |
| |
| |
| def to_java_typeinfo(type_info: TypeInformation): |
| if isinstance(type_info, BasicTypeInfo): |
| basic_type = type_info._basic_type |
| |
| if basic_type == BasicType.STRING: |
| j_typeinfo = JTypes.STRING |
| elif basic_type == BasicType.BYTE: |
| j_typeinfo = JTypes.LONG |
| elif basic_type == BasicType.BOOLEAN: |
| j_typeinfo = JTypes.BOOLEAN |
| elif basic_type == BasicType.SHORT: |
| j_typeinfo = JTypes.LONG |
| elif basic_type == BasicType.INT: |
| j_typeinfo = JTypes.LONG |
| elif basic_type == BasicType.LONG: |
| j_typeinfo = JTypes.LONG |
| elif basic_type == BasicType.FLOAT: |
| j_typeinfo = JTypes.DOUBLE |
| elif basic_type == BasicType.DOUBLE: |
| j_typeinfo = JTypes.DOUBLE |
| elif basic_type == BasicType.CHAR: |
| j_typeinfo = JTypes.STRING |
| elif basic_type == BasicType.BIG_INT: |
| j_typeinfo = JTypes.BIG_INT |
| elif basic_type == BasicType.BIG_DEC: |
| j_typeinfo = JTypes.BIG_DEC |
| elif basic_type == BasicType.INSTANT: |
| j_typeinfo = JTypes.INSTANT |
| else: |
| raise TypeError("Invalid BasicType %s." % basic_type) |
| |
| elif isinstance(type_info, PrimitiveArrayTypeInfo): |
| element_type = type_info._element_type |
| |
| if element_type == Types.BOOLEAN(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.BYTE(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.SHORT(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.INT(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.LONG(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.FLOAT(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.DOUBLE(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO |
| elif element_type == Types.CHAR(): |
| j_typeinfo = JPrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO |
| else: |
| raise TypeError("Invalid element type for a primitive array.") |
| |
| elif isinstance(type_info, BasicArrayTypeInfo): |
| element_type = type_info._element_type |
| |
| if element_type == Types.BOOLEAN(): |
| j_typeinfo = JBasicArrayTypeInfo.BOOLEAN_ARRAY_TYPE_INFO |
| elif element_type == Types.BYTE(): |
| j_typeinfo = JBasicArrayTypeInfo.BYTE_ARRAY_TYPE_INFO |
| elif element_type == Types.SHORT(): |
| j_typeinfo = JBasicArrayTypeInfo.SHORT_ARRAY_TYPE_INFO |
| elif element_type == Types.INT(): |
| j_typeinfo = JBasicArrayTypeInfo.INT_ARRAY_TYPE_INFO |
| elif element_type == Types.LONG(): |
| j_typeinfo = JBasicArrayTypeInfo.LONG_ARRAY_TYPE_INFO |
| elif element_type == Types.FLOAT(): |
| j_typeinfo = JBasicArrayTypeInfo.FLOAT_ARRAY_TYPE_INFO |
| elif element_type == Types.DOUBLE(): |
| j_typeinfo = JBasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO |
| elif element_type == Types.CHAR(): |
| j_typeinfo = JBasicArrayTypeInfo.CHAR_ARRAY_TYPE_INFO |
| elif element_type == Types.STRING(): |
| j_typeinfo = JBasicArrayTypeInfo.STRING_ARRAY_TYPE_INFO |
| else: |
| raise TypeError("Invalid element type for a basic array.") |
| |
| elif isinstance(type_info, ObjectArrayTypeInfo): |
| element_type = type_info._element_type |
| |
| j_typeinfo = JTypes.OBJECT_ARRAY(to_java_typeinfo(element_type)) |
| |
| elif isinstance(type_info, MapTypeInfo): |
| j_key_typeinfo = to_java_typeinfo(type_info._key_type_info) |
| j_value_typeinfo = to_java_typeinfo(type_info._value_type_info) |
| |
| j_typeinfo = JMapTypeInfo(j_key_typeinfo, j_value_typeinfo) |
| else: |
| j_typeinfo = JPickledByteArrayTypeInfo.PICKLED_BYTE_ARRAY_TYPE_INFO |
| |
| return j_typeinfo |
| |
| |
| def to_java_state_ttl_config(ttl_config: StateTtlConfig): |
| j_ttl_config_builder = JStateTtlConfig.newBuilder( |
| JDuration.ofMillis(ttl_config.get_ttl().to_milliseconds())) |
| |
| update_type = ttl_config.get_update_type() |
| if update_type == StateTtlConfig.UpdateType.Disabled: |
| j_ttl_config_builder.setUpdateType(JUpdateType.Disabled) |
| elif update_type == StateTtlConfig.UpdateType.OnCreateAndWrite: |
| j_ttl_config_builder.setUpdateType(JUpdateType.OnCreateAndWrite) |
| elif update_type == StateTtlConfig.UpdateType.OnReadAndWrite: |
| j_ttl_config_builder.setUpdateType(JUpdateType.OnReadAndWrite) |
| |
| state_visibility = ttl_config.get_state_visibility() |
| if state_visibility == StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp: |
| j_ttl_config_builder.setStateVisibility(JStateVisibility.ReturnExpiredIfNotCleanedUp) |
| elif state_visibility == StateTtlConfig.StateVisibility.NeverReturnExpired: |
| j_ttl_config_builder.setStateVisibility(JStateVisibility.NeverReturnExpired) |
| |
| cleanup_strategies = ttl_config.get_cleanup_strategies() |
| if not cleanup_strategies.is_cleanup_in_background(): |
| j_ttl_config_builder.disableCleanupInBackground() |
| |
| if cleanup_strategies.in_full_snapshot(): |
| j_ttl_config_builder.cleanupFullSnapshot() |
| |
| incremental_cleanup_strategy = cleanup_strategies.get_incremental_cleanup_strategy() |
| if incremental_cleanup_strategy: |
| j_ttl_config_builder.cleanupIncrementally( |
| incremental_cleanup_strategy.get_cleanup_size(), |
| incremental_cleanup_strategy.run_cleanup_for_every_record()) |
| |
| rocksdb_compact_filter_cleanup_strategy = \ |
| cleanup_strategies.get_rocksdb_compact_filter_cleanup_strategy() |
| |
| if rocksdb_compact_filter_cleanup_strategy: |
| j_ttl_config_builder.cleanupInRocksdbCompactFilter( |
| rocksdb_compact_filter_cleanup_strategy.get_query_time_after_num_entries()) |
| |
| return j_ttl_config_builder.build() |
| |
| |
| def to_java_state_descriptor(state_descriptor: StateDescriptor): |
| if isinstance(state_descriptor, |
| (ValueStateDescriptor, ReducingStateDescriptor, AggregatingStateDescriptor)): |
| value_type_info = to_java_typeinfo(state_descriptor.type_info) |
| j_state_descriptor = JValueStateDescriptor(state_descriptor.name, value_type_info) |
| |
| elif isinstance(state_descriptor, ListStateDescriptor): |
| element_type_info = to_java_typeinfo(state_descriptor.type_info.elem_type) |
| j_state_descriptor = JListStateDescriptor(state_descriptor.name, element_type_info) |
| |
| elif isinstance(state_descriptor, MapStateDescriptor): |
| key_type_info = to_java_typeinfo(state_descriptor.type_info._key_type_info) |
| value_type_info = to_java_typeinfo(state_descriptor.type_info._value_type_info) |
| j_state_descriptor = JMapStateDescriptor( |
| state_descriptor.name, key_type_info, value_type_info) |
| else: |
| raise Exception("Unknown supported state_descriptor {0}".format(state_descriptor)) |
| |
| if state_descriptor._ttl_config: |
| j_state_ttl_config = to_java_state_ttl_config(state_descriptor._ttl_config) |
| j_state_descriptor.enableTimeToLive(j_state_ttl_config) |
| |
| return j_state_descriptor |