[FLINK-31173] Fix wrong typeinfo in ProxyOperatorStateBackend

This closes #216.
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
index a655886..58fc57a 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/proxy/state/ProxyOperatorStateBackend.java
@@ -21,6 +21,11 @@
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.iteration.utils.ReflectionUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.OperatorStateBackend;
@@ -50,20 +55,35 @@
     @Override
     public <K, V> BroadcastState<K, V> getBroadcastState(MapStateDescriptor<K, V> stateDescriptor)
             throws Exception {
-        MapStateDescriptor<K, V> newDescriptor =
-                new MapStateDescriptor<>(
-                        stateNamePrefix.prefix(stateDescriptor.getName()),
-                        stateDescriptor.getKeySerializer(),
-                        stateDescriptor.getValueSerializer());
+        MapStateDescriptor<K, V> newDescriptor;
+        if (stateDescriptor.isSerializerInitialized()) {
+            newDescriptor =
+                    new MapStateDescriptor<>(
+                            stateNamePrefix.prefix(stateDescriptor.getName()),
+                            stateDescriptor.getKeySerializer(),
+                            stateDescriptor.getValueSerializer());
+        } else {
+            MapTypeInfo<K, V> mapTypeInfo = getMapTypeInfo(stateDescriptor);
+            newDescriptor =
+                    new MapStateDescriptor<>(
+                            stateNamePrefix.prefix(stateDescriptor.getName()),
+                            mapTypeInfo.getKeyTypeInfo(),
+                            mapTypeInfo.getValueTypeInfo());
+        }
         return wrappedBackend.getBroadcastState(newDescriptor);
     }
 
     @Override
     public <S> ListState<S> getListState(ListStateDescriptor<S> stateDescriptor) throws Exception {
         ListStateDescriptor<S> newDescriptor =
-                new ListStateDescriptor<>(
-                        stateNamePrefix.prefix(stateDescriptor.getName()),
-                        stateDescriptor.getElementSerializer());
+                stateDescriptor.isSerializerInitialized()
+                        ? new ListStateDescriptor<>(
+                                stateNamePrefix.prefix(stateDescriptor.getName()),
+                                stateDescriptor.getElementSerializer())
+                        : new ListStateDescriptor<>(
+                                stateNamePrefix.prefix(stateDescriptor.getName()),
+                                getElementTypeInfo(stateDescriptor));
+
         return wrappedBackend.getListState(newDescriptor);
     }
 
@@ -71,9 +91,13 @@
     public <S> ListState<S> getUnionListState(ListStateDescriptor<S> stateDescriptor)
             throws Exception {
         ListStateDescriptor<S> newDescriptor =
-                new ListStateDescriptor<S>(
-                        stateNamePrefix.prefix(stateDescriptor.getName()),
-                        stateDescriptor.getElementSerializer());
+                stateDescriptor.isSerializerInitialized()
+                        ? new ListStateDescriptor<>(
+                                stateNamePrefix.prefix(stateDescriptor.getName()),
+                                stateDescriptor.getElementSerializer())
+                        : new ListStateDescriptor<>(
+                                stateNamePrefix.prefix(stateDescriptor.getName()),
+                                getElementTypeInfo(stateDescriptor));
         return wrappedBackend.getUnionListState(newDescriptor);
     }
 
@@ -125,4 +149,16 @@
             throws Exception {
         return wrappedBackend.snapshot(checkpointId, timestamp, streamFactory, checkpointOptions);
     }
+
+    @SuppressWarnings("unchecked,rawtypes")
+    private <S> TypeInformation<S> getElementTypeInfo(ListStateDescriptor<S> stateDescriptor) {
+        return ((ListTypeInfo)
+                        ReflectionUtils.getFieldValue(
+                                stateDescriptor, StateDescriptor.class, "typeInfo"))
+                .getElementTypeInfo();
+    }
+
+    private <K, V> MapTypeInfo<K, V> getMapTypeInfo(MapStateDescriptor<K, V> stateDescriptor) {
+        return ReflectionUtils.getFieldValue(stateDescriptor, StateDescriptor.class, "typeInfo");
+    }
 }
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
index 825aa31..a1b5609 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedPerRoundStreamIterationITCase.java
@@ -19,6 +19,10 @@
 package org.apache.flink.test.iteration;
 
 import org.apache.flink.api.common.functions.JoinFunction;
+import org.apache.flink.api.common.state.BroadcastState;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -26,12 +30,14 @@
 import org.apache.flink.iteration.IterationBody;
 import org.apache.flink.iteration.IterationBodyResult;
 import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationConfig.OperatorLifeCycle;
 import org.apache.flink.iteration.Iterations;
 import org.apache.flink.iteration.ReplayableDataStreamList;
 import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
 import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -55,6 +61,8 @@
 import org.junit.Rule;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingQueue;
@@ -73,6 +81,7 @@
 
     private SharedReference<BlockingQueue<OutputRecord<Integer>>> collectedOutputRecord;
     private SharedReference<BlockingQueue<Long>> collectedWatermarks;
+    private SharedReference<BlockingQueue<Long>> collectedOutputs;
 
     @Before
     public void setup() throws Exception {
@@ -81,6 +90,7 @@
 
         collectedOutputRecord = sharedObjects.add(new LinkedBlockingQueue<>());
         collectedWatermarks = sharedObjects.add(new LinkedBlockingQueue<>());
+        collectedOutputs = sharedObjects.add(new LinkedBlockingQueue<>());
     }
 
     @After
@@ -136,6 +146,33 @@
                 .forEachRemaining(x -> assertEquals(Long.MAX_VALUE, (long) x));
     }
 
+    @Test
+    public void testPerRoundIterationWithState() throws Exception {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(1);
+        DataStream<Long> broadcastStream = env.fromElements(1L);
+        DataStream<Long> inputStream = env.fromElements(1L);
+        DataStreamList outputStream =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(inputStream),
+                        ReplayableDataStreamList.replay(broadcastStream),
+                        IterationConfig.newBuilder()
+                                .setOperatorLifeCycle(OperatorLifeCycle.PER_ROUND)
+                                .build(),
+                        new PerRoundIterationBodyWithState());
+
+        outputStream.<Long>get(0).addSink(new LongSink(collectedOutputs));
+        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+        miniCluster.executeJobBlocking(jobGraph);
+
+        List<Long> result = new ArrayList<>(3);
+        collectedOutputs.get().drainTo(result);
+        assertEquals(3, result.size());
+        for (long value : result) {
+            assertEquals(1L, value);
+        }
+    }
+
     private static JobGraph createPerRoundJobGraph(
             int numSources,
             int numRecordsPerSource,
@@ -229,6 +266,56 @@
         }
     }
 
+    private static class PerRoundIterationBodyWithState implements IterationBody {
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Long> variableStream = variableStreams.get(0);
+
+            DataStream<Long> feedback =
+                    variableStream.transform("mapWithState", Types.LONG, new MapWithState());
+
+            DataStream<Integer> terminationCriteria =
+                    feedback.<Long>flatMap(new TerminateOnMaxIter(2)).returns(Types.INT);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedback), DataStreamList.of(feedback), terminationCriteria);
+        }
+    }
+
+    private static class MapWithState extends AbstractStreamOperator<Long>
+            implements OneInputStreamOperator<Long, Long> {
+        private ListState<Long> listState;
+        private ListState<Long> unionState;
+        private BroadcastState<Long, Long> broadcastState;
+
+        @Override
+        public void processElement(StreamRecord<Long> element) throws Exception {
+            long val = element.getValue();
+            listState.add(val);
+            unionState.add(val);
+            broadcastState.put(val, val);
+            output.collect(element);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            listState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("longState", Types.LONG));
+            unionState =
+                    context.getOperatorStateStore()
+                            .getUnionListState(new ListStateDescriptor<>("unionState", Types.LONG));
+            broadcastState =
+                    context.getOperatorStateStore()
+                            .getBroadcastState(
+                                    new MapStateDescriptor<>(
+                                            "broadcastState", Types.LONG, Types.LONG));
+        }
+    }
+
     private static class LongSink implements SinkFunction<Long> {
         private final SharedReference<BlockingQueue<Long>> collectedLong;