[hotfix] Fix nullpointer exception when broadcast variables are cleaned
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
index 170fccc..2df2654 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
@@ -59,6 +59,9 @@
@SuppressWarnings("rawtypes")
private final List[] caches;
+ /** whether each broadcast input has finished. */
+ private boolean[] cachesReady;
+
/** state storage of the broadcast inputs. */
private ListState<?>[] cacheStates;
@@ -78,6 +81,7 @@
inputList.add(new ProxyInput(this, i + 1));
}
this.caches = new List[inTypes.length];
+ this.cachesReady = new boolean[inTypes.length];
for (int i = 0; i < inTypes.length; i++) {
caches[i] = new ArrayList<>();
}
@@ -92,6 +96,7 @@
@Override
public void endInput(int i) {
+ cachesReady[i - 1] = true;
BroadcastContext.markCacheFinished(
broadcastStreamNames[i - 1] + "-" + getRuntimeContext().getIndexOfThisSubtask());
}
@@ -104,12 +109,7 @@
cacheStates[i].clear();
cacheStates[i].addAll(caches[i]);
cacheReadyStates[i].clear();
- boolean isCacheFinished =
- BroadcastContext.isCacheFinished(
- broadcastStreamNames[i]
- + "-"
- + getRuntimeContext().getIndexOfThisSubtask());
- cacheReadyStates[i].add(isCacheFinished);
+ cacheReadyStates[i].add(cachesReady[i]);
}
}
@@ -133,6 +133,8 @@
OperatorStateUtils.getUniqueElement(
cacheReadyStates[i], "cache_ready_state_" + i)
.orElse(false);
+ // TODO: there may be a memory leak if the BroadcastWrapper finishes fast before this
+ // task finishes.
BroadcastContext.putBroadcastVariable(
broadcastStreamNames[i] + "-" + getRuntimeContext().getIndexOfThisSubtask(),
Tuple2.of(cacheReady, caches[i]));
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
index 2457471..af6ba24 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
@@ -20,15 +20,23 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.config.IterationOptions;
import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
import org.junit.Assert;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import java.util.Arrays;
import java.util.List;
@@ -36,13 +44,15 @@
/** Tests the {@link BroadcastVariableReceiverOperator}. */
public class BroadcastVariableReceiverOperatorTest {
+ @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
private static final String[] BROADCAST_NAMES = new String[] {"source1", "source2"};
private static final TypeInformation<?>[] TYPE_INFORMATIONS =
new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO};
@Test
- public void testCacheStreamOperator() throws Exception {
+ public void test() throws Exception {
OperatorID operatorId = new OperatorID();
try (StreamTaskMailboxTestHarness<Integer> harness =
@@ -74,6 +84,47 @@
}
}
+ @Test
+ public void testVariableCleanedBeforeSnapShot() throws Exception {
+ OperatorID operatorId = new OperatorID();
+
+ try (StreamTaskMailboxTestHarness<Integer> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ MultipleInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO)
+ .addInput(BasicTypeInfo.INT_TYPE_INFO)
+ .setupOutputForSingletonOperatorChain(
+ new BroadcastVariableReceiverOperatorFactory<>(
+ new String[] {BROADCAST_NAMES[0]},
+ new TypeInformation[] {TYPE_INFORMATIONS[0]}),
+ operatorId)
+ .buildUnrestored()) {
+ harness.getStreamTask()
+ .getEnvironment()
+ .getTaskManagerInfo()
+ .getConfiguration()
+ .set(
+ IterationOptions.DATA_CACHE_PATH,
+ "file://" + tempFolder.newFolder().getAbsolutePath());
+ harness.getStreamTask().restore();
+ harness.processElement(new StreamRecord<>(1, 2), 0);
+ harness.processElement(new StreamRecord<>(2, 3), 0);
+ harness.endInput();
+ // clean broadcast variables here.
+ BroadcastContext.remove(BROADCAST_NAMES[0] + "-" + 0);
+
+ harness.getStreamTask()
+ .triggerCheckpointOnBarrier(
+ new CheckpointMetaData(1, 2),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference.getDefault()),
+ new CheckpointMetricsBuilder()
+ .setAlignmentDurationNanos(0)
+ .setBytesProcessedDuringAlignment(0));
+ harness.waitForTaskCompletion();
+ }
+ }
+
public static void compareLists(List<Integer> expected, List<?> actual) {
int[] actualInts =
actual.stream().map(x -> (Integer) x).mapToInt(Integer::intValue).toArray();