blob: 438aac83dfc3bed51ac110f58794c8963ba4bc2f [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Union
from apache_beam.coders import PickleCoder
from pyflink.datastream.state import ListState, MapState
from pyflink.fn_execution.coders import from_proto
from pyflink.fn_execution.internal_state import InternalListState, InternalMapState
from pyflink.fn_execution.operation_utils import is_built_in_function, load_aggregate_function
from pyflink.fn_execution.state_impl import RemoteKeyedStateBackend
from pyflink.table import FunctionContext
from pyflink.table.data_view import ListView, MapView, DataView
def extract_data_view_specs_from_accumulator(current_index, accumulator):
# for built in functions we extract the data view specs from their accumulator
i = -1
extracted_specs = []
for field in accumulator:
i += 1
# TODO: infer the coder from the input types and output type of the built-in functions
if isinstance(field, MapView):
extracted_specs.append(MapViewSpec(
"builtInAgg%df%d" % (current_index, i), i, PickleCoder(), PickleCoder()))
elif isinstance(field, ListView):
extracted_specs.append(ListViewSpec(
"builtInAgg%df%d" % (current_index, i), i, PickleCoder()))
return extracted_specs
def extract_data_view_specs(udfs):
extracted_udf_data_view_specs = []
current_index = -1
for udf in udfs:
current_index += 1
udf_data_view_specs_proto = udf.specs
if not udf_data_view_specs_proto:
if is_built_in_function(udf.payload):
built_in_function = load_aggregate_function(udf.payload)
accumulator = built_in_function.create_accumulator()
extracted_udf_data_view_specs.append(
extract_data_view_specs_from_accumulator(current_index, accumulator))
else:
extracted_udf_data_view_specs.append([])
else:
extracted_specs = []
for spec_proto in udf_data_view_specs_proto:
state_id = spec_proto.name
field_index = spec_proto.field_index
if spec_proto.HasField("list_view"):
element_coder = from_proto(spec_proto.list_view.element_type)
extracted_specs.append(ListViewSpec(state_id, field_index, element_coder))
elif spec_proto.HasField("map_view"):
key_coder = from_proto(spec_proto.map_view.key_type)
value_coder = from_proto(spec_proto.map_view.value_type)
extracted_specs.append(
MapViewSpec(state_id, field_index, key_coder, value_coder))
else:
raise Exception("Unsupported data view spec type: " + spec_proto.type)
extracted_udf_data_view_specs.append(extracted_specs)
if all([len(i) == 0 for i in extracted_udf_data_view_specs]):
return []
return extracted_udf_data_view_specs
N = TypeVar('N')
class StateDataView(DataView, Generic[N]):
@abstractmethod
def set_current_namespace(self, namespace: N):
"""
Sets current namespace for state.
"""
pass
class StateListView(ListView, StateDataView[N], ABC):
def __init__(self, list_state: Union[ListState, InternalListState]):
super().__init__()
self._list_state = list_state
def get(self):
return self._list_state.get()
def add(self, value):
self._list_state.add(value)
def add_all(self, values):
self._list_state.add_all(values)
def clear(self):
self._list_state.clear()
def __hash__(self) -> int:
return hash([i for i in self.get()])
class KeyedStateListView(StateListView[N]):
"""
KeyedStateListView is a default implementation of StateListView whose underlying
representation is a keyed state.
"""
def __init__(self, list_state: ListState):
super(KeyedStateListView, self).__init__(list_state)
def set_current_namespace(self, namespace: N):
raise Exception("KeyedStateListView doesn't support set_current_namespace")
class NamespacedStateListView(StateListView[N]):
"""
NamespacedStateListView is a StateListView whose underlying representation is a keyed and
namespaced state. It also supports changing current namespace.
"""
def __init__(self, list_state: InternalListState):
super(NamespacedStateListView, self).__init__(list_state)
def set_current_namespace(self, namespace: N):
self._list_state.set_current_namespace(namespace)
class StateMapView(MapView, StateDataView[N], ABC):
def __init__(self, map_state: Union[MapState, InternalMapState]):
super().__init__()
self._map_state = map_state
def get(self, key):
return self._map_state.get(key)
def put(self, key, value) -> None:
self._map_state.put(key, value)
def put_all(self, dict_value) -> None:
self._map_state.put_all(dict_value)
def remove(self, key) -> None:
self._map_state.remove(key)
def contains(self, key) -> bool:
return self._map_state.contains(key)
def items(self):
return self._map_state.items()
def keys(self):
return self._map_state.keys()
def values(self):
return self._map_state.values()
def is_empty(self) -> bool:
return self._map_state.is_empty()
def clear(self) -> None:
return self._map_state.clear()
class KeyedStateMapView(StateMapView[N]):
"""
KeyedStateMapView is a default implementation of StateMapView whose underlying
representation is a keyed state.
"""
def __init__(self, map_state: MapState):
super(KeyedStateMapView, self).__init__(map_state)
def set_current_namespace(self, namespace: N):
raise Exception("KeyedStateMapView doesn't support set_current_namespace")
class NamespacedStateMapView(StateMapView[N]):
"""
NamespacedStateMapView is a StateMapView whose underlying representation is a keyed and
namespaced state. It also supports changing current namespace.
"""
def __init__(self, map_state: InternalMapState):
super(NamespacedStateMapView, self).__init__(map_state)
def set_current_namespace(self, namespace: N):
self._map_state.set_current_namespace(namespace)
class DataViewSpec(object):
def __init__(self, state_id, field_index):
self.state_id = state_id
self.field_index = field_index
class ListViewSpec(DataViewSpec):
def __init__(self, state_id, field_index, element_coder):
super(ListViewSpec, self).__init__(state_id, field_index)
self.element_coder = element_coder
class MapViewSpec(DataViewSpec):
def __init__(self, state_id, field_index, key_coder, value_coder):
super(MapViewSpec, self).__init__(state_id, field_index)
self.key_coder = key_coder
self.value_coder = value_coder
class StateDataViewStore(ABC):
"""
This interface contains methods for registering StateDataView with a managed store.
"""
def __init__(self,
function_context: FunctionContext,
keyed_state_backend: RemoteKeyedStateBackend):
self._function_context = function_context
self._keyed_state_backend = keyed_state_backend
def get_runtime_context(self):
return self._function_context
@abstractmethod
def get_state_list_view(self, state_name, element_coder):
"""
Creates a state list view.
:param state_name: The name of underlying state of the list view.
:param element_coder: The element coder
:return: a keyed list state
"""
pass
@abstractmethod
def get_state_map_view(self, state_name, key_coder, value_coder):
"""
Creates a state map view.
:param state_name: The name of underlying state of the map view.
:param key_coder: The key coder
:param value_coder: The value coder
:return: a keyed map state
"""
pass
class PerKeyStateDataViewStore(StateDataViewStore):
"""
Default implementation of StateDataViewStore.
"""
def __init__(self,
function_context: FunctionContext,
keyed_state_backend: RemoteKeyedStateBackend):
super(PerKeyStateDataViewStore, self).__init__(function_context, keyed_state_backend)
def get_state_list_view(self, state_name, element_coder):
return KeyedStateListView(
self._keyed_state_backend.get_list_state(state_name, element_coder))
def get_state_map_view(self, state_name, key_coder, value_coder):
return KeyedStateMapView(
self._keyed_state_backend.get_map_state(state_name, key_coder, value_coder))
class PerWindowStateDataViewStore(StateDataViewStore):
"""
An implementation of StateDataViewStore for window aggregates which forwards the state
registration to an underlying RemoteKeyedStateBackend. The state created by this store has the
ability to switch window namespaces.
"""
def __init__(self,
function_context: FunctionContext,
keyed_state_backend: RemoteKeyedStateBackend):
super(PerWindowStateDataViewStore, self).__init__(function_context, keyed_state_backend)
def get_state_list_view(self, state_name, element_coder):
return NamespacedStateListView(
self._keyed_state_backend.get_list_state(state_name, element_coder))
def get_state_map_view(self, state_name, key_coder, value_coder):
return NamespacedStateMapView(
self._keyed_state_backend.get_map_state(state_name, key_coder, value_coder))