[FLINK-31173] Fix wrong typeinfo in ProxyOperatorStateBackend
This closes #216.
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
index a655886..58fc57a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
@@ -21,6 +21,11 @@
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.iteration.utils.ReflectionUtils;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
import org.apache.flink.runtime.state.OperatorStateBackend;
@@ -50,20 +55,35 @@
@Override
public <K, V> BroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor)
throws Exception {
- MapStateDescriptor<K, V> newDescriptor =
- new MapStateDescriptor<>(
- stateNamePrefix.prefix(stateDescriptor.getName()),
- stateDescriptor.getKeySerializer(),
- stateDescriptor.getValueSerializer());
+ MapStateDescriptor<K, V> newDescriptor;
+ if (stateDescriptor.isSerializerInitialized()) {
+ newDescriptor =
+ new MapStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ stateDescriptor.getKeySerializer(),
+ stateDescriptor.getValueSerializer());
+ } else {
+ MapTypeInfo<K, V> mapTypeInfo = getMapTypeInfo(stateDescriptor);
+ newDescriptor =
+ new MapStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ mapTypeInfo.getKeyTypeInfo(),
+ mapTypeInfo.getValueTypeInfo());
+ }
return wrappedBackend.getBroadcastState(newDescriptor);
}
@Override
public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception {
ListStateDescriptor<S> newDescriptor =
- new ListStateDescriptor<>(
- stateNamePrefix.prefix(stateDescriptor.getName()),
- stateDescriptor.getElementSerializer());
+ stateDescriptor.isSerializerInitialized()
+ ? new ListStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ stateDescriptor.getElementSerializer())
+ : new ListStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ getElementTypeInfo(stateDescriptor));
+
return wrappedBackend.getListState(newDescriptor);
}
@@ -71,9 +91,13 @@
public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor)
throws Exception {
ListStateDescriptor<S> newDescriptor =
- new ListStateDescriptor<S>(
- stateNamePrefix.prefix(stateDescriptor.getName()),
- stateDescriptor.getElementSerializer());
+ stateDescriptor.isSerializerInitialized()
+ ? new ListStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ stateDescriptor.getElementSerializer())
+ : new ListStateDescriptor<>(
+ stateNamePrefix.prefix(stateDescriptor.getName()),
+ getElementTypeInfo(stateDescriptor));
return wrappedBackend.getUnionListState(newDescriptor);
}
@@ -125,4 +149,16 @@
throws Exception {
return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
}
+
+ @SuppressWarnings("unchecked,rawtypes")
+ private <S> TypeInformation<S> getElementTypeInfo(ListStateDescriptor<S> stateDescriptor) {
+ return ((ListTypeInfo)
+ ReflectionUtils.getFieldValue(
+ stateDescriptor, StateDescriptor.class, "typeInfo"))
+ .getElementTypeInfo();
+ }
+
+ private <K, V> MapTypeInfo<K, V> getMapTypeInfo(MapStateDescriptor<K, V> stateDescriptor) {
+ return ReflectionUtils.getFieldValue(stateDescriptor, StateDescriptor.class, "typeInfo");
+ }
}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index 825aa31..a1b5609 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -19,6 +19,10 @@
package org.apache.flink.test.iteration;
import org.apache.flink.api.common.functions.JoinFunction;
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
@@ -26,12 +30,14 @@
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -55,6 +61,8 @@
import org.junit.Rule;
import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
@@ -73,6 +81,7 @@
private SharedReference<BlockingQueue<OutputRecord<Integer>>> collectedOutputRecord;
private SharedReference<BlockingQueue<Long>> collectedWatermarks;
+ private SharedReference<BlockingQueue<Long>> collectedOutputs;
@Before
public void setup() throws Exception {
@@ -81,6 +90,7 @@
collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>());
collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>());
+ collectedOutputs = sharedObjects.add(new LinkedBlockingQueue<>());
}
@After
@@ -136,6 +146,33 @@
.forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x));
}
+ @Test
+ public void testPerRoundIterationWithState() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(1);
+ DataStream<Long> broadcastStream = env.fromElements(1L);
+ DataStream<Long> inputStream = env.fromElements(1L);
+ DataStreamList outputStream =
+ Iterations.iterateBoundedStreamsUntilTermination(
+ DataStreamList.of(inputStream),
+ ReplayableDataStreamList.replay(broadcastStream),
+ IterationConfig.newBuilder()
+ .setOperatorLifeCycle(OperatorLifeCycle.PER_ROUND)
+ .build(),
+ new PerRoundIterationBodyWithState());
+
+ outputStream.<Long>get(0).addSink(new LongSink(collectedOutputs));
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+ miniCluster.executeJobBlocking(jobGraph);
+
+ List<Long> result = new ArrayList<>(3);
+ collectedOutputs.get().drainTo(result);
+ assertEquals(3, result.size());
+ for (long value : result) {
+ assertEquals(1L, value);
+ }
+ }
+
private static JobGraph createPerRoundJobGraph(
int numSources,
int numRecordsPerSource,
@@ -229,6 +266,56 @@
}
}
+ private static class PerRoundIterationBodyWithState implements IterationBody {
+
+ @Override
+ public IterationBodyResult process(
+ DataStreamList variableStreams, DataStreamList dataStreams) {
+ DataStream<Long> variableStream = variableStreams.get(0);
+
+ DataStream<Long> feedback =
+ variableStream.transform("mapWithState", Types.LONG, new MapWithState());
+
+ DataStream<Integer> terminationCriteria =
+ feedback.<Long>flatMap(new TerminateOnMaxIter(2)).returns(Types.INT);
+
+ return new IterationBodyResult(
+ DataStreamList.of(feedback), DataStreamList.of(feedback), terminationCriteria);
+ }
+ }
+
+ private static class MapWithState extends AbstractStreamOperator<Long>
+ implements OneInputStreamOperator<Long, Long> {
+ private ListState<Long> listState;
+ private ListState<Long> unionState;
+ private BroadcastState<Long, Long> broadcastState;
+
+ @Override
+ public void processElement(StreamRecord<Long> element) throws Exception {
+ long val = element.getValue();
+ listState.add(val);
+ unionState.add(val);
+ broadcastState.put(val, val);
+ output.collect(element);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ listState =
+ context.getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("longState", Types.LONG));
+ unionState =
+ context.getOperatorStateStore()
+ .getUnionListState(new ListStateDescriptor<>("unionState", Types.LONG));
+ broadcastState =
+ context.getOperatorStateStore()
+ .getBroadcastState(
+ new MapStateDescriptor<>(
+ "broadcastState", Types.LONG, Types.LONG));
+ }
+ }
+
private static class LongSink implements SinkFunction<Long> {
private final SharedReference<BlockingQueue<Long>> collectedLong;