[hotfix] Fix nullpointer exception when broadcast variables are cleaned
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
index 170fccc..2df2654 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperator.java
@@ -59,6 +59,9 @@
     @SuppressWarnings("rawtypes")
     private final List[] caches;
 
+    /** whether each broadcast input has finished. */
+    private boolean[] cachesReady;
+
     /** state storage of the broadcast inputs. */
     private ListState<?>[] cacheStates;
 
@@ -78,6 +81,7 @@
             inputList.add(new ProxyInput(this, i + 1));
         }
         this.caches = new List[inTypes.length];
+        this.cachesReady = new boolean[inTypes.length];
         for (int i = 0; i < inTypes.length; i++) {
             caches[i] = new ArrayList<>();
         }
@@ -92,6 +96,7 @@
 
     @Override
     public void endInput(int i) {
+        cachesReady[i - 1] = true;
         BroadcastContext.markCacheFinished(
                 broadcastStreamNames[i - 1] + "-" + getRuntimeContext().getIndexOfThisSubtask());
     }
@@ -104,12 +109,7 @@
             cacheStates[i].clear();
             cacheStates[i].addAll(caches[i]);
             cacheReadyStates[i].clear();
-            boolean isCacheFinished =
-                    BroadcastContext.isCacheFinished(
-                            broadcastStreamNames[i]
-                                    + "-"
-                                    + getRuntimeContext().getIndexOfThisSubtask());
-            cacheReadyStates[i].add(isCacheFinished);
+            cacheReadyStates[i].add(cachesReady[i]);
         }
     }
 
@@ -133,6 +133,8 @@
                     OperatorStateUtils.getUniqueElement(
                                     cacheReadyStates[i], "cache_ready_state_" + i)
                             .orElse(false);
+            // TODO: there may be a memory leak if the BroadcastWrapper finishes fast before this
+            // task finishes.
             BroadcastContext.putBroadcastVariable(
                     broadcastStreamNames[i] + "-" + getRuntimeContext().getIndexOfThisSubtask(),
                     Tuple2.of(cacheReady, caches[i]));
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
index 2457471..af6ba24 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/broadcast/operator/BroadcastVariableReceiverOperatorTest.java
@@ -20,15 +20,23 @@
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.config.IterationOptions;
 import org.apache.flink.ml.common.broadcast.BroadcastContext;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
 
 import org.junit.Assert;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 
 import java.util.Arrays;
 import java.util.List;
@@ -36,13 +44,15 @@
 /** Tests the {@link BroadcastVariableReceiverOperator}. */
 public class BroadcastVariableReceiverOperatorTest {
 
+    @Rule public TemporaryFolder tempFolder = new TemporaryFolder();
+
     private static final String[] BROADCAST_NAMES = new String[] {"source1", "source2"};
 
     private static final TypeInformation<?>[] TYPE_INFORMATIONS =
             new TypeInformation[] {BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.INT_TYPE_INFO};
 
     @Test
-    public void testCacheStreamOperator() throws Exception {
+    public void test() throws Exception {
         OperatorID operatorId = new OperatorID();
 
         try (StreamTaskMailboxTestHarness<Integer> harness =
@@ -74,6 +84,47 @@
         }
     }
 
+    @Test
+    public void testVariableCleanedBeforeSnapShot() throws Exception {
+        OperatorID operatorId = new OperatorID();
+
+        try (StreamTaskMailboxTestHarness<Integer> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                MultipleInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .setupOutputForSingletonOperatorChain(
+                                new BroadcastVariableReceiverOperatorFactory<>(
+                                        new String[] {BROADCAST_NAMES[0]},
+                                        new TypeInformation[] {TYPE_INFORMATIONS[0]}),
+                                operatorId)
+                        .buildUnrestored()) {
+            harness.getStreamTask()
+                    .getEnvironment()
+                    .getTaskManagerInfo()
+                    .getConfiguration()
+                    .set(
+                            IterationOptions.DATA_CACHE_PATH,
+                            "file://" + tempFolder.newFolder().getAbsolutePath());
+            harness.getStreamTask().restore();
+            harness.processElement(new StreamRecord<>(1, 2), 0);
+            harness.processElement(new StreamRecord<>(2, 3), 0);
+            harness.endInput();
+            // clean broadcast variables here.
+            BroadcastContext.remove(BROADCAST_NAMES[0] + "-" + 0);
+
+            harness.getStreamTask()
+                    .triggerCheckpointOnBarrier(
+                            new CheckpointMetaData(1, 2),
+                            CheckpointOptions.alignedNoTimeout(
+                                    CheckpointType.CHECKPOINT,
+                                    CheckpointStorageLocationReference.getDefault()),
+                            new CheckpointMetricsBuilder()
+                                    .setAlignmentDurationNanos(0)
+                                    .setBytesProcessedDuringAlignment(0));
+            harness.waitForTaskCompletion();
+        }
+    }
+
     public static void compareLists(List<Integer> expected, List<?> actual) {
         int[] actualInts =
                 actual.stream().map(x -> (Integer) x).mapToInt(Integer::intValue).toArray();