[FLINK-21308][core] Support delayed message cancellation
This closes #241.
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
index 226f418..db35de6 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
@@ -68,6 +68,12 @@
public void sendAfter(Duration duration, Address address, Object o) {}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {}
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M m, CompletableFuture<T> completableFuture) {}
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
index c77adb7..eed001f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
@@ -17,6 +17,7 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -93,6 +94,11 @@
}
@Override
+ public Optional<String> cancellationToken() {
+ return message.cancellationToken();
+ }
+
+ @Override
public void postApply() {
pendingAsyncOperations.remove(source(), futureId);
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
new file mode 100644
index 0000000..1dfb66b
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.functions;
+
+import java.util.Objects;
+import java.util.function.Consumer;
+import org.apache.flink.statefun.flink.core.di.Inject;
+import org.apache.flink.statefun.flink.core.di.Label;
+import org.apache.flink.statefun.flink.core.di.Lazy;
+import org.apache.flink.statefun.flink.core.message.Message;
+
+/**
+ * Handles any of the delayed message that needs to be fired at a specific timestamp. This handler
+ * dispatches {@linkplain Message}s to either remotely (shuffle) or locally.
+ */
+final class DelayMessageHandler implements Consumer<Message> {
+ private final RemoteSink remoteSink;
+ private final Lazy<Reductions> reductions;
+ private final Partition thisPartition;
+
+ @Inject
+ public DelayMessageHandler(
+ RemoteSink remoteSink,
+ @Label("reductions") Lazy<Reductions> reductions,
+ Partition partition) {
+ this.remoteSink = Objects.requireNonNull(remoteSink);
+ this.reductions = Objects.requireNonNull(reductions);
+ this.thisPartition = Objects.requireNonNull(partition);
+ }
+
+ @Override
+ public void accept(Message message) {
+ if (thisPartition.contains(message.target())) {
+ reductions.get().enqueue(message);
+ } else {
+ remoteSink.accept(message);
+ }
+ }
+
+ public void onStart() {}
+
+ public void onComplete() {
+ reductions.get().processEnvelopes();
+ }
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
index 05cc212..ddf81c9 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
@@ -18,10 +18,10 @@
package org.apache.flink.statefun.flink.core.functions;
import java.util.Objects;
+import java.util.OptionalLong;
import org.apache.flink.runtime.state.VoidNamespace;
import org.apache.flink.statefun.flink.core.di.Inject;
import org.apache.flink.statefun.flink.core.di.Label;
-import org.apache.flink.statefun.flink.core.di.Lazy;
import org.apache.flink.statefun.flink.core.message.Message;
import org.apache.flink.streaming.api.operators.InternalTimer;
import org.apache.flink.streaming.api.operators.InternalTimerService;
@@ -32,25 +32,17 @@
private final InternalTimerService<VoidNamespace> delayedMessagesTimerService;
private final DelayedMessagesBuffer delayedMessagesBuffer;
-
- private final Lazy<Reductions> reductionsSupplier;
- private final Partition thisPartition;
- private final RemoteSink remoteSink;
+ private final DelayMessageHandler delayMessageHandler;
@Inject
DelaySink(
@Label("delayed-messages-buffer") DelayedMessagesBuffer delayedMessagesBuffer,
@Label("delayed-messages-timer-service-factory")
TimerServiceFactory delayedMessagesTimerServiceFactory,
- @Label("reductions") Lazy<Reductions> reductionsSupplier,
- Partition thisPartition,
- RemoteSink remoteSink) {
+ DelayMessageHandler delayMessageHandler) {
this.delayedMessagesBuffer = Objects.requireNonNull(delayedMessagesBuffer);
- this.reductionsSupplier = Objects.requireNonNull(reductionsSupplier);
- this.thisPartition = Objects.requireNonNull(thisPartition);
- this.remoteSink = Objects.requireNonNull(remoteSink);
-
this.delayedMessagesTimerService = delayedMessagesTimerServiceFactory.createTimerService(this);
+ this.delayMessageHandler = Objects.requireNonNull(delayMessageHandler);
}
void accept(Message message, long delayMillis) {
@@ -64,35 +56,25 @@
}
@Override
- public void onProcessingTime(InternalTimer<String, VoidNamespace> timer) throws Exception {
- final long triggerTimestamp = timer.getTimestamp();
- final Reductions reductions = reductionsSupplier.get();
-
- Iterable<Message> delayedMessages = delayedMessagesBuffer.getForTimestamp(triggerTimestamp);
- if (delayedMessages == null) {
- throw new IllegalStateException(
- "A delayed message timer was triggered with timestamp "
- + triggerTimestamp
- + ", but no messages were buffered for it.");
- }
- for (Message delayedMessage : delayedMessages) {
- if (thisPartition.contains(delayedMessage.target())) {
- reductions.enqueue(delayedMessage);
- } else {
- remoteSink.accept(delayedMessage);
- }
- }
- // we clear the delayedMessageBuffer *before* we process the enqueued local reductions, because
- // processing the envelops might actually trigger a delayed message to be sent with the same
- // @triggerTimestamp
- // so it would be re-enqueued into the delayedMessageBuffer.
- delayedMessagesBuffer.clearForTimestamp(triggerTimestamp);
- reductions.processEnvelopes();
+ public void onProcessingTime(InternalTimer<String, VoidNamespace> timer) {
+ delayMessageHandler.onStart();
+ delayedMessagesBuffer.forEachMessageAt(timer.getTimestamp(), delayMessageHandler);
+ delayMessageHandler.onComplete();
}
@Override
- public void onEventTime(InternalTimer<String, VoidNamespace> timer) throws Exception {
+ public void onEventTime(InternalTimer<String, VoidNamespace> timer) {
throw new UnsupportedOperationException(
"Delayed messages with event time semantics is not supported.");
}
+
+ void removeMessageByCancellationToken(String cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ OptionalLong timerToClear =
+ delayedMessagesBuffer.removeMessageByCancellationToken(cancellationToken);
+ if (timerToClear.isPresent()) {
+ long timestamp = timerToClear.getAsLong();
+ delayedMessagesTimerService.deleteProcessingTimeTimer(VoidNamespace.INSTANCE, timestamp);
+ }
+ }
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
index 1b68e3f..cf35389 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
@@ -17,13 +17,23 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.OptionalLong;
+import java.util.function.Consumer;
import org.apache.flink.statefun.flink.core.message.Message;
interface DelayedMessagesBuffer {
+ /** Add a message to be fired at a specific timestamp */
void add(Message message, long untilTimestamp);
- Iterable<Message> getForTimestamp(long timestamp);
+ /** Apply @fn for each delayed message that is meant to be fired at @timestamp. */
+ void forEachMessageAt(long timestamp, Consumer<Message> fn);
- void clearForTimestamp(long timestamp);
+ /**
+ * @param token a message cancellation token to delete.
+ * @return an optional timestamp that this message was meant to be fired at. The timestamp will be
+ * present only if this message was the last message registered to fire at that timestamp.
+ * (hence: safe to clear any underlying timer)
+ */
+ OptionalLong removeMessageByCancellationToken(String token);
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
index a451fd0..ac7b9a5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
@@ -17,7 +17,14 @@
*/
package org.apache.flink.statefun.flink.core.functions;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.function.Consumer;
+import javax.annotation.Nullable;
+import org.apache.flink.api.common.state.MapState;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.statefun.flink.core.di.Inject;
import org.apache.flink.statefun.flink.core.di.Label;
@@ -26,41 +33,122 @@
final class FlinkStateDelayedMessagesBuffer implements DelayedMessagesBuffer {
static final String BUFFER_STATE_NAME = "delayed-messages-buffer";
+ static final String INDEX_STATE_NAME = "delayed-message-index";
private final InternalListState<String, Long, Message> bufferState;
+ private final MapState<String, Long> cancellationTokenToTimestamp;
@Inject
FlinkStateDelayedMessagesBuffer(
- @Label("delayed-messages-buffer-state")
- InternalListState<String, Long, Message> bufferState) {
+ @Label("delayed-messages-buffer-state") InternalListState<String, Long, Message> bufferState,
+ @Label("delayed-message-index") MapState<String, Long> cancellationTokenToTimestamp) {
this.bufferState = Objects.requireNonNull(bufferState);
+ this.cancellationTokenToTimestamp = Objects.requireNonNull(cancellationTokenToTimestamp);
+ }
+
+ @Override
+ public void forEachMessageAt(long timestamp, Consumer<Message> fn) {
+ try {
+ forEachMessageThrows(timestamp, fn);
+ } catch (Exception e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public OptionalLong removeMessageByCancellationToken(String token) {
+ try {
+ return remove(token);
+ } catch (Exception e) {
+ throw new IllegalStateException(
+ "Failed clearing a message with a cancellation token " + token, e);
+ }
}
@Override
public void add(Message message, long untilTimestamp) {
- bufferState.setCurrentNamespace(untilTimestamp);
try {
- bufferState.add(message);
+ addThrows(message, untilTimestamp);
} catch (Exception e) {
throw new RuntimeException("Error adding delayed message to state buffer: " + message, e);
}
}
- @Override
- public Iterable<Message> getForTimestamp(long timestamp) {
- bufferState.setCurrentNamespace(timestamp);
+ // -----------------------------------------------------------------------------------------------------
+ // Internal
+ // -----------------------------------------------------------------------------------------------------
- try {
- return bufferState.get();
- } catch (Exception e) {
- throw new RuntimeException(
- "Error accessing delayed message in state buffer for timestamp: " + timestamp, e);
+ private void forEachMessageThrows(long timestamp, Consumer<Message> fn) throws Exception {
+ bufferState.setCurrentNamespace(timestamp);
+ for (Message message : bufferState.get()) {
+ removeMessageIdMapping(message);
+ fn.accept(message);
+ }
+ bufferState.clear();
+ }
+
+ private void addThrows(Message message, long untilTimestamp) throws Exception {
+ bufferState.setCurrentNamespace(untilTimestamp);
+ bufferState.add(message);
+ Optional<String> maybeToken = message.cancellationToken();
+ if (!maybeToken.isPresent()) {
+ return;
+ }
+ String cancellationToken = maybeToken.get();
+ @Nullable Long previousTimestamp = cancellationTokenToTimestamp.get(cancellationToken);
+ if (previousTimestamp != null) {
+ throw new IllegalStateException(
+ "Trying to associate a message with cancellation token "
+ + cancellationToken
+ + " and timestamp "
+ + untilTimestamp
+ + ", but a message with the same cancellation token exists and with a timestamp "
+ + previousTimestamp);
+ }
+ cancellationTokenToTimestamp.put(cancellationToken, untilTimestamp);
+ }
+
+ private OptionalLong remove(String cancellationToken) throws Exception {
+ final @Nullable Long untilTimestamp = cancellationTokenToTimestamp.get(cancellationToken);
+ if (untilTimestamp == null) {
+ // The message associated with @cancellationToken has already been delivered, or previously
+ // removed.
+ return OptionalLong.empty();
+ }
+ cancellationTokenToTimestamp.remove(cancellationToken);
+ bufferState.setCurrentNamespace(untilTimestamp);
+ List<Message> newList = removeMessageByToken(bufferState.get(), cancellationToken);
+ if (!newList.isEmpty()) {
+ // There are more messages to process, so we indicate to the caller that
+ // they should NOT cancel the timer.
+ bufferState.update(newList);
+ return OptionalLong.empty();
+ }
+ // There are no more message to remove, we clear the buffer and indicate
+ // to our caller to remove the timer for @untilTimestamp
+ bufferState.clear();
+ return OptionalLong.of(untilTimestamp);
+ }
+
+ // ---------------------------------------------------------------------------------------------------------
+ // Helpers
+ // ---------------------------------------------------------------------------------------------------------
+
+ private void removeMessageIdMapping(Message message) throws Exception {
+ Optional<String> maybeToken = message.cancellationToken();
+ if (maybeToken.isPresent()) {
+ cancellationTokenToTimestamp.remove(maybeToken.get());
}
}
- @Override
- public void clearForTimestamp(long timestamp) {
- bufferState.setCurrentNamespace(timestamp);
- bufferState.clear();
+ private static List<Message> removeMessageByToken(Iterable<Message> messages, String token) {
+ ArrayList<Message> newList = new ArrayList<>();
+ for (Message message : messages) {
+ Optional<String> thisMessageId = message.cancellationToken();
+ if (!thisMessageId.isPresent() || !Objects.equals(thisMessageId.get(), token)) {
+ newList.add(message);
+ }
+ }
+ return newList;
}
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
index 741575b..8dcd01b 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
@@ -104,6 +104,11 @@
final ListStateDescriptor<Message> delayedMessageStateDescriptor =
new ListStateDescriptor<>(
FlinkStateDelayedMessagesBuffer.BUFFER_STATE_NAME, envelopeSerializer.duplicate());
+ final MapStateDescriptor<String, Long> delayedMessageIndexDescriptor =
+ new MapStateDescriptor<>(
+ FlinkStateDelayedMessagesBuffer.INDEX_STATE_NAME, String.class, Long.class);
+ final MapState<String, Long> delayedMessageIndex =
+ getRuntimeContext().getMapState(delayedMessageIndexDescriptor);
final MapState<Long, Message> asyncOperationState =
getRuntimeContext().getMapState(asyncOperationStateDescriptor);
@@ -130,6 +135,7 @@
new FlinkTimerServiceFactory(
super.getTimeServiceManager().orElseThrow(IllegalStateException::new)),
delayedMessagesBufferState(delayedMessageStateDescriptor),
+ delayedMessageIndex,
sideOutputs,
output,
MessageFactory.forKey(statefulFunctionsUniverse.messageFactoryKey()),
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
index 55b521f..b881a85 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
@@ -62,6 +62,7 @@
KeyedStateBackend<Object> keyedStateBackend,
TimerServiceFactory timerServiceFactory,
InternalListState<String, Long, Message> delayedMessagesBufferState,
+ MapState<String, Long> delayMessageIndex,
Map<EgressIdentifier<?>, OutputTag<Object>> sideOutputs,
Output<StreamRecord<Message>> output,
MessageFactory messageFactory,
@@ -117,6 +118,7 @@
// for delayed messages
container.add(
"delayed-messages-buffer-state", InternalListState.class, delayedMessagesBufferState);
+ container.add("delayed-message-index", MapState.class, delayMessageIndex);
container.add(
"delayed-messages-buffer",
DelayedMessagesBuffer.class,
@@ -124,6 +126,7 @@
container.add(
"delayed-messages-timer-service-factory", TimerServiceFactory.class, timerServiceFactory);
container.add(DelaySink.class);
+ container.add(DelayMessageHandler.class);
// lazy providers for the sinks
container.add("function-group", new Lazy<>(LocalFunctionGroup.class));
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
index 77db7dc..e1a0c87 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
@@ -109,6 +109,23 @@
}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+ Objects.requireNonNull(delay);
+ Objects.requireNonNull(to);
+ Objects.requireNonNull(message);
+ Objects.requireNonNull(cancellationToken);
+
+ Message envelope = messageFactory.from(self(), to, message, cancellationToken);
+ delaySink.accept(envelope, delay.toMillis());
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ delaySink.removeMessageByCancellationToken(cancellationToken);
+ }
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {
Objects.requireNonNull(metadata);
Objects.requireNonNull(future);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
index 2278fa5..be10e3f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
@@ -18,6 +18,7 @@
package org.apache.flink.statefun.flink.core.message;
import java.io.IOException;
+import java.util.Optional;
import java.util.OptionalLong;
import org.apache.flink.core.memory.DataOutputView;
@@ -35,6 +36,8 @@
*/
OptionalLong isBarrierMessage();
+ Optional<String> cancellationToken();
+
Message copy(MessageFactory context);
void writeTo(MessageFactory context, DataOutputView target) throws IOException;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
index e4d2d3a..780415e 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
@@ -54,6 +54,10 @@
return new SdkMessage(from, to, payload);
}
+ public Message from(Address from, Address to, Object payload, String messageId) {
+ return new SdkMessage(from, to, payload, messageId);
+ }
+
// -------------------------------------------------------------------------------------------------------
void copy(DataInputView source, DataOutputView target) throws IOException {
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
index dabda14..500958f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -82,6 +83,15 @@
}
@Override
+ public Optional<String> cancellationToken() {
+ String token = envelope.getCancellationToken();
+ if (token.isEmpty()) {
+ return Optional.empty();
+ }
+ return Optional.of(token);
+ }
+
+ @Override
public Message copy(MessageFactory unused) {
return new ProtobufMessage(envelope);
}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
index c10f2e9..ca5ee50 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
@@ -19,6 +19,7 @@
import java.io.IOException;
import java.util.Objects;
+import java.util.Optional;
import java.util.OptionalLong;
import javax.annotation.Nullable;
import org.apache.flink.core.memory.DataOutputView;
@@ -29,18 +30,27 @@
final class SdkMessage implements Message {
- @Nullable private final Address source;
-
private final Address target;
+ @Nullable private final Address source;
+ @Nullable private final String cancellationToken;
+ @Nullable private Envelope cachedEnvelope;
+
private Object payload;
- @Nullable private Envelope cachedEnvelope;
-
SdkMessage(@Nullable Address source, Address target, Object payload) {
+ this(source, target, payload, null);
+ }
+
+ SdkMessage(
+ @Nullable Address source,
+ Address target,
+ Object payload,
+ @Nullable String cancellationToken) {
this.source = source;
this.target = Objects.requireNonNull(target);
this.payload = Objects.requireNonNull(payload);
+ this.cancellationToken = cancellationToken;
}
@Override
@@ -68,8 +78,13 @@
}
@Override
+ public Optional<String> cancellationToken() {
+ return Optional.ofNullable(cancellationToken);
+ }
+
+ @Override
public Message copy(MessageFactory factory) {
- return new SdkMessage(source, target, payload);
+ return new SdkMessage(source, target, payload, cancellationToken);
}
@Override
@@ -86,6 +101,9 @@
}
builder.setTarget(sdkAddressToProtobufAddress(target));
builder.setPayload(factory.serializeUserMessagePayload(payload));
+ if (cancellationToken != null) {
+ builder.setCancellationToken(cancellationToken);
+ }
cachedEnvelope = builder.build();
}
return cachedEnvelope;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
index d0ed196..b9fdc1a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -233,14 +233,34 @@
private void handleOutgoingDelayedMessages(Context context, InvocationResponse invocationResult) {
for (FromFunction.DelayedInvocation delayedInvokeCommand :
invocationResult.getDelayedInvocationsList()) {
- final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
- final TypedValue message = delayedInvokeCommand.getArgument();
- final long delay = delayedInvokeCommand.getDelayInMs();
- context.sendAfter(Duration.ofMillis(delay), to, message);
+ if (delayedInvokeCommand.getIsCancellationRequest()) {
+ handleDelayedMessageCancellation(context, delayedInvokeCommand);
+ } else {
+ handleDelayedMessageSending(context, delayedInvokeCommand);
+ }
}
}
+ private void handleDelayedMessageSending(
+ Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+ final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
+ final TypedValue message = delayedInvokeCommand.getArgument();
+ final long delay = delayedInvokeCommand.getDelayInMs();
+
+ context.sendAfter(Duration.ofMillis(delay), to, message);
+ }
+
+ private void handleDelayedMessageCancellation(
+ Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+ String token = delayedInvokeCommand.getCancellationToken();
+ if (token.isEmpty()) {
+ throw new IllegalArgumentException(
+ "Can not handle a cancellation request without a cancellation token.");
+ }
+ context.cancelDelayedMessage(token);
+ }
+
// --------------------------------------------------------------------------------
// Send Message to Remote Function
// --------------------------------------------------------------------------------
diff --git a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
index 1b09239..1c17e0b 100644
--- a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
+++ b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
@@ -37,10 +37,14 @@
int64 checkpoint_id = 1;
}
+
message Envelope {
EnvelopeAddress source = 1;
EnvelopeAddress target = 2;
+ // an optional token that can be used track delayed message cancellation.
+ string cancellation_token = 10;
+
oneof body {
Checkpoint checkpoint = 4;
Payload payload = 3;
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
index 9c7ccd0..7cbd4b6 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
@@ -133,6 +133,12 @@
public void sendAfter(Duration duration, Address to, Object message) {}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {}
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
@Override
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
index 3f84bce..cf3b19a 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
@@ -100,12 +100,13 @@
new FakeKeyedStateBackend(),
new FakeTimerServiceFactory(),
new FakeInternalListState(),
+ new FakeMapState<>(),
new HashMap<>(),
new FakeOutput(),
TestUtils.ENVELOPE_FACTORY,
MoreExecutors.directExecutor(),
new FakeMetricGroup(),
- new FakeMapState());
+ new FakeMapState<>());
assertThat(reductions, notNullValue());
}
@@ -517,44 +518,44 @@
}
}
- private static final class FakeMapState implements MapState<Long, Message> {
+ private static final class FakeMapState<K, V> implements MapState<K, V> {
@Override
- public Message get(Long key) throws Exception {
+ public V get(K key) throws Exception {
return null;
}
@Override
- public void put(Long key, Message value) throws Exception {}
+ public void put(K key, V value) throws Exception {}
@Override
- public void putAll(Map<Long, Message> map) throws Exception {}
+ public void putAll(Map<K, V> map) throws Exception {}
@Override
- public void remove(Long key) throws Exception {}
+ public void remove(K key) throws Exception {}
@Override
- public boolean contains(Long key) throws Exception {
+ public boolean contains(K key) throws Exception {
return false;
}
@Override
- public Iterable<Entry<Long, Message>> entries() throws Exception {
+ public Iterable<Entry<K, V>> entries() throws Exception {
return null;
}
@Override
- public Iterable<Long> keys() throws Exception {
+ public Iterable<K> keys() throws Exception {
return null;
}
@Override
- public Iterable<Message> values() throws Exception {
+ public Iterable<V> values() throws Exception {
return null;
}
@Override
- public Iterator<Entry<Long, Message>> iterator() throws Exception {
+ public Iterator<Entry<K, V>> iterator() throws Exception {
return null;
}
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
index f88916f..5b7c536 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -37,6 +37,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
@@ -197,7 +198,7 @@
functionUnderTest.invoke(context, successfulAsyncOperation(response));
assertFalse(context.delayed.isEmpty());
- assertEquals(Duration.ofMillis(1), context.delayed.get(0).getKey());
+ assertEquals(Duration.ofMillis(1), context.delayed.get(0).delay());
}
@Test
@@ -376,6 +377,38 @@
}
}
+ private static final class DelayedMessage {
+ final Duration delay;
+ final @Nullable String messageId;
+ final Address target;
+ final Object message;
+
+ public DelayedMessage(
+ Duration delay, @Nullable String messageId, Address target, Object message) {
+ this.delay = delay;
+ this.messageId = messageId;
+ this.target = target;
+ this.message = message;
+ }
+
+ public Duration delay() {
+ return delay;
+ }
+
+ @Nullable
+ public String messageId() {
+ return messageId;
+ }
+
+ public Address target() {
+ return target;
+ }
+
+ public Object message() {
+ return message;
+ }
+ }
+
private static final class FakeContext implements InternalContext {
private final BacklogTrackingMetrics fakeMetrics = new BacklogTrackingMetrics();
@@ -385,7 +418,7 @@
// capture emitted messages
List<Map.Entry<EgressIdentifier<?>, ?>> egresses = new ArrayList<>();
- List<Map.Entry<Duration, ?>> delayed = new ArrayList<>();
+ List<DelayedMessage> delayed = new ArrayList<>();
@Override
public void awaitAsyncOperationComplete() {
@@ -417,10 +450,18 @@
@Override
public void sendAfter(Duration delay, Address to, Object message) {
- delayed.add(new SimpleImmutableEntry<>(delay, message));
+ delayed.add(new DelayedMessage(delay, null, to, message));
}
@Override
+ public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+ delayed.add(new DelayedMessage(delay, cancellationToken, to, message));
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {}
+
+ @Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
}
diff --git a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
index d3a1231..ea9a32d 100644
--- a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
+++ b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
@@ -75,6 +75,33 @@
void sendAfter(Duration delay, Address to, Object message);
/**
+ * Invokes another function with an input (associated with a {@code cancellationToken}),
+ * identified by the target function's {@link Address}, after a given delay.
+ *
+ * <p>Providing an id to a message, allows "unsending" this message later. ({@link
+ * #cancelDelayedMessage(String)}).
+ *
+ * @param delay the amount of delay before invoking the target function. Value needs to be >=
+ * 0.
+ * @param to the target function's address.
+ * @param message the input to provide for the delayed invocation.
+ * @param cancellationToken the non-empty, non-null, unique token to attach to this message, to be
+ * used for message cancellation. (see {@link #cancelDelayedMessage(String)}.)
+ */
+ void sendAfter(Duration delay, Address to, Object message, String cancellationToken);
+
+ /**
+ * Cancel a delayed message (a message that was send via {@link #sendAfter(Duration, Address,
+ * Object, String)}).
+ *
+ * <p>NOTE: this is a best-effort operation, since the message might have been already delivered.
+ * If the message was delivered, this is a no-op operation.
+ *
+ * @param cancellationToken the id of the message to un-send.
+ */
+ void cancelDelayedMessage(String cancellationToken);
+
+ /**
* Invokes another function with an input, identified by the target function's {@link
* FunctionType} and unique id.
*
diff --git a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
index 43ab6b8..8ef48e0 100644
--- a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
+++ b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
@@ -56,6 +56,26 @@
void sendAfter(Duration duration, Message message);
/**
+ * Sends out a {@link Message} to another function, after a specified {@link Duration} delay.
+ *
+ * @param duration the amount of time to delay the message delivery. * @param cancellationToken
+ * @param cancellationToken the non-empty, non-null, unique token to attach to this message, to be
+ * used for message cancellation. (see {@link #cancelDelayedMessage(String)}.)
+ * @param message the message to send.
+ */
+ void sendAfter(Duration duration, String cancellationToken, Message message);
+
+ /**
+ * Cancel a delayed message (a message that was send via {@link #sendAfter(Duration, Message)}).
+ *
+ * <p>NOTE: this is a best-effort operation, since the message might have been already delivered.
+ * If the message was delivered, this is a no-op operation.
+ *
+ * @param cancellationToken the id of the message to un-send.
+ */
+ void cancelDelayedMessage(String cancellationToken);
+
+ /**
* Sends out a {@link EgressMessage} to an egress.
*
* @param message the message to send.
diff --git a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
index 49e9fcc..07e44ca 100644
--- a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
+++ b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
@@ -116,6 +116,46 @@
}
@Override
+ public void sendAfter(Duration duration, String cancellationToken, Message message) {
+ Objects.requireNonNull(duration);
+ if (cancellationToken == null || cancellationToken.isEmpty()) {
+ throw new IllegalArgumentException("message cancellation token can not be empty or null.");
+ }
+ Objects.requireNonNull(message);
+
+ FromFunction.DelayedInvocation outInvocation =
+ FromFunction.DelayedInvocation.newBuilder()
+ .setArgument(getTypedValue(message))
+ .setTarget(protoAddressFromSdk(message.targetAddress()))
+ .setDelayInMs(duration.toMillis())
+ .setCancellationToken(cancellationToken)
+ .build();
+
+ synchronized (responseBuilder) {
+ checkNotDone();
+ responseBuilder.addDelayedInvocations(outInvocation);
+ }
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ if (cancellationToken == null || cancellationToken.isEmpty()) {
+ throw new IllegalArgumentException("message cancellation token can not be empty or null.");
+ }
+
+ FromFunction.DelayedInvocation cancellation =
+ FromFunction.DelayedInvocation.newBuilder()
+ .setIsCancellationRequest(true)
+ .setCancellationToken(cancellationToken)
+ .build();
+
+ synchronized (responseBuilder) {
+ checkNotDone();
+ responseBuilder.addDelayedInvocations(cancellation);
+ }
+ }
+
+ @Override
public void send(EgressMessage message) {
Objects.requireNonNull(message);
diff --git a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
index 19d9f2a..ac72d7c 100644
--- a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
+++ b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
@@ -115,6 +115,15 @@
// DelayedInvocation represents a delayed remote function call with a target address, an argument
// and a delay in milliseconds, after which this message to be sent.
message DelayedInvocation {
+ // a boolean value (default false) that indicates rather this is a regular delayed message, or (true) a message
+ // cancellation request.
+ // in case of a regular delayed message all other fields are expected to be preset, otherwise only the
+ // cancellation_token is expected
+ bool is_cancellation_request = 10;
+
+ // an optional cancellation token that can be used to request the "unsending" of a delayed message.
+ string cancellation_token = 11;
+
// the amount of milliseconds to wait before sending this message
int64 delay_in_ms = 1;
// the target address to send this message to
diff --git a/statefun-sdk-python/statefun/context.py b/statefun-sdk-python/statefun/context.py
index b1692a5..578e020 100644
--- a/statefun-sdk-python/statefun/context.py
+++ b/statefun-sdk-python/statefun/context.py
@@ -24,7 +24,6 @@
class Context(abc.ABC):
-
__slots__ = ()
@property
@@ -62,13 +61,24 @@
"""
pass
- def send_after(self, duration: timedelta, message: Message):
+ def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""):
"""
Send a message to a target function after a specified delay.
:param duration: the amount of time to wait before sending this message out.
:param message: the message to send.
+ :param cancellation_token: an optional cancellation token to associate with this message.
"""
+ pass
+
+ def cancel_delayed_message(self, cancellation_token: str):
+ """
+ Cancel a delayed message (message that was sent using send_after) with a given token.
+
+ Please note that this is a best-effort operation, since the message might have been already delivered.
+ If the message was delivered, this is a no-op operation.
+ """
+ pass
def send_egress(self, message: EgressMessage):
"""
diff --git a/statefun-sdk-python/statefun/request_reply_v3.py b/statefun-sdk-python/statefun/request_reply_v3.py
index fa05253..f4db551 100644
--- a/statefun-sdk-python/statefun/request_reply_v3.py
+++ b/statefun-sdk-python/statefun/request_reply_v3.py
@@ -28,6 +28,16 @@
from statefun.request_reply_pb2 import ToFunction, FromFunction, Address, TypedValue
from statefun.storage import resolve, Cell
+from dataclasses import dataclass
+
+
+@dataclass
+class DelayedMessage:
+ is_cancellation: bool = None
+ duration: int = None
+ message: Message = None,
+ cancellation_token: str = None
+
class UserFacingContext(statefun.context.Context):
__slots__ = (
@@ -37,7 +47,7 @@
def __init__(self, address, storage):
self._self_address = address
self._outgoing_messages = []
- self._outgoing_delayed_messages = []
+ self._outgoing_delayed_messages: typing.List[DelayedMessage] = []
self._outgoing_egress_messages = []
self._storage = storage
self._caller = None
@@ -66,15 +76,28 @@
"""
self._outgoing_messages.append(message)
- def send_after(self, duration: timedelta, message: Message):
+ def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""):
"""
Send a message to a target function after a specified delay.
:param duration: the amount of time to wait before sending this message out.
:param message: the message to send.
+ :param cancellation_token: an optional cancellation token to associate with this message.
"""
ms = int(duration.total_seconds() * 1000.0)
- self._outgoing_delayed_messages.append((ms, message))
+ record = DelayedMessage(is_cancellation=False, duration=ms, message=message,
+ cancellation_token=cancellation_token)
+ self._outgoing_delayed_messages.append(record)
+
+ def cancel_delayed_message(self, cancellation_token: str):
+ """
+ Cancel a delayed message (message that was sent using send_after) with a given token.
+
+ Please note that this is a best-effort operation, since the message might have been already delivered.
+ If the message was delivered, this is a no-op operation.
+ """
+ record = DelayedMessage(is_cancellation=True, cancellation_token=cancellation_token)
+ self._outgoing_delayed_messages.append(record)
def send_egress(self, message: EgressMessage):
"""
@@ -145,17 +168,34 @@
outgoing.argument.CopyFrom(message.typed_value)
-def collect_delayed(delayed_messages: typing.List[typing.Tuple[timedelta, Message]], invocation_result):
+def collect_delayed(delayed_messages: typing.List[DelayedMessage], invocation_result):
delayed_invocations = invocation_result.delayed_invocations
- for delay, message in delayed_messages:
+ for delayed_message in delayed_messages:
outgoing = delayed_invocations.add()
- namespace, type = parse_typename(message.target_typename)
- outgoing.target.namespace = namespace
- outgoing.target.type = type
- outgoing.target.id = message.target_id
- outgoing.delay_in_ms = delay
- outgoing.argument.CopyFrom(message.typed_value)
+ if delayed_message.is_cancellation:
+ # handle cancellation
+ outgoing.cancellation_token = delayed_message.cancellation_token
+ outgoing.is_cancellation_request = True
+ else:
+ message = delayed_message.message
+ namespace, type = parse_typename(message.target_typename)
+
+ outgoing.target.namespace = namespace
+ outgoing.target.type = type
+ outgoing.target.id = message.target_id
+ outgoing.delay_in_ms = delayed_message.duration
+ outgoing.argument.CopyFrom(message.typed_value)
+ if delayed_message.cancellation_token is not None:
+ outgoing.cancellation_token = delayed_message.cancellation_token
+
+
+def collect_cancellations(tokens: typing.List[str], invocation_result):
+ outgoing_cancellations = invocation_result.outgoing_delay_cancellations
+ for token in tokens:
+ if token:
+ delay_cancelltion = outgoing_cancellations.add()
+ delay_cancelltion.cancellation_token = token
def collect_egress(egresses: typing.List[EgressMessage], invocation_result):
diff --git a/statefun-sdk-python/tests/request_reply_test.py b/statefun-sdk-python/tests/request_reply_test.py
index 612750f..5bd783a 100644
--- a/statefun-sdk-python/tests/request_reply_test.py
+++ b/statefun-sdk-python/tests/request_reply_test.py
@@ -91,6 +91,7 @@
NTH_OUTGOING_MESSAGE = lambda n: [key("invocation_result"), key("outgoing_messages"), nth(n)]
NTH_STATE_MUTATION = lambda n: [key("invocation_result"), key("state_mutations"), nth(n)]
NTH_DELAYED_MESSAGE = lambda n: [key("invocation_result"), key("delayed_invocations"), nth(n)]
+NTH_CANCELLATION_MESSAGE = lambda n: [key("invocation_result"), key("outgoing_delay_cancellations"), nth(n)]
NTH_EGRESS = lambda n: [key("invocation_result"), key("outgoing_egresses"), nth(n)]
NTH_MISSING_STATE_SPEC = lambda n: [key("incomplete_invocation_context"), key("missing_values"), nth(n)]
@@ -123,6 +124,16 @@
message_builder(target_typename="night/owl",
target_id="1",
str_value="hoo hoo"))
+
+ # delayed with cancellation
+ context.send_after(timedelta(hours=1),
+ message_builder(target_typename="night/owl",
+ target_id="1",
+ str_value="hoo hoo"),
+ cancellation_token="token-1234")
+
+ context.cancel_delayed_message("token-1234")
+
# kafka egresses
context.send_egress(
kafka_egress_message(typename="e/kafka",
@@ -165,6 +176,15 @@
first_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(0))
self.assertEqual(int(first_delayed['delay_in_ms']), 1000 * 60 * 60)
+ # assert delayed with token
+ second_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(1))
+ self.assertEqual(second_delayed['cancellation_token'], "token-1234")
+
+ # assert cancellation
+ first_cancellation = json_at(result_json, NTH_DELAYED_MESSAGE(2))
+ self.assertTrue(first_cancellation['is_cancellation_request'])
+ self.assertEqual(first_cancellation['cancellation_token'], "token-1234")
+
# assert egresses
first_egress = json_at(result_json, NTH_EGRESS(0))
self.assertEqual(first_egress['egress_namespace'], 'e')
diff --git a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
index 1c39775..fdf6be8 100644
--- a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
+++ b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
@@ -102,7 +102,21 @@
@Override
public void sendAfter(Duration delay, Address to, Object message) {
pendingMessage.add(
- new PendingMessage(new Envelope(self(), to, message), watermark + delay.toMillis()));
+ new PendingMessage(new Envelope(self(), to, message), watermark + delay.toMillis(), null));
+ }
+
+ @Override
+ public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+ Objects.requireNonNull(cancellationToken);
+ pendingMessage.add(
+ new PendingMessage(
+ new Envelope(self(), to, message), watermark + delay.toMillis(), cancellationToken));
+ }
+
+ @Override
+ public void cancelDelayedMessage(String cancellationToken) {
+ pendingMessage.removeIf(
+ pendingMessage -> Objects.equals(pendingMessage.cancellationToken, cancellationToken));
}
@Override
@@ -186,12 +200,13 @@
private static class PendingMessage {
Envelope envelope;
-
+ String cancellationToken;
long timer;
- PendingMessage(Envelope envelope, long timer) {
+ PendingMessage(Envelope envelope, long timer, String cancellationToken) {
this.envelope = envelope;
this.timer = timer;
+ this.cancellationToken = cancellationToken;
}
}
}