[FLINK-34549][API] Implement applyToAllPartitions for non-partitioned context
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java
index 34e6df8..0bbba01 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContext.java
@@ -18,23 +18,61 @@
 
 package org.apache.flink.datastream.impl.context;
 
+import org.apache.flink.datastream.api.common.Collector;
 import org.apache.flink.datastream.api.context.JobInfo;
 import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.TaskInfo;
 import org.apache.flink.datastream.api.function.ApplyPartitionFunction;
 import org.apache.flink.metrics.MetricGroup;
 
+import java.util.Set;
+
 /** The default implementation of {@link NonPartitionedContext}. */
 public class DefaultNonPartitionedContext<OUT> implements NonPartitionedContext<OUT> {
     private final DefaultRuntimeContext context;
 
-    public DefaultNonPartitionedContext(DefaultRuntimeContext context) {
+    private final DefaultPartitionedContext partitionedContext;
+
+    private final Collector<OUT> collector;
+
+    private final boolean isKeyed;
+
+    private final Set<Object> keySet;
+
+    public DefaultNonPartitionedContext(
+            DefaultRuntimeContext context,
+            DefaultPartitionedContext partitionedContext,
+            Collector<OUT> collector,
+            boolean isKeyed,
+            Set<Object> keySet) {
         this.context = context;
+        this.partitionedContext = partitionedContext;
+        this.collector = collector;
+        this.isKeyed = isKeyed;
+        this.keySet = keySet;
     }
 
     @Override
-    public void applyToAllPartitions(ApplyPartitionFunction<OUT> applyPartitionFunction) {
-        // TODO implements this method.
+    public void applyToAllPartitions(ApplyPartitionFunction<OUT> applyPartitionFunction)
+            throws Exception {
+        if (isKeyed) {
+            for (Object key : keySet) {
+                partitionedContext
+                        .getStateManager()
+                        .executeInKeyContext(
+                                () -> {
+                                    try {
+                                        applyPartitionFunction.apply(collector, partitionedContext);
+                                    } catch (Exception e) {
+                                        throw new RuntimeException(e);
+                                    }
+                                },
+                                key);
+            }
+        } else {
+            // non-keyed operator has only one partition.
+            applyPartitionFunction.apply(collector, partitionedContext);
+        }
     }
 
     @Override
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java
index 9b60437..1a72476 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContext.java
@@ -18,25 +18,69 @@
 
 package org.apache.flink.datastream.impl.context;
 
+import org.apache.flink.datastream.api.common.Collector;
 import org.apache.flink.datastream.api.context.JobInfo;
 import org.apache.flink.datastream.api.context.TaskInfo;
 import org.apache.flink.datastream.api.context.TwoOutputNonPartitionedContext;
 import org.apache.flink.datastream.api.function.TwoOutputApplyPartitionFunction;
 import org.apache.flink.metrics.MetricGroup;
 
+import java.util.Set;
+
 /** The default implementation of {@link TwoOutputNonPartitionedContext}. */
 public class DefaultTwoOutputNonPartitionedContext<OUT1, OUT2>
         implements TwoOutputNonPartitionedContext<OUT1, OUT2> {
-    private final DefaultRuntimeContext context;
+    protected final DefaultRuntimeContext context;
 
-    public DefaultTwoOutputNonPartitionedContext(DefaultRuntimeContext context) {
+    private final DefaultPartitionedContext partitionedContext;
+
+    protected final Collector<OUT1> firstCollector;
+
+    protected final Collector<OUT2> secondCollector;
+
+    private final boolean isKeyed;
+
+    private final Set<Object> keySet;
+
+    public DefaultTwoOutputNonPartitionedContext(
+            DefaultRuntimeContext context,
+            DefaultPartitionedContext partitionedContext,
+            Collector<OUT1> firstCollector,
+            Collector<OUT2> secondCollector,
+            boolean isKeyed,
+            Set<Object> keySet) {
         this.context = context;
+        this.partitionedContext = partitionedContext;
+        this.firstCollector = firstCollector;
+        this.secondCollector = secondCollector;
+        this.isKeyed = isKeyed;
+        this.keySet = keySet;
     }
 
     @Override
     public void applyToAllPartitions(
-            TwoOutputApplyPartitionFunction<OUT1, OUT2> applyPartitionFunction) {
-        // TODO implements this method.
+            TwoOutputApplyPartitionFunction<OUT1, OUT2> applyPartitionFunction) throws Exception {
+        if (isKeyed) {
+            for (Object key : keySet) {
+                partitionedContext
+                        .getStateManager()
+                        .executeInKeyContext(
+                                () -> {
+                                    try {
+                                        applyPartitionFunction.apply(
+                                                firstCollector,
+                                                secondCollector,
+                                                partitionedContext);
+                                    } catch (Exception e) {
+                                        throw new RuntimeException(e);
+                                    }
+                                },
+                                key);
+            }
+        } else {
+            // non-keyed operator has only one partition.
+            applyPartitionFunction.apply(firstCollector, secondCollector, partitionedContext);
+        }
     }
 
     @Override
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java
index 1b729f8..9f9c0bb 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperator.java
@@ -19,26 +19,36 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.OneInputStreamProcessFunction;
 import org.apache.flink.datastream.api.stream.KeyedPartitionStream;
 import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector;
 import org.apache.flink.datastream.impl.common.OutputCollector;
 import org.apache.flink.datastream.impl.common.TimestampCollector;
+import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext;
 import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 import javax.annotation.Nullable;
 
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
 /** Operator for {@link OneInputStreamProcessFunction} in {@link KeyedPartitionStream}. */
 public class KeyedProcessOperator<KEY, IN, OUT> extends ProcessOperator<IN, OUT>
         implements Triggerable<KEY, VoidNamespace> {
     private transient InternalTimerService<VoidNamespace> timerService;
 
+    private transient Set<Object> keySet;
+
     @Nullable private final KeySelector<OUT, KEY> outKeySelector;
 
     public KeyedProcessOperator(OneInputStreamProcessFunction<IN, OUT> userFunction) {
@@ -56,6 +66,7 @@
     public void open() throws Exception {
         this.timerService =
                 getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this);
+        this.keySet = new HashSet<>();
         super.open();
     }
 
@@ -95,4 +106,24 @@
     protected ProcessingTimeManager getProcessingTimeManager() {
         return new DefaultProcessingTimeManager(timerService);
     }
+
+    @Override
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, outputCollector, true, keySet);
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement1(StreamRecord record) throws Exception {
+        setKeyContextElement(record, getStateKeySelector1());
+    }
+
+    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
+            throws Exception {
+        checkNotNull(selector);
+        Object key = selector.getKey(record.getValue());
+        setCurrentKey(key);
+        keySet.add(key);
+    }
 }
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java
index d303d0c..d46da49 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperator.java
@@ -19,27 +19,35 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.TwoInputBroadcastStreamProcessFunction;
 import org.apache.flink.datastream.api.stream.KeyedPartitionStream;
 import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector;
 import org.apache.flink.datastream.impl.common.OutputCollector;
 import org.apache.flink.datastream.impl.common.TimestampCollector;
+import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext;
 import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 import javax.annotation.Nullable;
 
+import java.util.HashSet;
+import java.util.Set;
+
 /** Operator for {@link TwoInputBroadcastStreamProcessFunction} in {@link KeyedPartitionStream}. */
 public class KeyedTwoInputBroadcastProcessOperator<KEY, IN1, IN2, OUT>
         extends TwoInputBroadcastProcessOperator<IN1, IN2, OUT>
         implements Triggerable<KEY, VoidNamespace> {
     private transient InternalTimerService<VoidNamespace> timerService;
 
+    private transient Set<Object> keySet;
+
     @Nullable private final KeySelector<OUT, KEY> outKeySelector;
 
     public KeyedTwoInputBroadcastProcessOperator(
@@ -58,6 +66,7 @@
     public void open() throws Exception {
         this.timerService =
                 getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this);
+        this.keySet = new HashSet<>();
         super.open();
     }
 
@@ -96,4 +105,27 @@
                                         partitionedContext),
                         timer.getKey());
     }
+
+    @Override
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, collector, true, keySet);
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    // Only element from input1 should be considered as the other side is broadcast input.
+    public void setKeyContextElement1(StreamRecord record) throws Exception {
+        setKeyContextElement(record, getStateKeySelector1());
+    }
+
+    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
+            throws Exception {
+        if (selector == null) {
+            return;
+        }
+        Object key = selector.getKey(record.getValue());
+        setCurrentKey(key);
+        keySet.add(key);
+    }
 }
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java
index 36ef958..d646c2b 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperator.java
@@ -19,21 +19,27 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.TwoInputNonBroadcastStreamProcessFunction;
 import org.apache.flink.datastream.api.stream.KeyedPartitionStream;
 import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector;
 import org.apache.flink.datastream.impl.common.OutputCollector;
 import org.apache.flink.datastream.impl.common.TimestampCollector;
+import org.apache.flink.datastream.impl.context.DefaultNonPartitionedContext;
 import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 import javax.annotation.Nullable;
 
+import java.util.HashSet;
+import java.util.Set;
+
 /**
  * Operator for {@link TwoInputNonBroadcastStreamProcessFunction} in {@link KeyedPartitionStream}.
  */
@@ -42,6 +48,8 @@
         implements Triggerable<KEY, VoidNamespace> {
     private transient InternalTimerService<VoidNamespace> timerService;
 
+    private transient Set<Object> keySet;
+
     @Nullable private final KeySelector<OUT, KEY> outKeySelector;
 
     public KeyedTwoInputNonBroadcastProcessOperator(
@@ -60,6 +68,7 @@
     public void open() throws Exception {
         this.timerService =
                 getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this);
+        this.keySet = new HashSet<>();
         super.open();
     }
 
@@ -98,4 +107,32 @@
                                         partitionedContext),
                         timer.getKey());
     }
+
+    @Override
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, collector, true, keySet);
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement1(StreamRecord record) throws Exception {
+        setKeyContextElement(record, getStateKeySelector1());
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement2(StreamRecord record) throws Exception {
+        setKeyContextElement(record, getStateKeySelector2());
+    }
+
+    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
+            throws Exception {
+        if (selector == null) {
+            return;
+        }
+        Object key = selector.getKey(record.getValue());
+        setCurrentKey(key);
+        keySet.add(key);
+    }
 }
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java
index aa7de64..61035d1 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperator.java
@@ -20,27 +20,37 @@
 
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
+import org.apache.flink.datastream.api.context.TwoOutputNonPartitionedContext;
 import org.apache.flink.datastream.api.function.TwoOutputStreamProcessFunction;
 import org.apache.flink.datastream.impl.common.KeyCheckedOutputCollector;
 import org.apache.flink.datastream.impl.common.OutputCollector;
 import org.apache.flink.datastream.impl.common.TimestampCollector;
 import org.apache.flink.datastream.impl.context.DefaultProcessingTimeManager;
+import org.apache.flink.datastream.impl.context.DefaultTwoOutputNonPartitionedContext;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
 import org.apache.flink.streaming.api.operators.Triggerable;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nullable;
 
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
 /** */
 public class KeyedTwoOutputProcessOperator<KEY, IN, OUT_MAIN, OUT_SIDE>
         extends TwoOutputProcessOperator<IN, OUT_MAIN, OUT_SIDE>
         implements Triggerable<KEY, VoidNamespace> {
     private transient InternalTimerService<VoidNamespace> timerService;
 
+    private transient Set<Object> keySet;
+
     @Nullable private final KeySelector<OUT_MAIN, KEY> mainOutKeySelector;
 
     @Nullable private final KeySelector<OUT_SIDE, KEY> sideOutKeySelector;
@@ -69,6 +79,7 @@
     public void open() throws Exception {
         this.timerService =
                 getInternalTimerService("processing timer", VoidNamespaceSerializer.INSTANCE, this);
+        this.keySet = new HashSet<>();
         super.open();
     }
 
@@ -120,4 +131,24 @@
                                         partitionedContext),
                         timer.getKey());
     }
+
+    @Override
+    protected TwoOutputNonPartitionedContext<OUT_MAIN, OUT_SIDE> getNonPartitionedContext() {
+        return new DefaultTwoOutputNonPartitionedContext<>(
+                context, partitionedContext, mainCollector, sideCollector, true, keySet);
+    }
+
+    @Override
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    public void setKeyContextElement1(StreamRecord record) throws Exception {
+        setKeyContextElement(record, getStateKeySelector1());
+    }
+
+    private <T> void setKeyContextElement(StreamRecord<T> record, KeySelector<T, ?> selector)
+            throws Exception {
+        checkNotNull(selector);
+        Object key = selector.getKey(record.getValue());
+        setCurrentKey(key);
+        keySet.add(key);
+    }
 }
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java
index 9ea6996..f46488d 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/ProcessOperator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.OneInputStreamProcessFunction;
 import org.apache.flink.datastream.impl.common.OutputCollector;
@@ -43,7 +44,7 @@
 
     protected transient DefaultPartitionedContext partitionedContext;
 
-    protected transient DefaultNonPartitionedContext<OUT> nonPartitionedContext;
+    protected transient NonPartitionedContext<OUT> nonPartitionedContext;
 
     protected transient TimestampCollector<OUT> outputCollector;
 
@@ -67,8 +68,8 @@
         partitionedContext =
                 new DefaultPartitionedContext(
                         context, this::currentKey, this::setCurrentKey, getProcessingTimeManager());
-        nonPartitionedContext = new DefaultNonPartitionedContext<>(context);
         outputCollector = getOutputCollector();
+        nonPartitionedContext = getNonPartitionedContext();
     }
 
     @Override
@@ -93,4 +94,9 @@
     protected ProcessingTimeManager getProcessingTimeManager() {
         return UnsupportedProcessingTimeManager.INSTANCE;
     }
+
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, outputCollector, false, null);
+    }
 }
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java
index a11b2b7..4030002 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.TwoInputBroadcastStreamProcessFunction;
 import org.apache.flink.datastream.impl.common.OutputCollector;
@@ -48,7 +49,7 @@
 
     protected transient DefaultPartitionedContext partitionedContext;
 
-    protected transient DefaultNonPartitionedContext<OUT> nonPartitionedContext;
+    protected transient NonPartitionedContext<OUT> nonPartitionedContext;
 
     public TwoInputBroadcastProcessOperator(
             TwoInputBroadcastStreamProcessFunction<IN1, IN2, OUT> userFunction) {
@@ -71,7 +72,7 @@
         this.partitionedContext =
                 new DefaultPartitionedContext(
                         context, this::currentKey, this::setCurrentKey, getProcessingTimeManager());
-        this.nonPartitionedContext = new DefaultNonPartitionedContext<>(context);
+        this.nonPartitionedContext = getNonPartitionedContext();
     }
 
     @Override
@@ -91,6 +92,11 @@
         return new OutputCollector<>(output);
     }
 
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, collector, false, null);
+    }
+
     @Override
     public void endInput(int inputId) throws Exception {
         // sanity check.
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java
index 56488ef..7918203 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.datastream.impl.operators;
 
 import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.datastream.api.context.NonPartitionedContext;
 import org.apache.flink.datastream.api.context.ProcessingTimeManager;
 import org.apache.flink.datastream.api.function.TwoInputNonBroadcastStreamProcessFunction;
 import org.apache.flink.datastream.impl.common.OutputCollector;
@@ -48,7 +49,7 @@
 
     protected transient DefaultPartitionedContext partitionedContext;
 
-    protected transient DefaultNonPartitionedContext<OUT> nonPartitionedContext;
+    protected transient NonPartitionedContext<OUT> nonPartitionedContext;
 
     public TwoInputNonBroadcastProcessOperator(
             TwoInputNonBroadcastStreamProcessFunction<IN1, IN2, OUT> userFunction) {
@@ -71,7 +72,7 @@
         this.partitionedContext =
                 new DefaultPartitionedContext(
                         context, this::currentKey, this::setCurrentKey, getProcessingTimeManager());
-        this.nonPartitionedContext = new DefaultNonPartitionedContext<>(context);
+        this.nonPartitionedContext = getNonPartitionedContext();
     }
 
     @Override
@@ -91,6 +92,11 @@
         return new OutputCollector<>(output);
     }
 
+    protected NonPartitionedContext<OUT> getNonPartitionedContext() {
+        return new DefaultNonPartitionedContext<>(
+                context, partitionedContext, collector, false, null);
+    }
+
     @Override
     public void endInput(int inputId) throws Exception {
         // sanity check.
diff --git a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java
index 50b2ad6..e800001 100644
--- a/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java
+++ b/flink-datastream/src/main/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperator.java
@@ -82,7 +82,7 @@
         this.partitionedContext =
                 new DefaultPartitionedContext(
                         context, this::currentKey, this::setCurrentKey, getProcessingTimeManager());
-        this.nonPartitionedContext = new DefaultTwoOutputNonPartitionedContext<>(context);
+        this.nonPartitionedContext = getNonPartitionedContext();
     }
 
     @Override
@@ -110,6 +110,11 @@
         throw new UnsupportedOperationException("The key is only defined for keyed operator");
     }
 
+    protected TwoOutputNonPartitionedContext<OUT_MAIN, OUT_SIDE> getNonPartitionedContext() {
+        return new DefaultTwoOutputNonPartitionedContext<>(
+                context, partitionedContext, mainCollector, sideCollector, false, null);
+    }
+
     protected ProcessingTimeManager getProcessingTimeManager() {
         return UnsupportedProcessingTimeManager.INSTANCE;
     }
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java
new file mode 100644
index 0000000..5c8311e
--- /dev/null
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/ContextTestUtils.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.impl.context;
+
+import org.apache.flink.runtime.jobgraph.JobType;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
+import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
+import org.apache.flink.streaming.util.MockStreamingRuntimeContext;
+
+/** Test utils for things related to context. */
+public final class ContextTestUtils {
+    public static StreamingRuntimeContext createStreamingRuntimeContext() {
+        return new MockStreamingRuntimeContext(
+                false,
+                2,
+                1,
+                new MockEnvironmentBuilder()
+                        .setTaskName("mockTask")
+                        .setManagedMemorySize(4 * MemoryManager.DEFAULT_PAGE_SIZE)
+                        .setParallelism(2)
+                        .setMaxParallelism(2)
+                        .setSubtaskIndex(1)
+                        .setJobType(JobType.STREAMING)
+                        .setJobName("mockJob")
+                        .build());
+    }
+}
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java
new file mode 100644
index 0000000..1376277
--- /dev/null
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultNonPartitionedContextTest.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.impl.context;
+
+import org.apache.flink.datastream.impl.common.TestingTimestampCollector;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link DefaultNonPartitionedContext}. */
+class DefaultNonPartitionedContextTest {
+    @Test
+    void testApplyToAllPartitions() throws Exception {
+        AtomicInteger counter = new AtomicInteger(0);
+        List<Integer> collectedData = new ArrayList<>();
+
+        TestingTimestampCollector<Integer> collector =
+                TestingTimestampCollector.<Integer>builder()
+                        .setCollectConsumer(collectedData::add)
+                        .build();
+        CompletableFuture<Void> cf = new CompletableFuture<>();
+        DefaultRuntimeContext runtimeContext =
+                new DefaultRuntimeContext(
+                        ContextTestUtils.createStreamingRuntimeContext(), 1, 2, "mock-task");
+        DefaultNonPartitionedContext<Integer> nonPartitionedContext =
+                new DefaultNonPartitionedContext<>(
+                        runtimeContext,
+                        new DefaultPartitionedContext(
+                                runtimeContext,
+                                Optional::empty,
+                                (key) -> cf.complete(null),
+                                UnsupportedProcessingTimeManager.INSTANCE),
+                        collector,
+                        false,
+                        null);
+        nonPartitionedContext.applyToAllPartitions(
+                (out, ctx) -> {
+                    counter.incrementAndGet();
+                    out.collect(10);
+                });
+        assertThat(counter.get()).isEqualTo(1);
+        assertThat(cf).isNotCompleted();
+        assertThat(collectedData).containsExactly(10);
+    }
+
+    @Test
+    void testKeyedApplyToAllPartitions() throws Exception {
+        AtomicInteger counter = new AtomicInteger(0);
+        List<Integer> collectedData = new ArrayList<>();
+
+        TestingTimestampCollector<Integer> collector =
+                TestingTimestampCollector.<Integer>builder()
+                        .setCollectConsumer(collectedData::add)
+                        .build();
+        // put all keys
+        Set<Object> allKeys = new HashSet<>();
+        allKeys.add(1);
+        allKeys.add(2);
+        allKeys.add(3);
+
+        AtomicInteger currentKey = new AtomicInteger(-1);
+        DefaultRuntimeContext runtimeContext =
+                new DefaultRuntimeContext(
+                        ContextTestUtils.createStreamingRuntimeContext(), 1, 2, "mock-task");
+        DefaultNonPartitionedContext<Integer> nonPartitionedContext =
+                new DefaultNonPartitionedContext<>(
+                        runtimeContext,
+                        new DefaultPartitionedContext(
+                                runtimeContext,
+                                currentKey::get,
+                                (key) -> currentKey.set((Integer) key),
+                                UnsupportedProcessingTimeManager.INSTANCE),
+                        collector,
+                        true,
+                        allKeys);
+        nonPartitionedContext.applyToAllPartitions(
+                (out, ctx) -> {
+                    counter.incrementAndGet();
+                    Integer key = ctx.getStateManager().getCurrentKey();
+                    assertThat(key).isIn(allKeys);
+                    out.collect(key);
+                });
+        assertThat(counter.get()).isEqualTo(allKeys.size());
+        assertThat(collectedData).containsExactlyInAnyOrder(1, 2, 3);
+    }
+}
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java
new file mode 100644
index 0000000..790ca7b
--- /dev/null
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/context/DefaultTwoOutputNonPartitionedContextTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.datastream.impl.context;
+
+import org.apache.flink.datastream.impl.common.TestingTimestampCollector;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link DefaultTwoOutputNonPartitionedContext}. */
+class DefaultTwoOutputNonPartitionedContextTest {
+    @Test
+    void testApplyToAllPartitions() throws Exception {
+        AtomicInteger counter = new AtomicInteger(0);
+        List<Integer> collectedFromFirstOutput = new ArrayList<>();
+        List<Long> collectedFromSecondOutput = new ArrayList<>();
+
+        TestingTimestampCollector<Integer> firstCollector =
+                TestingTimestampCollector.<Integer>builder()
+                        .setCollectConsumer(collectedFromFirstOutput::add)
+                        .build();
+        TestingTimestampCollector<Long> secondCollector =
+                TestingTimestampCollector.<Long>builder()
+                        .setCollectConsumer(collectedFromSecondOutput::add)
+                        .build();
+        CompletableFuture<Void> cf = new CompletableFuture<>();
+        DefaultRuntimeContext runtimeContext =
+                new DefaultRuntimeContext(
+                        ContextTestUtils.createStreamingRuntimeContext(), 1, 2, "mock-task");
+        DefaultTwoOutputNonPartitionedContext<Integer, Long> nonPartitionedContext =
+                new DefaultTwoOutputNonPartitionedContext<>(
+                        runtimeContext,
+                        new DefaultPartitionedContext(
+                                runtimeContext,
+                                Optional::empty,
+                                (key) -> cf.complete(null),
+                                UnsupportedProcessingTimeManager.INSTANCE),
+                        firstCollector,
+                        secondCollector,
+                        false,
+                        null);
+        nonPartitionedContext.applyToAllPartitions(
+                (firstOutput, secondOutput, ctx) -> {
+                    counter.incrementAndGet();
+                    firstOutput.collect(10);
+                    secondOutput.collect(20L);
+                });
+        assertThat(counter.get()).isEqualTo(1);
+        assertThat(cf).isNotCompleted();
+        assertThat(collectedFromFirstOutput).containsExactly(10);
+        assertThat(collectedFromSecondOutput).containsExactly(20L);
+    }
+
+    @Test
+    void testKeyedApplyToAllPartitions() throws Exception {
+        AtomicInteger counter = new AtomicInteger(0);
+        List<Integer> collectedFromFirstOutput = new ArrayList<>();
+        List<Long> collectedFromSecondOutput = new ArrayList<>();
+
+        TestingTimestampCollector<Integer> firstCollector =
+                TestingTimestampCollector.<Integer>builder()
+                        .setCollectConsumer(collectedFromFirstOutput::add)
+                        .build();
+        TestingTimestampCollector<Long> secondCollector =
+                TestingTimestampCollector.<Long>builder()
+                        .setCollectConsumer(collectedFromSecondOutput::add)
+                        .build();
+        // put all keys
+        Set<Object> allKeys = new HashSet<>();
+        allKeys.add(1);
+        allKeys.add(2);
+        allKeys.add(3);
+
+        AtomicInteger currentKey = new AtomicInteger(-1);
+        DefaultRuntimeContext runtimeContext =
+                new DefaultRuntimeContext(
+                        ContextTestUtils.createStreamingRuntimeContext(), 1, 2, "mock-task");
+        DefaultTwoOutputNonPartitionedContext<Integer, Long> nonPartitionedContext =
+                new DefaultTwoOutputNonPartitionedContext<>(
+                        runtimeContext,
+                        new DefaultPartitionedContext(
+                                runtimeContext,
+                                currentKey::get,
+                                (key) -> currentKey.set((Integer) key),
+                                UnsupportedProcessingTimeManager.INSTANCE),
+                        firstCollector,
+                        secondCollector,
+                        true,
+                        allKeys);
+        nonPartitionedContext.applyToAllPartitions(
+                (firstOut, secondOut, ctx) -> {
+                    counter.incrementAndGet();
+                    Integer key = ctx.getStateManager().getCurrentKey();
+                    assertThat(key).isIn(allKeys);
+                    firstOut.collect(key);
+                    secondOut.collect(Long.valueOf(key));
+                });
+        assertThat(counter.get()).isEqualTo(allKeys.size());
+        assertThat(collectedFromFirstOutput).containsExactlyInAnyOrder(1, 2, 3);
+        assertThat(collectedFromSecondOutput).containsExactlyInAnyOrder(1L, 2L, 3L);
+    }
+}
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java
index 163ba10..d744ba0 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedProcessOperatorTest.java
@@ -30,7 +30,7 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -70,7 +70,7 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
+        AtomicInteger counter = new AtomicInteger();
         KeyedProcessOperator<Integer, Integer, Integer> processOperator =
                 new KeyedProcessOperator<>(
                         new OneInputStreamProcessFunction<Integer, Integer>() {
@@ -84,7 +84,17 @@
 
                             @Override
                             public void endInput(NonPartitionedContext<Integer> ctx) {
-                                future.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                counter.incrementAndGet();
+                                                Integer currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                out.collect(currentKey);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -94,8 +104,15 @@
                         (KeySelector<Integer, Integer>) value -> value,
                         Types.INT)) {
             testHarness.open();
+            testHarness.processElement(new StreamRecord<>(1)); // key is 1
+            testHarness.processElement(new StreamRecord<>(2)); // key is 2
+            testHarness.processElement(new StreamRecord<>(3)); // key is 3
             testHarness.endInput();
-            assertThat(future).isCompleted();
+            assertThat(counter).hasValue(3);
+            Collection<StreamRecord<Integer>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(
+                            new StreamRecord<>(1), new StreamRecord<>(2), new StreamRecord<>(3));
         }
     }
 
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java
index c56a71c..7ad6070 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputBroadcastProcessOperatorTest.java
@@ -30,8 +30,9 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -79,8 +80,8 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> nonBroadcastInputEnd = new CompletableFuture<>();
-        CompletableFuture<Void> broadcastInputEnd = new CompletableFuture<>();
+        AtomicInteger nonBroadcastInputCounter = new AtomicInteger();
+        AtomicInteger broadcastInputCounter = new AtomicInteger();
         KeyedTwoInputBroadcastProcessOperator<Long, Integer, Long, Long> processOperator =
                 new KeyedTwoInputBroadcastProcessOperator<>(
                         new TwoInputBroadcastStreamProcessFunction<Integer, Long, Long>() {
@@ -100,12 +101,32 @@
 
                             @Override
                             public void endNonBroadcastInput(NonPartitionedContext<Long> ctx) {
-                                nonBroadcastInputEnd.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                nonBroadcastInputCounter.incrementAndGet();
+                                                Long currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                out.collect(currentKey);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
 
                             @Override
                             public void endBroadcastInput(NonPartitionedContext<Long> ctx) {
-                                broadcastInputEnd.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                broadcastInputCounter.incrementAndGet();
+                                                Long currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                out.collect(currentKey);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -116,10 +137,18 @@
                         (KeySelector<Long, Long>) value -> value,
                         Types.LONG)) {
             testHarness.open();
+            testHarness.processElement1(new StreamRecord<>(1)); // key is 1L
+            testHarness.processElement2(new StreamRecord<>(2L)); // broadcast input is not keyed
             testHarness.endInput1();
-            assertThat(nonBroadcastInputEnd).isCompleted();
+            assertThat(nonBroadcastInputCounter).hasValue(1);
+            Collection<StreamRecord<Long>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput).containsExactly(new StreamRecord<>(1L));
+            testHarness.processElement2(new StreamRecord<>(3L)); // broadcast input is not keyed
             testHarness.endInput2();
-            assertThat(broadcastInputEnd).isCompleted();
+            assertThat(broadcastInputCounter).hasValue(1);
+            recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(1L));
         }
     }
 
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java
index 324e91e..75d9b61 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoInputNonBroadcastProcessOperatorTest.java
@@ -30,7 +30,7 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -80,8 +80,8 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> firstFuture = new CompletableFuture<>();
-        CompletableFuture<Void> secondFuture = new CompletableFuture<>();
+        AtomicInteger firstInputCounter = new AtomicInteger();
+        AtomicInteger secondInputCounter = new AtomicInteger();
         KeyedTwoInputNonBroadcastProcessOperator<Long, Integer, Long, Long> processOperator =
                 new KeyedTwoInputNonBroadcastProcessOperator<>(
                         new TwoInputNonBroadcastStreamProcessFunction<Integer, Long, Long>() {
@@ -101,12 +101,32 @@
 
                             @Override
                             public void endFirstInput(NonPartitionedContext<Long> ctx) {
-                                firstFuture.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                firstInputCounter.incrementAndGet();
+                                                Long currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                out.collect(currentKey);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
 
                             @Override
                             public void endSecondInput(NonPartitionedContext<Long> ctx) {
-                                secondFuture.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                secondInputCounter.incrementAndGet();
+                                                Long currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                out.collect(currentKey);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -117,10 +137,21 @@
                         (KeySelector<Long, Long>) value -> value,
                         Types.LONG)) {
             testHarness.open();
+            testHarness.processElement1(new StreamRecord<>(1)); // key is 1L
+            testHarness.processElement2(new StreamRecord<>(2L)); // key is 2L
             testHarness.endInput1();
-            assertThat(firstFuture).isCompleted();
+            assertThat(firstInputCounter).hasValue(2);
+            Collection<StreamRecord<Long>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L));
+            testHarness.processElement2(new StreamRecord<>(3L)); // key is 3L
+            testHarness.getOutput().clear();
             testHarness.endInput2();
-            assertThat(secondFuture).isCompleted();
+            assertThat(secondInputCounter).hasValue(3);
+            recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(
+                            new StreamRecord<>(1L), new StreamRecord<>(2L), new StreamRecord<>(3L));
         }
     }
 
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java
index b40a0e1..875898d 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/KeyedTwoOutputProcessOperatorTest.java
@@ -31,9 +31,9 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -82,7 +82,7 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
+        AtomicInteger counter = new AtomicInteger();
         OutputTag<Long> sideOutputTag = new OutputTag<Long>("side-output") {};
 
         KeyedTwoOutputProcessOperator<Integer, Integer, Integer, Long> processOperator =
@@ -100,7 +100,18 @@
                             @Override
                             public void endInput(
                                     TwoOutputNonPartitionedContext<Integer, Long> ctx) {
-                                future.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (firstOutput, secondOutput, context) -> {
+                                                counter.incrementAndGet();
+                                                Integer currentKey =
+                                                        context.getStateManager().getCurrentKey();
+                                                firstOutput.collect(currentKey);
+                                                secondOutput.collect(Long.valueOf(currentKey));
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         },
                         sideOutputTag);
@@ -111,8 +122,16 @@
                         (KeySelector<Integer, Integer>) value -> value,
                         Types.INT)) {
             testHarness.open();
+            testHarness.processElement(new StreamRecord<>(1)); // key is 1
+            testHarness.processElement(new StreamRecord<>(2)); //  key is 2
             testHarness.endInput();
-            assertThat(future).isCompleted();
+            assertThat(counter).hasValue(2);
+            Collection<StreamRecord<Integer>> firstOutput = testHarness.getRecordOutput();
+            ConcurrentLinkedQueue<StreamRecord<Long>> secondOutput =
+                    testHarness.getSideOutput(sideOutputTag);
+            assertThat(firstOutput).containsExactly(new StreamRecord<>(1), new StreamRecord<>(2));
+            assertThat(secondOutput)
+                    .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L));
         }
     }
 
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java
index e33b92b..fe4419b 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/ProcessOperatorTest.java
@@ -28,7 +28,7 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -57,7 +57,7 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
+        AtomicInteger counter = new AtomicInteger();
         ProcessOperator<Integer, String> processOperator =
                 new ProcessOperator<>(
                         new OneInputStreamProcessFunction<Integer, String>() {
@@ -71,7 +71,16 @@
 
                             @Override
                             public void endInput(NonPartitionedContext<String> ctx) {
-                                future.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                counter.incrementAndGet();
+                                                out.collect("end");
+                                            });
+
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -79,7 +88,9 @@
                 new OneInputStreamOperatorTestHarness<>(processOperator)) {
             testHarness.open();
             testHarness.endInput();
-            assertThat(future).isCompleted();
+            Collection<StreamRecord<String>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput).containsExactly(new StreamRecord<>("end"));
+            assertThat(counter).hasValue(1);
         }
     }
 }
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java
index 6bbcd33..1e50d7f 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputBroadcastProcessOperatorTest.java
@@ -28,8 +28,9 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.List;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -73,8 +74,8 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> nonBroadcastInputEnd = new CompletableFuture<>();
-        CompletableFuture<Void> broadcastInputEnd = new CompletableFuture<>();
+        AtomicInteger nonBroadcastInputCounter = new AtomicInteger();
+        AtomicInteger broadcastInputCounter = new AtomicInteger();
         TwoInputBroadcastProcessOperator<Integer, Long, Long> processOperator =
                 new TwoInputBroadcastProcessOperator<>(
                         new TwoInputBroadcastStreamProcessFunction<Integer, Long, Long>() {
@@ -95,12 +96,28 @@
 
                             @Override
                             public void endNonBroadcastInput(NonPartitionedContext<Long> ctx) {
-                                nonBroadcastInputEnd.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                nonBroadcastInputCounter.incrementAndGet();
+                                                out.collect(1L);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
 
                             @Override
                             public void endBroadcastInput(NonPartitionedContext<Long> ctx) {
-                                broadcastInputEnd.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                broadcastInputCounter.incrementAndGet();
+                                                out.collect(2L);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -108,9 +125,12 @@
                 new TwoInputStreamOperatorTestHarness<>(processOperator)) {
             testHarness.open();
             testHarness.endInput1();
-            assertThat(nonBroadcastInputEnd).isCompleted();
+            assertThat(nonBroadcastInputCounter).hasValue(1);
             testHarness.endInput2();
-            assertThat(broadcastInputEnd).isCompleted();
+            assertThat(broadcastInputCounter).hasValue(1);
+            Collection<StreamRecord<Long>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L));
         }
     }
 }
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java
index a6c674f..e4774f3 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoInputNonBroadcastProcessOperatorTest.java
@@ -28,7 +28,7 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -75,8 +75,8 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> firstFuture = new CompletableFuture<>();
-        CompletableFuture<Void> secondFuture = new CompletableFuture<>();
+        AtomicInteger firstInputCounter = new AtomicInteger();
+        AtomicInteger secondInputCounter = new AtomicInteger();
         TwoInputNonBroadcastProcessOperator<Integer, Long, Long> processOperator =
                 new TwoInputNonBroadcastProcessOperator<>(
                         new TwoInputNonBroadcastStreamProcessFunction<Integer, Long, Long>() {
@@ -96,12 +96,28 @@
 
                             @Override
                             public void endFirstInput(NonPartitionedContext<Long> ctx) {
-                                firstFuture.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                firstInputCounter.incrementAndGet();
+                                                out.collect(1L);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
 
                             @Override
                             public void endSecondInput(NonPartitionedContext<Long> ctx) {
-                                secondFuture.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (out, context) -> {
+                                                secondInputCounter.incrementAndGet();
+                                                out.collect(2L);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         });
 
@@ -109,9 +125,12 @@
                 new TwoInputStreamOperatorTestHarness<>(processOperator)) {
             testHarness.open();
             testHarness.endInput1();
-            assertThat(firstFuture).isCompleted();
+            assertThat(firstInputCounter).hasValue(1);
             testHarness.endInput2();
-            assertThat(secondFuture).isCompleted();
+            assertThat(secondInputCounter).hasValue(1);
+            Collection<StreamRecord<Long>> recordOutput = testHarness.getRecordOutput();
+            assertThat(recordOutput)
+                    .containsExactly(new StreamRecord<>(1L), new StreamRecord<>(2L));
         }
     }
 }
diff --git a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java
index e3c6394..273259a 100644
--- a/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java
+++ b/flink-datastream/src/test/java/org/apache/flink/datastream/impl/operators/TwoOutputProcessOperatorTest.java
@@ -29,8 +29,8 @@
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -75,7 +75,7 @@
 
     @Test
     void testEndInput() throws Exception {
-        CompletableFuture<Void> future = new CompletableFuture<>();
+        AtomicInteger counter = new AtomicInteger();
         OutputTag<Long> sideOutputTag = new OutputTag<Long>("side-output") {};
 
         TwoOutputProcessOperator<Integer, Integer, Long> processOperator =
@@ -93,7 +93,16 @@
                             @Override
                             public void endInput(
                                     TwoOutputNonPartitionedContext<Integer, Long> ctx) {
-                                future.complete(null);
+                                try {
+                                    ctx.applyToAllPartitions(
+                                            (firstOutput, secondOutput, context) -> {
+                                                counter.incrementAndGet();
+                                                firstOutput.collect(1);
+                                                secondOutput.collect(2L);
+                                            });
+                                } catch (Exception e) {
+                                    throw new RuntimeException(e);
+                                }
                             }
                         },
                         sideOutputTag);
@@ -102,7 +111,12 @@
                 new OneInputStreamOperatorTestHarness<>(processOperator)) {
             testHarness.open();
             testHarness.endInput();
-            assertThat(future).isCompleted();
+            assertThat(counter).hasValue(1);
+            Collection<StreamRecord<Integer>> firstOutput = testHarness.getRecordOutput();
+            ConcurrentLinkedQueue<StreamRecord<Long>> secondOutput =
+                    testHarness.getSideOutput(sideOutputTag);
+            assertThat(firstOutput).containsExactly(new StreamRecord<>(1));
+            assertThat(secondOutput).containsExactly(new StreamRecord<>(2L));
         }
     }
 }
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
index eda7c254..ca82984 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/AbstractStreamOperator.java
@@ -531,6 +531,14 @@
         return stateHandler.getKeyedStateStore().orElse(null);
     }
 
+    protected KeySelector<?, ?> getStateKeySelector1() {
+        return stateKeySelector1;
+    }
+
+    protected KeySelector<?, ?> getStateKeySelector2() {
+        return stateKeySelector2;
+    }
+
     // ------------------------------------------------------------------------
     //  Context and chaining properties
     // ------------------------------------------------------------------------