[FLINK-35066] Fix the unwrap from IterationRecord during keyBy
This closes #260.
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index 509d8a7..25b97f4 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -127,8 +127,10 @@
@Override
public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record)
throws Exception {
- reusedInput.replace(record.getValue(), record.getTimestamp());
- input.setKeyContextElement(reusedInput);
+ if (record.getValue().getType() == IterationRecord.Type.RECORD) {
+ reusedInput.replace(record.getValue().getValue(), record.getTimestamp());
+ input.setKeyContextElement(reusedInput);
+ }
}
}
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index 5d4f9b4..903d14f 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -50,6 +50,40 @@
}
@Override
+ public void setKeyContextElement1(StreamRecord<?> record) throws Exception {
+ setKeyContextElement(record, reusedInput1, wrappedOperator::setKeyContextElement1);
+ }
+
+ @Override
+ public void setKeyContextElement2(StreamRecord<?> record) throws Exception {
+ setKeyContextElement(record, reusedInput2, wrappedOperator::setKeyContextElement2);
+ }
+
+ private void setKeyContextElement(
+ StreamRecord<?> record,
+ StreamRecord<?> reusedInput,
+ ThrowingConsumer<StreamRecord<?>, Exception> processor)
+ throws Exception {
+ if (!(record.getValue() instanceof IterationRecord)) {
+ super.setKeyContextElement1(record);
+ return;
+ }
+
+ IterationRecord<?> iterationRecord = (IterationRecord<?>) record.getValue();
+ switch (iterationRecord.getType()) {
+ case RECORD:
+ reusedInput.replace(iterationRecord.getValue(), record.getTimestamp());
+ processor.accept(reusedInput);
+ break;
+ case EPOCH_WATERMARK:
+ break;
+ default:
+ throw new FlinkRuntimeException(
+ "Not supported iteration record type: " + iterationRecord.getType());
+ }
+ }
+
+ @Override
public void processElement1(StreamRecord<IterationRecord<IN1>> element) throws Exception {
processElement(element, 0, reusedInput1, wrappedOperator::processElement1);
}
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
index c7ebbbb..4590db2 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
@@ -177,7 +177,7 @@
if (element.getValue().getType() == IterationRecord.Type.RECORD) {
// Ensures the operators are created.
getWrappedOperator(element.getValue().getEpoch());
- reusedInput.replace(element.getValue(), element.getTimestamp());
+ reusedInput.replace(element.getValue().getValue(), element.getTimestamp());
operatorInputsByEpoch
.get(element.getValue().getEpoch())
.get(inputIndex)
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
index 76b4baa..14450e0 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
@@ -23,6 +23,7 @@
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -75,9 +76,18 @@
new StreamTaskMailboxTestHarnessBuilder<>(
MultipleInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index 81eb52b..3bdde8e 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -24,6 +24,7 @@
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.operator.OperatorWrapper;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -79,7 +80,10 @@
new StreamTaskMailboxTestHarnessBuilder<>(
OneInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
index 7c4d791..ebe2fad 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
@@ -23,6 +23,7 @@
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -74,8 +75,14 @@
new StreamTaskMailboxTestHarnessBuilder<>(
TwoInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
index 3c61f51..7f36aca 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
@@ -25,6 +25,7 @@
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
import org.apache.flink.iteration.operator.allround.LifeCycle;
import org.apache.flink.iteration.operator.allround.OneInputAllRoundWrapperOperator;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -77,9 +78,18 @@
new StreamTaskMailboxTestHarnessBuilder<>(
MultipleInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
index cd6b00a..295aa63 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
@@ -28,6 +28,7 @@
import org.apache.flink.iteration.operator.OperatorWrapper;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
import org.apache.flink.iteration.operator.allround.LifeCycle;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -93,7 +94,10 @@
new StreamTaskMailboxTestHarnessBuilder<>(
OneInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2));
diff --git a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
index 1d41bbd..70aa34d 100644
--- a/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/flink-ml-iteration-common/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
@@ -24,6 +24,7 @@
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
import org.apache.flink.iteration.operator.allround.LifeCycle;
+import org.apache.flink.iteration.proxy.ProxyKeySelector;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -75,8 +76,14 @@
new StreamTaskMailboxTestHarnessBuilder<>(
TwoInputStreamTask::new,
new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
- .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
+ .addInput(
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO),
+ 1,
+ new ProxyKeySelector<Integer, Integer>(x -> x % 2))
.setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
.build()) {
harness.processElement(new StreamRecord<>(IterationRecord.newRecord(5, 1), 2), 0);