blob: 2825f7568068ada5ad191efb0ee611f7d0cac873 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.statefun.flink.state.processor.operator;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.flink.statefun.flink.common.SetContextClassLoader;
import org.apache.flink.statefun.flink.core.state.FlinkStateBinder;
import org.apache.flink.statefun.flink.core.state.PersistedStates;
import org.apache.flink.statefun.flink.core.state.State;
import org.apache.flink.statefun.flink.state.processor.StateBootstrapFunction;
import org.apache.flink.statefun.flink.state.processor.StateBootstrapFunctionProvider;
import org.apache.flink.statefun.sdk.FunctionType;
/** A registry that handles {@link StateBootstrapFunctionProvider} registrations. */
public final class StateBootstrapFunctionRegistry implements Serializable {
private static final long serialVersionUID = 1L;
/** State bootstrap function providers registered by the user. */
private final Map<SerializableFunctionType, StateBootstrapFunctionProvider>
stateBootstrapFunctionProviders = new HashMap<>();
/**
* Registry of instantiated, state-bound bootstrap functions. This is created only after {@link
* #initialize(State)} is invoked during runtime.
*/
private transient Map<FunctionType, StateBootstrapFunction> registry;
/**
* Registers a {@link StateBootstrapFunctionProvider}.
*
* @param functionType the type of the function that is being bootstrapped.
* @param stateBootstrapFunctionProvider provider of the bootstrap function.
*/
public void register(
FunctionType functionType, StateBootstrapFunctionProvider stateBootstrapFunctionProvider) {
if (isInitialized()) {
throw new IllegalStateException(
"Cannot register bootstrap function providers after the registry is initialized.");
}
Objects.requireNonNull(functionType);
Objects.requireNonNull(stateBootstrapFunctionProvider);
final StateBootstrapFunctionProvider previous =
stateBootstrapFunctionProviders.put(
SerializableFunctionType.fromNonSerializable(functionType),
stateBootstrapFunctionProvider);
if (previous == null) {
return;
}
throw new IllegalArgumentException(
String.format(
"A StateBootstrapFunctionProvider for function type %s was previously defined.",
functionType));
}
/**
* Returns number of registrations.
*
* @return number of registrations.
*/
public int numRegistrations() {
return stateBootstrapFunctionProviders.size();
}
/**
* Initializes the registry. This instantiates all registered state bootstrap functions, and binds
* them with Flink state.
*
* @param stateAccessor accessor for Flink state to bind bootstrap functions with.
*/
void initialize(State stateAccessor) {
this.registry = new HashMap<>(stateBootstrapFunctionProviders.size());
for (Map.Entry<SerializableFunctionType, StateBootstrapFunctionProvider> entry :
stateBootstrapFunctionProviders.entrySet()) {
final FunctionType functionType = entry.getKey().toNonSerializable();
final StateBootstrapFunction bootstrapFunction =
entry.getValue().bootstrapFunctionOfType(functionType);
final FlinkStateBinder stateBinder = new FlinkStateBinder(stateAccessor, functionType);
registry.put(functionType, bindState(bootstrapFunction, stateBinder));
}
}
/** Retrieves the bootstrap function for a given function type. */
@Nullable
StateBootstrapFunction getBootstrapFunction(FunctionType functionType) {
if (!isInitialized()) {
throw new IllegalStateException("The registry must be initialized first.");
}
return registry.get(functionType);
}
private static StateBootstrapFunction bindState(
StateBootstrapFunction bootstrapFunction, FlinkStateBinder stateBinder) {
try (SetContextClassLoader ignored = new SetContextClassLoader(bootstrapFunction)) {
PersistedStates.findAndBind(bootstrapFunction, stateBinder);
return bootstrapFunction;
}
}
private boolean isInitialized() {
return registry != null;
}
private static final class SerializableFunctionType implements Serializable {
private static final long serialVersionUID = 1L;
private final String namespace;
private final String name;
static SerializableFunctionType fromNonSerializable(FunctionType functionType) {
return new SerializableFunctionType(functionType.namespace(), functionType.name());
}
private SerializableFunctionType(String namespace, String name) {
this.namespace = Objects.requireNonNull(namespace);
this.name = Objects.requireNonNull(name);
}
private FunctionType toNonSerializable() {
return new FunctionType(namespace, name);
}
@Override
public int hashCode() {
return Objects.hash(namespace, name);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SerializableFunctionType that = (SerializableFunctionType) o;
return namespace.equals(that.namespace) && name.equals(that.name);
}
}
}