[FLINK-19102] [core, sdk] Make StateBinder a per-FunctionType entity

The goal of this commit is to make StateBinders know the FunctionType,
with each state binder instance specifically binding state for a single
FunctionType.

The side effect of this is also moving FunctionType out of
PersistedStateRegistry, as that information is encapsulated to the
StateBinder it receives at runtime.

This closes #137.
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
index 9c42cdf..abce51a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
@@ -35,7 +35,6 @@
 import org.apache.flink.statefun.flink.core.metrics.FlinkMetricsFactory;
 import org.apache.flink.statefun.flink.core.metrics.MetricsFactory;
 import org.apache.flink.statefun.flink.core.state.FlinkState;
-import org.apache.flink.statefun.flink.core.state.FlinkStateBinder;
 import org.apache.flink.statefun.flink.core.state.State;
 import org.apache.flink.statefun.flink.core.types.DynamicallyRegisteredTypes;
 import org.apache.flink.statefun.sdk.io.EgressIdentifier;
@@ -94,7 +93,6 @@
     container.add("applying-context", ApplyingContext.class, ReusableContext.class);
     container.add(LocalSink.class);
     container.add("function-loader", FunctionLoader.class, PredefinedFunctionLoader.class);
-    container.add(FlinkStateBinder.class);
     container.add(Reductions.class);
     container.add(LocalFunctionGroup.class);
     container.add("metrics-factory", MetricsFactory.class, new FlinkMetricsFactory(metricGroup));
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
index 0187e57..af512d0 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/StatefulFunctionRepository.java
@@ -27,11 +27,12 @@
 import org.apache.flink.statefun.flink.core.metrics.MetricsFactory;
 import org.apache.flink.statefun.flink.core.state.FlinkStateBinder;
 import org.apache.flink.statefun.flink.core.state.PersistedStates;
+import org.apache.flink.statefun.flink.core.state.State;
 import org.apache.flink.statefun.sdk.FunctionType;
 
 final class StatefulFunctionRepository implements FunctionRepository {
   private final ObjectOpenHashMap<FunctionType, StatefulFunction> instances;
-  private final FlinkStateBinder stateBinder;
+  private final State flinkState;
   private final FunctionLoader functionLoader;
   private final MetricsFactory metricsFactory;
   private final MessageFactory messageFactory;
@@ -40,12 +41,12 @@
   StatefulFunctionRepository(
       @Label("function-loader") FunctionLoader functionLoader,
       @Label("metrics-factory") MetricsFactory metricsFactory,
-      MessageFactory messageFactory,
-      FlinkStateBinder stateBinder) {
+      @Label("state") State state,
+      MessageFactory messageFactory) {
     this.instances = new ObjectOpenHashMap<>();
-    this.stateBinder = Objects.requireNonNull(stateBinder);
     this.functionLoader = Objects.requireNonNull(functionLoader);
     this.metricsFactory = Objects.requireNonNull(metricsFactory);
+    this.flinkState = Objects.requireNonNull(state);
     this.messageFactory = Objects.requireNonNull(messageFactory);
   }
 
@@ -62,7 +63,8 @@
     org.apache.flink.statefun.sdk.StatefulFunction statefulFunction =
         functionLoader.load(functionType);
     try (SetContextClassLoader ignored = new SetContextClassLoader(statefulFunction)) {
-      PersistedStates.findAndBind(functionType, statefulFunction, stateBinder);
+      FlinkStateBinder stateBinderForType = new FlinkStateBinder(flinkState, functionType);
+      PersistedStates.findAndBind(statefulFunction, stateBinderForType);
       FunctionTypeMetrics metrics = metricsFactory.forType(functionType);
       return new StatefulFunction(statefulFunction, metrics, messageFactory);
     }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
index 839e601..6a0421d 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/FlinkStateBinder.java
@@ -19,7 +19,6 @@
 
 import java.util.Objects;
 import org.apache.flink.statefun.flink.core.di.Inject;
-import org.apache.flink.statefun.flink.core.di.Label;
 import org.apache.flink.statefun.sdk.FunctionType;
 import org.apache.flink.statefun.sdk.state.Accessor;
 import org.apache.flink.statefun.sdk.state.ApiExtension;
@@ -30,30 +29,35 @@
 import org.apache.flink.statefun.sdk.state.StateBinder;
 import org.apache.flink.statefun.sdk.state.TableAccessor;
 
+/**
+ * A {@link StateBinder} that binds persisted state objects to Flink state for a specific {@link
+ * FunctionType}.
+ */
 public final class FlinkStateBinder extends StateBinder {
   private final State state;
+  private final FunctionType functionType;
 
   @Inject
-  public FlinkStateBinder(@Label("state") State state) {
+  public FlinkStateBinder(State state, FunctionType functionType) {
     this.state = Objects.requireNonNull(state);
+    this.functionType = Objects.requireNonNull(functionType);
   }
 
   @Override
-  public void bindValue(PersistedValue<?> persistedValue, FunctionType functionType) {
+  public void bindValue(PersistedValue<?> persistedValue) {
     Accessor<?> accessor = state.createFlinkStateAccessor(functionType, persistedValue);
     setAccessorRaw(persistedValue, accessor);
   }
 
   @Override
-  public void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType) {
+  public void bindTable(PersistedTable<?, ?> persistedTable) {
     TableAccessor<?, ?> accessor =
         state.createFlinkStateTableAccessor(functionType, persistedTable);
     setAccessorRaw(persistedTable, accessor);
   }
 
   @Override
-  public void bindAppendingBuffer(
-      PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType) {
+  public void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer) {
     AppendingBufferAccessor<?> accessor =
         state.createFlinkStateAppendingBufferAccessor(functionType, persistedAppendingBuffer);
     setAccessorRaw(persistedAppendingBuffer, accessor);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
index 6e57c5b..968e4b5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/state/PersistedStates.java
@@ -26,7 +26,6 @@
 import java.util.stream.Stream;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
-import org.apache.flink.statefun.sdk.FunctionType;
 import org.apache.flink.statefun.sdk.annotations.Persisted;
 import org.apache.flink.statefun.sdk.state.ApiExtension;
 import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
@@ -36,15 +35,14 @@
 
 public final class PersistedStates {
 
-  public static void findAndBind(
-      FunctionType functionType, @Nullable Object instance, FlinkStateBinder stateBinder) {
+  public static void findAndBind(@Nullable Object instance, FlinkStateBinder stateBinder) {
     List<?> states = findReflectively(instance);
     for (Object persisted : states) {
       if (persisted instanceof PersistedStateRegistry) {
         PersistedStateRegistry stateRegistry = (PersistedStateRegistry) persisted;
-        ApiExtension.bindPersistedStateRegistry(stateRegistry, stateBinder, functionType);
+        ApiExtension.bindPersistedStateRegistry(stateRegistry, stateBinder);
       } else {
-        stateBinder.bind(persisted, functionType);
+        stateBinder.bind(persisted);
       }
     }
   }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
index 30e9657..f107aeb 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/sdk/state/ApiExtension.java
@@ -17,8 +17,6 @@
  */
 package org.apache.flink.statefun.sdk.state;
 
-import org.apache.flink.statefun.sdk.FunctionType;
-
 public class ApiExtension {
   public static <T> void setPersistedValueAccessor(
       PersistedValue<T> persistedValue, Accessor<T> accessor) {
@@ -36,9 +34,7 @@
   }
 
   public static void bindPersistedStateRegistry(
-      PersistedStateRegistry persistedStateRegistry,
-      StateBinder stateBinder,
-      FunctionType functionType) {
-    persistedStateRegistry.bind(stateBinder, functionType);
+      PersistedStateRegistry persistedStateRegistry, StateBinder stateBinder) {
+    persistedStateRegistry.bind(stateBinder);
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
index 5e5d4fb..adddcf5 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
@@ -47,52 +47,50 @@
   private final FakeState state = new FakeState();
 
   // object under test
-  private final FlinkStateBinder binderUnderTest = new FlinkStateBinder(state);
+  private final FlinkStateBinder binderUnderTest =
+      new FlinkStateBinder(state, TestUtils.FUNCTION_TYPE);
 
   @Test
   public void exampleUsage() {
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new SanityClass(), binderUnderTest);
+    PersistedStates.findAndBind(new SanityClass(), binderUnderTest);
 
     assertThat(state.boundNames, hasItems("name", "last"));
   }
 
   @Test(expected = IllegalStateException.class)
   public void nullValueField() {
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new NullValueClass(), binderUnderTest);
+    PersistedStates.findAndBind(new NullValueClass(), binderUnderTest);
   }
 
   @Test
   public void nonAnnotatedClass() {
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new IgnoreNonAnnotated(), binderUnderTest);
+    PersistedStates.findAndBind(new IgnoreNonAnnotated(), binderUnderTest);
 
     assertTrue(state.boundNames.isEmpty());
   }
 
   @Test
   public void extendedClass() {
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new ChildClass(), binderUnderTest);
+    PersistedStates.findAndBind(new ChildClass(), binderUnderTest);
 
     assertThat(state.boundNames, hasItems("parent", "child"));
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void staticPersistedFieldsAreNotAllowed() {
-    PersistedStates.findAndBind(
-        TestUtils.FUNCTION_TYPE, new StaticPersistedValue(), binderUnderTest);
+    PersistedStates.findAndBind(new StaticPersistedValue(), binderUnderTest);
   }
 
   @Test
   public void bindPersistedTable() {
-    PersistedStates.findAndBind(
-        TestUtils.FUNCTION_TYPE, new PersistedTableValue(), binderUnderTest);
+    PersistedStates.findAndBind(new PersistedTableValue(), binderUnderTest);
 
     assertThat(state.boundNames, hasItems("table"));
   }
 
   @Test
   public void bindPersistedAppendingBuffer() {
-    PersistedStates.findAndBind(
-        TestUtils.FUNCTION_TYPE, new PersistedAppendingBufferState(), binderUnderTest);
+    PersistedStates.findAndBind(new PersistedAppendingBufferState(), binderUnderTest);
 
     assertThat(state.boundNames, hasItems("buffer"));
   }
@@ -100,7 +98,7 @@
   @Test
   public void bindDynamicState() {
     DynamicState dynamicState = new DynamicState();
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, dynamicState, binderUnderTest);
+    PersistedStates.findAndBind(dynamicState, binderUnderTest);
 
     dynamicState.process();
 
@@ -117,7 +115,7 @@
 
   @Test
   public void bindComposedState() {
-    PersistedStates.findAndBind(TestUtils.FUNCTION_TYPE, new OuterClass(), binderUnderTest);
+    PersistedStates.findAndBind(new OuterClass(), binderUnderTest);
 
     assertThat(state.boundNames, hasItems("inner"));
   }
diff --git a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
index 304a4df..2825f75 100644
--- a/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
+++ b/statefun-flink/statefun-flink-state-processor/src/main/java/org/apache/flink/statefun/flink/state/processor/operator/StateBootstrapFunctionRegistry.java
@@ -93,14 +93,14 @@
   void initialize(State stateAccessor) {
     this.registry = new HashMap<>(stateBootstrapFunctionProviders.size());
 
-    final FlinkStateBinder stateBinder = new FlinkStateBinder(stateAccessor);
     for (Map.Entry<SerializableFunctionType, StateBootstrapFunctionProvider> entry :
         stateBootstrapFunctionProviders.entrySet()) {
       final FunctionType functionType = entry.getKey().toNonSerializable();
       final StateBootstrapFunction bootstrapFunction =
           entry.getValue().bootstrapFunctionOfType(functionType);
+      final FlinkStateBinder stateBinder = new FlinkStateBinder(stateAccessor, functionType);
 
-      registry.put(functionType, bindState(functionType, bootstrapFunction, stateBinder));
+      registry.put(functionType, bindState(bootstrapFunction, stateBinder));
     }
   }
 
@@ -115,11 +115,9 @@
   }
 
   private static StateBootstrapFunction bindState(
-      FunctionType functionType,
-      StateBootstrapFunction bootstrapFunction,
-      FlinkStateBinder stateBinder) {
+      StateBootstrapFunction bootstrapFunction, FlinkStateBinder stateBinder) {
     try (SetContextClassLoader ignored = new SetContextClassLoader(bootstrapFunction)) {
-      PersistedStates.findAndBind(functionType, bootstrapFunction, stateBinder);
+      PersistedStates.findAndBind(bootstrapFunction, stateBinder);
       return bootstrapFunction;
     }
   }
diff --git a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
index b806df1..bce8364 100644
--- a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
+++ b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
@@ -21,8 +21,6 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
-import javax.annotation.Nullable;
-import org.apache.flink.statefun.sdk.FunctionType;
 import org.apache.flink.statefun.sdk.StatefulFunction;
 import org.apache.flink.statefun.sdk.annotations.ForRuntime;
 import org.apache.flink.statefun.sdk.annotations.Persisted;
@@ -44,12 +42,6 @@
 
   private StateBinder stateBinder;
 
-  /**
-   * The type of the function that this registry is bound to. This is {@code NULL} if this registry
-   * is not bounded.
-   */
-  @Nullable private FunctionType functionType;
-
   public PersistedStateRegistry() {
     this.stateBinder = new NonFaultTolerantStateBinder();
   }
@@ -96,23 +88,24 @@
    * will also be bound to the system.
    *
    * @param stateBinder the new fault-tolerant state binder to use.
-   * @param functionType the type of the function that this registry is being bound to.
    * @throws IllegalStateException if the registry was attempted to be bound more than once.
    */
   @ForRuntime
-  void bind(StateBinder stateBinder, FunctionType functionType) {
-    if (this.functionType != null) {
+  void bind(StateBinder stateBinder) {
+    if (isBound()) {
       throw new IllegalStateException(
-          "This registry was already bound to function type: "
-              + this.functionType
-              + ", attempting to rebind to function type: "
-              + functionType);
+          "This registry was already bound to state binder: "
+              + this.stateBinder.getClass().getName()
+              + ", attempting to rebind to state binder: "
+              + stateBinder.getClass().getName());
     }
 
     this.stateBinder = Objects.requireNonNull(stateBinder);
-    this.functionType = Objects.requireNonNull(functionType);
+    registeredStates.values().forEach(stateBinder::bind);
+  }
 
-    registeredStates.values().forEach(state -> stateBinder.bind(state, functionType));
+  private boolean isBound() {
+    return stateBinder != null && !(stateBinder instanceof NonFaultTolerantStateBinder);
   }
 
   private void acceptRegistrationOrThrowIfPresent(String stateName, Object newStateObject) {
@@ -125,18 +118,17 @@
     }
 
     registeredStates.put(stateName, newStateObject);
-    stateBinder.bind(newStateObject, functionType);
+    stateBinder.bind(newStateObject);
   }
 
   private static final class NonFaultTolerantStateBinder extends StateBinder {
     @Override
-    public void bindValue(PersistedValue<?> persistedValue, FunctionType functionType) {}
+    public void bindValue(PersistedValue<?> persistedValue) {}
 
     @Override
-    public void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType) {}
+    public void bindTable(PersistedTable<?, ?> persistedTable) {}
 
     @Override
-    public void bindAppendingBuffer(
-        PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType) {}
+    public void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer) {}
   }
 }
diff --git a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
index 1cfcb3e..01560cc 100644
--- a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
+++ b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/StateBinder.java
@@ -18,23 +18,20 @@
 
 package org.apache.flink.statefun.sdk.state;
 
-import org.apache.flink.statefun.sdk.FunctionType;
-
 public abstract class StateBinder {
-  public abstract void bindValue(PersistedValue<?> persistedValue, FunctionType functionType);
+  public abstract void bindValue(PersistedValue<?> persistedValue);
 
-  public abstract void bindTable(PersistedTable<?, ?> persistedTable, FunctionType functionType);
+  public abstract void bindTable(PersistedTable<?, ?> persistedTable);
 
-  public abstract void bindAppendingBuffer(
-      PersistedAppendingBuffer<?> persistedAppendingBuffer, FunctionType functionType);
+  public abstract void bindAppendingBuffer(PersistedAppendingBuffer<?> persistedAppendingBuffer);
 
-  public final void bind(Object stateObject, FunctionType functionType) {
+  public final void bind(Object stateObject) {
     if (stateObject instanceof PersistedValue) {
-      bindValue((PersistedValue<?>) stateObject, functionType);
+      bindValue((PersistedValue<?>) stateObject);
     } else if (stateObject instanceof PersistedTable) {
-      bindTable((PersistedTable<?, ?>) stateObject, functionType);
+      bindTable((PersistedTable<?, ?>) stateObject);
     } else if (stateObject instanceof PersistedAppendingBuffer) {
-      bindAppendingBuffer((PersistedAppendingBuffer<?>) stateObject, functionType);
+      bindAppendingBuffer((PersistedAppendingBuffer<?>) stateObject);
     } else {
       throw new IllegalArgumentException("Unknown persisted state object " + stateObject);
     }