[FLINK-20265] [core] Allow PersistedRemoteFunctionValues to register states based on protocol PersistedValueSpecs
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
index 4733366..4977e2e 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/PersistedRemoteFunctionValues.java
@@ -18,13 +18,17 @@
package org.apache.flink.statefun.flink.core.reqreply;
+import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import org.apache.flink.statefun.flink.core.httpfn.StateSpec;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.ExpirationSpec;
+import org.apache.flink.statefun.flink.core.polyglot.generated.FromFunction.PersistedValueSpec;
import org.apache.flink.statefun.sdk.annotations.Persisted;
+import org.apache.flink.statefun.sdk.state.Expiration;
import org.apache.flink.statefun.sdk.state.PersistedStateRegistry;
import org.apache.flink.statefun.sdk.state.PersistedValue;
@@ -52,6 +56,49 @@
getStateHandleOrThrow(stateName).clear();
}
+ /**
+ * Registers states that were indicated to be missing by remote functions via the remote
+ * invocation protocol.
+ *
+ * <p>A state is registered with the provided specification only if it wasn't registered already
+ * under the same name (identified by {@link PersistedValueSpec#getStateName()}). This means that
+ * you cannot change the specifications of an already registered state name, e.g. state TTL
+ * expiration configuration cannot be changed.
+ *
+ * @param protocolPersistedValueSpecs list of specifications for the indicated missing states.
+ */
+ void registerStates(List<PersistedValueSpec> protocolPersistedValueSpecs) {
+ protocolPersistedValueSpecs.forEach(this::createAndRegisterValueStateIfAbsent);
+ }
+
+ private void createAndRegisterValueStateIfAbsent(PersistedValueSpec protocolPersistedValueSpec) {
+ final String stateName = protocolPersistedValueSpec.getStateName();
+
+ if (!managedStates.containsKey(stateName)) {
+ final PersistedValue<byte[]> stateValue =
+ PersistedValue.of(
+ stateName,
+ byte[].class,
+ sdkTtlExpiration(protocolPersistedValueSpec.getExpirationSpec()));
+ stateRegistry.registerValue(stateValue);
+ managedStates.put(stateName, stateValue);
+ }
+ }
+
+ private static Expiration sdkTtlExpiration(ExpirationSpec protocolExpirationSpec) {
+ final long expirationTtlMillis = protocolExpirationSpec.getExpireAfterMillis();
+
+ switch (protocolExpirationSpec.getMode()) {
+ case AFTER_INVOKE:
+ return Expiration.expireAfterReadingOrWriting(Duration.ofMillis(expirationTtlMillis));
+ case AFTER_WRITE:
+ return Expiration.expireAfterWriting(Duration.ofMillis(expirationTtlMillis));
+ default:
+ case NONE:
+ return Expiration.none();
+ }
+ }
+
private void createAndRegisterValueState(StateSpec stateSpec) {
final String stateName = stateSpec.name();
@@ -65,7 +112,9 @@
final PersistedValue<byte[]> handle = managedStates.get(stateName);
if (handle == null) {
throw new IllegalStateException(
- "Accessing a non-existing remote function state: " + stateName);
+ "Accessing a non-existing function state: "
+ + stateName
+ + ". This can happen if you forgot to declare this state using the language SDKs.");
}
return handle;
}