[FLINK-19096] [sdk] Rework PersistedStateRegistry registration methods

This reworks the `registerValue` / `registerTable` etc. methods on
PersistedStateRegistry to directly accept a constructed PersistedValue /
PersistedTable instance. This allows us to avoid having to synchronize
those methods and the state class constructors.
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
index e918c41..4733366 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
@@ -37,7 +37,7 @@
   public PersistedRemoteFunctionValues(List<StateSpec> stateSpecs) {
     Objects.requireNonNull(stateSpecs);
     this.managedStates = new HashMap<>(stateSpecs.size());
-    stateSpecs.forEach(spec -> managedStates.put(spec.name(), createStateHandle(spec)));
+    stateSpecs.forEach(this::createAndRegisterValueState);
   }
 
   void forEach(BiConsumer<String, byte[]> consumer) {
@@ -52,8 +52,13 @@
     getStateHandleOrThrow(stateName).clear();
   }
 
-  private PersistedValue<byte[]> createStateHandle(StateSpec stateSpec) {
-    return stateRegistry.registerValue(stateSpec.name(), byte[].class, stateSpec.ttlExpiration());
+  private void createAndRegisterValueState(StateSpec stateSpec) {
+    final String stateName = stateSpec.name();
+
+    final PersistedValue<byte[]> stateValue =
+        PersistedValue.of(stateName, byte[].class, stateSpec.ttlExpiration());
+    stateRegistry.registerValue(stateValue);
+    managedStates.put(stateName, stateValue);
   }
 
   private PersistedValue<byte[]> getStateHandleOrThrow(String stateName) {
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
index 3e1c622..5e5d4fb 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/state/PersistedStatesTest.java
@@ -180,15 +180,19 @@
     @Persisted PersistedStateRegistry provider = new PersistedStateRegistry();
 
     DynamicState() {
-      provider.registerValue("in-constructor-value", String.class);
-      provider.registerTable("in-constructor-table", String.class, Integer.class);
-      provider.registerAppendingBuffer("in-constructor-buffer", String.class);
+      provider.registerValue(PersistedValue.of("in-constructor-value", String.class));
+      provider.registerTable(
+          PersistedTable.of("in-constructor-table", String.class, Integer.class));
+      provider.registerAppendingBuffer(
+          PersistedAppendingBuffer.of("in-constructor-buffer", String.class));
     }
 
     void process() {
-      provider.registerValue("post-constructor-value", String.class);
-      provider.registerTable("post-constructor-table", String.class, Integer.class);
-      provider.registerAppendingBuffer("post-constructor-buffer", String.class);
+      provider.registerValue(PersistedValue.of("post-constructor-value", String.class));
+      provider.registerTable(
+          PersistedTable.of("post-constructor-table", String.class, Integer.class));
+      provider.registerAppendingBuffer(
+          PersistedAppendingBuffer.of("post-constructor-buffer", String.class));
     }
   }
 
diff --git a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
index 391bffe..f606675 100644
--- a/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
+++ b/statefun-sdk/src/main/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistry.java
@@ -21,7 +21,6 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
-import java.util.function.Function;
 import javax.annotation.Nullable;
 import org.apache.flink.statefun.sdk.FunctionType;
 import org.apache.flink.statefun.sdk.StatefulFunction;
@@ -56,121 +55,40 @@
   }
 
   /**
-   * Registers a {@link PersistedValue}, given a state name and the type of the values. If a
-   * registered value already exists for the given name, the previous persisted value is returned.
+   * Registers a {@link PersistedValue}. If a registered state already exists for the specified name
+   * of the value, the registration fails.
    *
-   * @param name the state name to register with.
-   * @param type the type of the value.
+   * @param valueState the value state to register.
    * @param <T> the type of the value.
-   * @return the registered value, or the previous registered value if a registration for the state
-   *     name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedValue}.
+   * @throws IllegalStateException if a previous registration exists for the given state name.
    */
-  public <T> PersistedValue<T> registerValue(String name, Class<T> type) {
-    return registerValue(name, type, Expiration.none());
+  public <T> void registerValue(PersistedValue<T> valueState) {
+    acceptRegistrationOrThrowIfPresent(valueState.name(), valueState);
   }
 
   /**
-   * Registers a {@link PersistedValue}, given a state name and the type of the values. If a
-   * registered value already exists for the given name, the previous persisted value is returned.
+   * Registers a {@link PersistedTable}. If a registered state already exists for the specified name
+   * of the table, the registration fails.
    *
-   * @param name the state name to register with.
-   * @param type the type of the value.
-   * @param expiration expiration configuration for the registered state.
-   * @param <T> the type of the value.
-   * @return the registered value, or the previous registered value if a registration for the state
-   *     name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedValue}.
-   */
-  public <T> PersistedValue<T> registerValue(String name, Class<T> type, Expiration expiration) {
-    return getStateOrCreateIfAbsent(
-        PersistedValue.class, name, stateName -> createValue(stateName, type, expiration));
-  }
-
-  /**
-   * Registers a {@link PersistedTable}, given a state name and the type of the keys and values of
-   * the table. If a registered value already exists for the given name, the previous persisted
-   * table is returned.
-   *
-   * @param name the state name to register with.
-   * @param keyType the type of the keys.
-   * @param valueType the type of the values.
+   * @param tableState the table state to register.
    * @param <K> the type of the keys.
    * @param <V> the type of the values.
-   * @return the registered table, or the previous registered table if a registration for the state
-   *     name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedTable}.
+   * @throws IllegalStateException if a previous registration exists for the given state name.
    */
-  public <K, V> PersistedTable<K, V> registerTable(
-      String name, Class<K> keyType, Class<V> valueType) {
-    return registerTable(name, keyType, valueType, Expiration.none());
+  public <K, V> void registerTable(PersistedTable<K, V> tableState) {
+    acceptRegistrationOrThrowIfPresent(tableState.name(), tableState);
   }
 
   /**
-   * Registers a {@link PersistedTable}, given a state name and the type of the keys and values of
-   * the table. If a registered value already exists for the given name, the previous persisted
-   * table is returned.
+   * Registers a {@link PersistedAppendingBuffer}. If a registered state already exists for the
+   * specified name of the table, the registration fails.
    *
-   * @param name the state name to register with.
-   * @param keyType the type of the keys.
-   * @param valueType the type of the values.
-   * @param expiration expiration configuration for the registered state.
-   * @param <K> the type of the keys.
-   * @param <V> the type of the values.
-   * @return the registered table, or the previous registered table if a registration for the state
-   *     name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedTable}.
-   */
-  public <K, V> PersistedTable<K, V> registerTable(
-      String name, Class<K> keyType, Class<V> valueType, Expiration expiration) {
-    return getStateOrCreateIfAbsent(
-        PersistedTable.class,
-        name,
-        stateName -> createTable(stateName, keyType, valueType, expiration));
-  }
-
-  /**
-   * Registers a {@link PersistedAppendingBuffer}, given a state name and the type of the buffer
-   * elements. If a registered buffer already exists for the given name, the previous persisted
-   * buffer is returned.
-   *
-   * @param name the state name to register with.
-   * @param elementType the type of the buffer elements.
+   * @param bufferState the appending buffer to register.
    * @param <E> the type of the buffer elements.
-   * @return the registered buffer, or the previous registered buffer if a registration for the
-   *     state name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedAppendingBuffer}.
+   * @throws IllegalStateException if a previous registration exists for the given state name.
    */
-  public <E> PersistedAppendingBuffer<E> registerAppendingBuffer(
-      String name, Class<E> elementType) {
-    return registerAppendingBuffer(name, elementType, Expiration.none());
-  }
-
-  /**
-   * Registers a {@link PersistedAppendingBuffer}, given a state name and the type of the buffer
-   * elements. If a registered buffer already exists for the given name, the previous persisted
-   * buffer is returned.
-   *
-   * @param name the state name to register with.
-   * @param elementType the type of the buffer elements.
-   * @param expiration expiration configuration for the registered state.
-   * @param <E> the type of the buffer elements.
-   * @return the registered buffer, or the previous registered buffer if a registration for the
-   *     state name already exists.
-   * @throws IllegalStateException if a previous registration exists for the given state name, but
-   *     it wasn't registered as a {@link PersistedAppendingBuffer}.
-   */
-  public <E> PersistedAppendingBuffer<E> registerAppendingBuffer(
-      String name, Class<E> elementType, Expiration expiration) {
-    return getStateOrCreateIfAbsent(
-        PersistedAppendingBuffer.class,
-        name,
-        stateName -> createAppendingBuffer(stateName, elementType, expiration));
+  public <E> void registerAppendingBuffer(PersistedAppendingBuffer<E> bufferState) {
+    acceptRegistrationOrThrowIfPresent(bufferState.name(), bufferState);
   }
 
   /**
@@ -197,39 +115,19 @@
     registeredStates.values().forEach(state -> stateBinder.bind(state, functionType));
   }
 
-  private <T> PersistedValue<T> createValue(String name, Class<T> type, Expiration expiration) {
-    final PersistedValue<T> value = PersistedValue.of(name, type, expiration);
-    stateBinder.bindValue(value, functionType);
-    return value;
-  }
-
-  private <K, V> PersistedTable<K, V> createTable(
-      String name, Class<K> keyType, Class<V> valueType, Expiration expiration) {
-    final PersistedTable<K, V> table = PersistedTable.of(name, keyType, valueType, expiration);
-    stateBinder.bindTable(table, functionType);
-    return table;
-  }
-
-  private <E> PersistedAppendingBuffer<E> createAppendingBuffer(
-      String name, Class<E> elementType, Expiration expiration) {
-    final PersistedAppendingBuffer<E> buffer =
-        PersistedAppendingBuffer.of(name, elementType, expiration);
-    stateBinder.bindAppendingBuffer(buffer, functionType);
-    return buffer;
-  }
-
-  @SuppressWarnings("unchecked")
-  private <ST> ST getStateOrCreateIfAbsent(
-      Class<?> statePrimitiveType, String name, Function<String, ST> createFunction) {
-    final ST state = (ST) registeredStates.computeIfAbsent(name, createFunction::apply);
-    if (state.getClass() != statePrimitiveType) {
+  private void acceptRegistrationOrThrowIfPresent(String stateName, Object newStateObject) {
+    final Object previousRegistration = registeredStates.get(stateName);
+    if (previousRegistration != null) {
       throw new IllegalStateException(
-          "Unexpected state primitive type. The state was registered with type: "
-              + state.getClass()
-              + ", but was attempting to access it again as type: "
-              + statePrimitiveType);
+          String.format(
+              "State name '%s' was registered twice; previous registered state object with the same name was a %s, attempting to register a new %s under the same name.",
+              stateName,
+              previousRegistration.getClass().getName(),
+              newStateObject.getClass().getName()));
     }
-    return state;
+
+    registeredStates.put(stateName, newStateObject);
+    stateBinder.bind(newStateObject, functionType);
   }
 
   private static final class NonFaultTolerantStateBinder extends StateBinder {
diff --git a/statefun-sdk/src/test/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistryTest.java b/statefun-sdk/src/test/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistryTest.java
index 0c91178..7858b6b 100644
--- a/statefun-sdk/src/test/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistryTest.java
+++ b/statefun-sdk/src/test/java/org/apache/flink/statefun/sdk/state/PersistedStateRegistryTest.java
@@ -26,16 +26,16 @@
   public void exampleUsage() {
     final PersistedStateRegistry registry = new PersistedStateRegistry();
 
-    registry.registerValue("value", String.class);
-    registry.registerTable("table", String.class, Integer.class);
-    registry.registerAppendingBuffer("buffer", String.class);
+    registry.registerValue(PersistedValue.of("value", String.class));
+    registry.registerTable(PersistedTable.of("table", String.class, Integer.class));
+    registry.registerAppendingBuffer(PersistedAppendingBuffer.of("buffer", String.class));
   }
 
   @Test(expected = IllegalStateException.class)
-  public void reaccessAsWrongStatePrimitiveType() {
+  public void duplicateRegistration() {
     final PersistedStateRegistry registry = new PersistedStateRegistry();
 
-    registry.registerValue("my-state", String.class);
-    registry.registerAppendingBuffer("my-state", String.class);
+    registry.registerValue(PersistedValue.of("my-state", String.class));
+    registry.registerTable(PersistedTable.of("my-state", String.class, Integer.class));
   }
 }