[FLINK-19102] [core, sdk] Make StateBinder a per-FunctionType entity
The goal of this commit is to make StateBinders know the FunctionType,
with each state binder instance specifically binding state for a single
FunctionType.
The side effect of this is also moving FunctionType out of
PersistedStateRegistry, as that information is encapsulated to the
StateBinder it receives at runtime.
This closes #137.
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
index 9c42cdf..abce51a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
@@ -35,7 +35,6 @@
import org.apache.flink.statefun.flink.core.metrics.FlinkMetricsFactory;
import org.apache.flink.statefun.flink.core.metrics.MetricsFactory;
import org.apache.flink.statefun.flink.core.state.FlinkState;
-import org.apache.flink.statefun.flink.core.state.FlinkStateBinder;
import org.apache.flink.statefun.flink.core.state.State;
import org.apache.flink.statefun.flink.core.types.DynamicallyRegisteredTypes;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
@@ -94,7 +93,6 @@
container.add("applying-context", ApplyingContext.class, ReusableContext.class);
container.add(LocalSink.class);
container.add("function-loader", FunctionLoader.class, PredefinedFunctionLoader.class);
- container.add(FlinkStateBinder.class);
container.add(Reductions.class);
container.add(LocalFunctionGroup.class);
container.add("metrics-factory", MetricsFactory.class, new FlinkMetricsFactory(metricGroup));
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
index 0187e57..af512d0 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
@@ -27,11 +27,12 @@
import org.apache.flink.statefun.flink.core.metrics.MetricsFactory;
import org.apache.flink.statefun.flink.core.state.FlinkStateBinder;
import org.apache.flink.statefun.flink.core.state.PersistedStates;
+import org.apache.flink.statefun.flink.core.state.State;
import org.apache.flink.statefun.sdk.FunctionType;
final class StatefulFunctionRepository implements FunctionRepository {
private final ObjectOpenHashMap<FunctionType, StatefulFunction> instances;
- private final FlinkStateBinder stateBinder;
+ private final State flinkState;
private final FunctionLoader functionLoader;
private final MetricsFactory metricsFactory;
private final MessageFactory messageFactory;
@@ -40,12 +41,12 @@
StatefulFunctionRepository(
@Label("function-loader") FunctionLoader functionLoader,
@Label("metrics-factory") MetricsFactory metricsFactory,
- MessageFactory messageFactory,
- FlinkStateBinder stateBinder) {
+ @Label("state") State state,
+ MessageFactory messageFactory) {
this.instances = new ObjectOpenHashMap<>();
- this.stateBinder = Objects.requireNonNull(stateBinder);
this.functionLoader = Objects.requireNonNull(functionLoader);
this.metricsFactory = Objects.requireNonNull(metricsFactory);
+ this.flinkState = Objects.requireNonNull(state);
this.messageFactory = Objects.requireNonNull(messageFactory);
}
@@ -62,7 +63,8 @@
org.apache.flink.statefun.sdk.StatefulFunction statefulFunction =
functionLoader.load(functionType);
try (SetContextClassLoader ignored = new SetContextClassLoader(statefulFunction)) {
- PersistedStates.findAndBind(functionType, statefulFunction, stateBinder);
+ FlinkStateBinder stateBinderForType = new FlinkStateBinder(flinkState, functionType);
+ PersistedStates.findAndBind(statefulFunction, stateBinderForType);
FunctionTypeMetrics metrics = metricsFactory.forType(functionType);
return new StatefulFunction(statefulFunction, metrics, messageFactory);
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
index 839e601..6a0421d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
@@ -19,7 +19,6 @@
import java.util.Objects;
import org.apache.flink.statefun.flink.core.di.Inject;
-import org.apache.flink.statefun.flink.core.di.Label;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.state.Accessor;
import org.apache.flink.statefun.sdk.state.ApiExtension;
@@ -30,30 +29,35 @@
import org.apache.flink.statefun.sdk.state.StateBinder;
import org.apache.flink.statefun.sdk.state.TableAccessor;
+/**
+ * A {@link StateBinder} that binds persisted state objects to Flink state for a specific {@link
+ * FunctionType}.
+ */
public final class FlinkStateBinder extends StateBinder {
private final State state;
+ private final FunctionType functionType;
@Inject
- public FlinkStateBinder(@Label("state") State state) {
+ public FlinkStateBinder(State state, FunctionType functionType) {
this.state = Objects.requireNonNull(state);
+ this.functionType = Objects.requireNonNull(functionType);
}
@Override
- public void bindValue(PersistedValue<?> persistedValue, FunctionType functionType) {
+ public void bindValue(PersistedValue<?> persistedValue) {
Accessor<?> accessor = state.createFlinkStateAccessor(functionType, persistedValue);
setAccessorRaw(persistedValue, accessor);
}
@Override
- public void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType) {
+ public void bindTable(PersistedTable<?, ?> persistedTable) {
TableAccessor<?, ?> accessor =
state.createFlinkStateTableAccessor(functionType, persistedTable);
setAccessorRaw(persistedTable, accessor);
}
@Override
- public void bindAppendingBuffer(
- PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType) {
+ public void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer) {
AppendingBufferAccessor<?> accessor =
state.createFlinkStateAppendingBufferAccessor(functionType, persistedAppendingBuffer);
setAccessorRaw(persistedAppendingBuffer, accessor);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
index 6e57c5b..968e4b5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
@@ -26,7 +26,6 @@
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
-import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.annotations.Persisted;
import org.apache.flink.statefun.sdk.state.ApiExtension;
import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
@@ -36,15 +35,14 @@
public final class PersistedStates {
- public static void findAndBind(
- FunctionType functionType, @Nullable Object instance, FlinkStateBinder stateBinder) {
+ public static void findAndBind(@Nullable Object instance, FlinkStateBinder stateBinder) {
List<?> states = findReflectively(instance);
for (Object persisted : states) {
if (persisted instanceof PersistedStateRegistry) {
PersistedStateRegistry stateRegistry = (PersistedStateRegistry) persisted;
- ApiExtension.bindPersistedStateRegistry(stateRegistry, stateBinder, functionType);
+ ApiExtension.bindPersistedStateRegistry(stateRegistry, stateBinder);
} else {
- stateBinder.bind(persisted, functionType);
+ stateBinder.bind(persisted);
}
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
index 30e9657..f107aeb 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
@@ -17,8 +17,6 @@
*/
package org.apache.flink.statefun.sdk.state;
-import org.apache.flink.statefun.sdk.FunctionType;
-
public class ApiExtension {
public static <T> void setPersistedValueAccessor(
PersistedValue<T> persistedValue, Accessor<T> accessor) {
@@ -36,9 +34,7 @@
}
public static void bindPersistedStateRegistry(
- PersistedStateRegistry persistedStateRegistry,
- StateBinder stateBinder,
- FunctionType functionType) {
- persistedStateRegistry.bind(stateBinder, functionType);
+ PersistedStateRegistry persistedStateRegistry, StateBinder stateBinder) {
+ persistedStateRegistry.bind(stateBinder);
}
}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
index 5e5d4fb..adddcf5 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
@@ -47,52 +47,50 @@
private final FakeState state = new FakeState();
// object under test
- private final FlinkStateBinder binderUnderTest = new FlinkStateBinder(state);
+ private final FlinkStateBinder binderUnderTest =
+ new FlinkStateBinder(state, TestUtils.FUNCTION_TYPE);
@Test
public void exampleUsage() {
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new SanityClass(), binderUnderTest);
+ PersistedStates.findAndBind(new SanityClass(), binderUnderTest);
assertThat(state.boundNames, hasItems("name", "last"));
}
@Test(expected = IllegalStateException.class)
public void nullValueField() {
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new NullValueClass(), binderUnderTest);
+ PersistedStates.findAndBind(new NullValueClass(), binderUnderTest);
}
@Test
public void nonAnnotatedClass() {
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new IgnoreNonAnnotated(), binderUnderTest);
+ PersistedStates.findAndBind(new IgnoreNonAnnotated(), binderUnderTest);
assertTrue(state.boundNames.isEmpty());
}
@Test
public void extendedClass() {
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new ChildClass(), binderUnderTest);
+ PersistedStates.findAndBind(new ChildClass(), binderUnderTest);
assertThat(state.boundNames, hasItems("parent", "child"));
}
@Test(expected = IllegalArgumentException.class)
public void staticPersistedFieldsAreNotAllowed() {
- PersistedStates.findAndBind(
- TestUtils.FUNCTION_TYPE, new StaticPersistedValue(), binderUnderTest);
+ PersistedStates.findAndBind(new StaticPersistedValue(), binderUnderTest);
}
@Test
public void bindPersistedTable() {
- PersistedStates.findAndBind(
- TestUtils.FUNCTION_TYPE, new PersistedTableValue(), binderUnderTest);
+ PersistedStates.findAndBind(new PersistedTableValue(), binderUnderTest);
assertThat(state.boundNames, hasItems("table"));
}
@Test
public void bindPersistedAppendingBuffer() {
- PersistedStates.findAndBind(
- TestUtils.FUNCTION_TYPE, new PersistedAppendingBufferState(), binderUnderTest);
+ PersistedStates.findAndBind(new PersistedAppendingBufferState(), binderUnderTest);
assertThat(state.boundNames, hasItems("buffer"));
}
@@ -100,7 +98,7 @@
@Test
public void bindDynamicState() {
DynamicState dynamicState = new DynamicState();
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, dynamicState, binderUnderTest);
+ PersistedStates.findAndBind(dynamicState, binderUnderTest);
dynamicState.process();
@@ -117,7 +115,7 @@
@Test
public void bindComposedState() {
- PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new OuterClass(), binderUnderTest);
+ PersistedStates.findAndBind(new OuterClass(), binderUnderTest);
assertThat(state.boundNames, hasItems("inner"));
}
diff --git a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
index 304a4df..2825f75 100644
--- a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
+++ b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
@@ -93,14 +93,14 @@
void initialize(State stateAccessor) {
this.registry = new HashMap<>(stateBootstrapFunctionProviders.size());
- final FlinkStateBinder stateBinder = new FlinkStateBinder(stateAccessor);
for (Map.Entry<SerializableFunctionType, StateBootstrapFunctionProvider> entry :
stateBootstrapFunctionProviders.entrySet()) {
final FunctionType functionType = entry.getKey().toNonSerializable();
final StateBootstrapFunction bootstrapFunction =
entry.getValue().bootstrapFunctionOfType(functionType);
+ final FlinkStateBinder stateBinder = new FlinkStateBinder(stateAccessor, functionType);
- registry.put(functionType, bindState(functionType, bootstrapFunction, stateBinder));
+ registry.put(functionType, bindState(bootstrapFunction, stateBinder));
}
}
@@ -115,11 +115,9 @@
}
private static StateBootstrapFunction bindState(
- FunctionType functionType,
- StateBootstrapFunction bootstrapFunction,
- FlinkStateBinder stateBinder) {
+ StateBootstrapFunction bootstrapFunction, FlinkStateBinder stateBinder) {
try (SetContextClassLoader ignored = new SetContextClassLoader(bootstrapFunction)) {
- PersistedStates.findAndBind(functionType, bootstrapFunction, stateBinder);
+ PersistedStates.findAndBind(bootstrapFunction, stateBinder);
return bootstrapFunction;
}
}
diff --git a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
index b806df1..bce8364 100644
--- a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
+++ b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
@@ -21,8 +21,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
-import javax.annotation.Nullable;
-import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.StatefulFunction;
import org.apache.flink.statefun.sdk.annotations.ForRuntime;
import org.apache.flink.statefun.sdk.annotations.Persisted;
@@ -44,12 +42,6 @@
private StateBinder stateBinder;
- /**
- * The type of the function that this registry is bound to. This is {@code NULL} if this registry
- * is not bounded.
- */
- @Nullable private FunctionType functionType;
-
public PersistedStateRegistry() {
this.stateBinder = new NonFaultTolerantStateBinder();
}
@@ -96,23 +88,24 @@
* will also be bound to the system.
*
* @param stateBinder the new fault-tolerant state binder to use.
- * @param functionType the type of the function that this registry is being bound to.
* @throws IllegalStateException if the registry was attempted to be bound more than once.
*/
@ForRuntime
- void bind(StateBinder stateBinder, FunctionType functionType) {
- if (this.functionType != null) {
+ void bind(StateBinder stateBinder) {
+ if (isBound()) {
throw new IllegalStateException(
- "This registry was already bound to function type: "
- + this.functionType
- + ", attempting to rebind to function type: "
- + functionType);
+ "This registry was already bound to state binder: "
+ + this.stateBinder.getClass().getName()
+ + ", attempting to rebind to state binder: "
+ + stateBinder.getClass().getName());
}
this.stateBinder = Objects.requireNonNull(stateBinder);
- this.functionType = Objects.requireNonNull(functionType);
+ registeredStates.values().forEach(stateBinder::bind);
+ }
- registeredStates.values().forEach(state -> stateBinder.bind(state, functionType));
+ private boolean isBound() {
+ return stateBinder != null && !(stateBinder instanceof NonFaultTolerantStateBinder);
}
private void acceptRegistrationOrThrowIfPresent(String stateName, Object newStateObject) {
@@ -125,18 +118,17 @@
}
registeredStates.put(stateName, newStateObject);
- stateBinder.bind(newStateObject, functionType);
+ stateBinder.bind(newStateObject);
}
private static final class NonFaultTolerantStateBinder extends StateBinder {
@Override
- public void bindValue(PersistedValue<?> persistedValue, FunctionType functionType) {}
+ public void bindValue(PersistedValue<?> persistedValue) {}
@Override
- public void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType) {}
+ public void bindTable(PersistedTable<?, ?> persistedTable) {}
@Override
- public void bindAppendingBuffer(
- PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType) {}
+ public void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer) {}
}
}
diff --git a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
index 1cfcb3e..01560cc 100644
--- a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
+++ b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
@@ -18,23 +18,20 @@
package org.apache.flink.statefun.sdk.state;
-import org.apache.flink.statefun.sdk.FunctionType;
-
public abstract class StateBinder {
- public abstract void bindValue(PersistedValue<?> persistedValue, FunctionType functionType);
+ public abstract void bindValue(PersistedValue<?> persistedValue);
- public abstract void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType);
+ public abstract void bindTable(PersistedTable<?, ?> persistedTable);
- public abstract void bindAppendingBuffer(
- PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType);
+ public abstract void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer);
- public final void bind(Object stateObject, FunctionType functionType) {
+ public final void bind(Object stateObject) {
if (stateObject instanceof PersistedValue) {
- bindValue((PersistedValue<?>) stateObject, functionType);
+ bindValue((PersistedValue<?>) stateObject);
} else if (stateObject instanceof PersistedTable) {
- bindTable((PersistedTable<?, ?>) stateObject, functionType);
+ bindTable((PersistedTable<?, ?>) stateObject);
} else if (stateObject instanceof PersistedAppendingBuffer) {
- bindAppendingBuffer((PersistedAppendingBuffer<?>) stateObject, functionType);
+ bindAppendingBuffer((PersistedAppendingBuffer<?>) stateObject);
} else {
throw new IllegalArgumentException("Unknown persisted state object " + stateObject);
}