[FLINK-35066] Fix the unwrap from IterationRecord during keyBy

This closes #260.
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index 509d8a7..25b97f4 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -127,8 +127,10 @@
         @Override
         public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record)
                 throws Exception {
-            reusedInput.replace(record.getValue(), record.getTimestamp());
-            input.setKeyContextElement(reusedInput);
+            if (record.getValue().getType() == IterationRecord.Type.RECORD) {
+                reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
+                input.setKeyContextElement(reusedInput);
+            }
         }
     }
 }
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index 5d4f9b4..903d14f 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -50,6 +50,40 @@
     }
 
     @Override
+    public void setKeyContextElement1(StreamRecord<?> record) throws Exception {
+        setKeyContextElement(record, reusedInput1, wrappedOperator::setKeyContextElement1);
+    }
+
+    @Override
+    public void setKeyContextElement2(StreamRecord<?> record) throws Exception {
+        setKeyContextElement(record, reusedInput2, wrappedOperator::setKeyContextElement2);
+    }
+
+    private void setKeyContextElement(
+            StreamRecord<?> record,
+            StreamRecord<?> reusedInput,
+            ThrowingConsumer<StreamRecord<?>, Exception> processor)
+            throws Exception {
+        if (!(record.getValue() instanceof IterationRecord)) {
+            super.setKeyContextElement1(record);
+            return;
+        }
+
+        IterationRecord<?> iterationRecord = (IterationRecord<?>) record.getValue();
+        switch (iterationRecord.getType()) {
+            case RECORD:
+                reusedInput.replace(iterationRecord.getValue(), record.getTimestamp());
+                processor.accept(reusedInput);
+                break;
+            case EPOCH_WATERMARK:
+                break;
+            default:
+                throw new FlinkRuntimeException(
+                        "Not supported iteration record type: " + iterationRecord.getType());
+        }
+    }
+
+    @Override
     public void processElement1(StreamRecord<IterationRecord<IN1>> element) throws Exception {
         processElement(element, 0, reusedInput1, wrappedOperator::processElement1);
     }
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
index c7ebbbb..4590db2 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
@@ -177,7 +177,7 @@
             if (element.getValue().getType() == IterationRecord.Type.RECORD) {
                 // Ensures the operators are created.
                 getWrappedOperator(element.getValue().getEpoch());
-                reusedInput.replace(element.getValue(), element.getTimestamp());
+                reusedInput.replace(element.getValue().getValue(), element.getTimestamp());
                 operatorInputsByEpoch
                         .get(element.getValue().getEpoch())
                         .get(inputIndex)
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
index 76b4baa..14450e0 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
@@ -23,6 +23,7 @@
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -75,9 +76,18 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 MultipleInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index 81eb52b..3bdde8e 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -24,6 +24,7 @@
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.OperatorWrapper;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -79,7 +80,10 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 OneInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
index 7c4d791..ebe2fad 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
@@ -23,6 +23,7 @@
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -74,8 +75,14 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 TwoInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
index 3c61f51..7f36aca 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
@@ -25,6 +25,7 @@
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
 import org.apache.flink.iteration.operator.allround.LifeCycle;
 import org.apache.flink.iteration.operator.allround.OneInputAllRoundWrapperOperator;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -77,9 +78,18 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 MultipleInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
index cd6b00a..295aa63 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
@@ -28,6 +28,7 @@
 import org.apache.flink.iteration.operator.OperatorWrapper;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
 import org.apache.flink.iteration.operator.allround.LifeCycle;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -93,7 +94,10 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 OneInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
index 1d41bbd..70aa34d 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
@@ -24,6 +24,7 @@
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
 import org.apache.flink.iteration.operator.allround.LifeCycle;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -75,8 +76,14 @@
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 TwoInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+                        .addInput(
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+                                1,
+                                new ProxyKeySelector<Integer, Integer>(x -> x % 2))
                         .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
                         .build()) {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);