[FLINK-21308][core] Support delayed message cancellation

This closes #241.
diff --git a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
index 226f418..db35de6 100644
--- a/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
+++ b/statefun-e2e-tests/statefun-smoke-e2e/src/test/java/org/apache/flink/statefun/e2e/smoke/CommandInterpreterTest.java
@@ -68,6 +68,12 @@
     public void sendAfter(Duration duration, Address address, Object o) {}
 
     @Override
+    public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {}
+
+    @Override
+    public void cancelDelayedMessage(String cancellationToken) {}
+
+    @Override
     public <M, T> void registerAsyncOperation(M m, CompletableFuture<T> completableFuture) {}
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
index c77adb7..eed001f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/AsyncMessageDecorator.java
@@ -17,6 +17,7 @@
  */
 package org.apache.flink.statefun.flink.core.functions;
 
+import java.util.Optional;
 import java.util.OptionalLong;
 import javax.annotation.Nullable;
 import org.apache.flink.core.memory.DataOutputView;
@@ -93,6 +94,11 @@
   }
 
   @Override
+  public Optional<String> cancellationToken() {
+    return message.cancellationToken();
+  }
+
+  @Override
   public void postApply() {
     pendingAsyncOperations.remove(source(), futureId);
   }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
new file mode 100644
index 0000000..1dfb66b
--- /dev/null
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayMessageHandler.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.statefun.flink.core.functions;
+
+import java.util.Objects;
+import java.util.function.Consumer;
+import org.apache.flink.statefun.flink.core.di.Inject;
+import org.apache.flink.statefun.flink.core.di.Label;
+import org.apache.flink.statefun.flink.core.di.Lazy;
+import org.apache.flink.statefun.flink.core.message.Message;
+
+/**
+ * Handles any of the delayed message that needs to be fired at a specific timestamp. This handler
+ * dispatches {@linkplain Message}s to either remotely (shuffle) or locally.
+ */
+final class DelayMessageHandler implements Consumer<Message> {
+  private final RemoteSink remoteSink;
+  private final Lazy<Reductions> reductions;
+  private final Partition thisPartition;
+
+  @Inject
+  public DelayMessageHandler(
+      RemoteSink remoteSink,
+      @Label("reductions") Lazy<Reductions> reductions,
+      Partition partition) {
+    this.remoteSink = Objects.requireNonNull(remoteSink);
+    this.reductions = Objects.requireNonNull(reductions);
+    this.thisPartition = Objects.requireNonNull(partition);
+  }
+
+  @Override
+  public void accept(Message message) {
+    if (thisPartition.contains(message.target())) {
+      reductions.get().enqueue(message);
+    } else {
+      remoteSink.accept(message);
+    }
+  }
+
+  public void onStart() {}
+
+  public void onComplete() {
+    reductions.get().processEnvelopes();
+  }
+}
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
index 05cc212..ddf81c9 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelaySink.java
@@ -18,10 +18,10 @@
 package org.apache.flink.statefun.flink.core.functions;
 
 import java.util.Objects;
+import java.util.OptionalLong;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.statefun.flink.core.di.Inject;
 import org.apache.flink.statefun.flink.core.di.Label;
-import org.apache.flink.statefun.flink.core.di.Lazy;
 import org.apache.flink.statefun.flink.core.message.Message;
 import org.apache.flink.streaming.api.operators.InternalTimer;
 import org.apache.flink.streaming.api.operators.InternalTimerService;
@@ -32,25 +32,17 @@
 
   private final InternalTimerService<VoidNamespace> delayedMessagesTimerService;
   private final DelayedMessagesBuffer delayedMessagesBuffer;
-
-  private final Lazy<Reductions> reductionsSupplier;
-  private final Partition thisPartition;
-  private final RemoteSink remoteSink;
+  private final DelayMessageHandler delayMessageHandler;
 
   @Inject
   DelaySink(
       @Label("delayed-messages-buffer") DelayedMessagesBuffer delayedMessagesBuffer,
       @Label("delayed-messages-timer-service-factory")
           TimerServiceFactory delayedMessagesTimerServiceFactory,
-      @Label("reductions") Lazy<Reductions> reductionsSupplier,
-      Partition thisPartition,
-      RemoteSink remoteSink) {
+      DelayMessageHandler delayMessageHandler) {
     this.delayedMessagesBuffer = Objects.requireNonNull(delayedMessagesBuffer);
-    this.reductionsSupplier = Objects.requireNonNull(reductionsSupplier);
-    this.thisPartition = Objects.requireNonNull(thisPartition);
-    this.remoteSink = Objects.requireNonNull(remoteSink);
-
     this.delayedMessagesTimerService = delayedMessagesTimerServiceFactory.createTimerService(this);
+    this.delayMessageHandler = Objects.requireNonNull(delayMessageHandler);
   }
 
   void accept(Message message, long delayMillis) {
@@ -64,35 +56,25 @@
   }
 
   @Override
-  public void onProcessingTime(InternalTimer<String, VoidNamespace> timer) throws Exception {
-    final long triggerTimestamp = timer.getTimestamp();
-    final Reductions reductions = reductionsSupplier.get();
-
-    Iterable<Message> delayedMessages = delayedMessagesBuffer.getForTimestamp(triggerTimestamp);
-    if (delayedMessages == null) {
-      throw new IllegalStateException(
-          "A delayed message timer was triggered with timestamp "
-              + triggerTimestamp
-              + ", but no messages were buffered for it.");
-    }
-    for (Message delayedMessage : delayedMessages) {
-      if (thisPartition.contains(delayedMessage.target())) {
-        reductions.enqueue(delayedMessage);
-      } else {
-        remoteSink.accept(delayedMessage);
-      }
-    }
-    // we clear the delayedMessageBuffer *before* we process the enqueued local reductions, because
-    // processing the envelops might actually trigger a delayed message to be sent with the same
-    // @triggerTimestamp
-    // so it would be re-enqueued into the delayedMessageBuffer.
-    delayedMessagesBuffer.clearForTimestamp(triggerTimestamp);
-    reductions.processEnvelopes();
+  public void onProcessingTime(InternalTimer<String, VoidNamespace> timer) {
+    delayMessageHandler.onStart();
+    delayedMessagesBuffer.forEachMessageAt(timer.getTimestamp(), delayMessageHandler);
+    delayMessageHandler.onComplete();
   }
 
   @Override
-  public void onEventTime(InternalTimer<String, VoidNamespace> timer) throws Exception {
+  public void onEventTime(InternalTimer<String, VoidNamespace> timer) {
     throw new UnsupportedOperationException(
         "Delayed messages with event time semantics is not supported.");
   }
+
+  void removeMessageByCancellationToken(String cancellationToken) {
+    Objects.requireNonNull(cancellationToken);
+    OptionalLong timerToClear =
+        delayedMessagesBuffer.removeMessageByCancellationToken(cancellationToken);
+    if (timerToClear.isPresent()) {
+      long timestamp = timerToClear.getAsLong();
+      delayedMessagesTimerService.deleteProcessingTimeTimer(VoidNamespace.INSTANCE, timestamp);
+    }
+  }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
index 1b68e3f..cf35389 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/DelayedMessagesBuffer.java
@@ -17,13 +17,23 @@
  */
 package org.apache.flink.statefun.flink.core.functions;
 
+import java.util.OptionalLong;
+import java.util.function.Consumer;
 import org.apache.flink.statefun.flink.core.message.Message;
 
 interface DelayedMessagesBuffer {
 
+  /** Add a message to be fired at a specific timestamp */
   void add(Message message, long untilTimestamp);
 
-  Iterable<Message> getForTimestamp(long timestamp);
+  /** Apply @fn for each delayed message that is meant to be fired at @timestamp. */
+  void forEachMessageAt(long timestamp, Consumer<Message> fn);
 
-  void clearForTimestamp(long timestamp);
+  /**
+   * @param token a message cancellation token to delete.
+   * @return an optional timestamp that this message was meant to be fired at. The timestamp will be
+   *     present only if this message was the last message registered to fire at that timestamp.
+   *     (hence: safe to clear any underlying timer)
+   */
+  OptionalLong removeMessageByCancellationToken(String token);
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
index a451fd0..ac7b9a5 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FlinkStateDelayedMessagesBuffer.java
@@ -17,7 +17,14 @@
  */
 package org.apache.flink.statefun.flink.core.functions;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
+import java.util.OptionalLong;
+import java.util.function.Consumer;
+import javax.annotation.Nullable;
+import org.apache.flink.api.common.state.MapState;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.statefun.flink.core.di.Inject;
 import org.apache.flink.statefun.flink.core.di.Label;
@@ -26,41 +33,122 @@
 final class FlinkStateDelayedMessagesBuffer implements DelayedMessagesBuffer {
 
   static final String BUFFER_STATE_NAME = "delayed-messages-buffer";
+  static final String INDEX_STATE_NAME = "delayed-message-index";
 
   private final InternalListState<String, Long, Message> bufferState;
+  private final MapState<String, Long> cancellationTokenToTimestamp;
 
   @Inject
   FlinkStateDelayedMessagesBuffer(
-      @Label("delayed-messages-buffer-state")
-          InternalListState<String, Long, Message> bufferState) {
+      @Label("delayed-messages-buffer-state") InternalListState<String, Long, Message> bufferState,
+      @Label("delayed-message-index") MapState<String, Long> cancellationTokenToTimestamp) {
     this.bufferState = Objects.requireNonNull(bufferState);
+    this.cancellationTokenToTimestamp = Objects.requireNonNull(cancellationTokenToTimestamp);
+  }
+
+  @Override
+  public void forEachMessageAt(long timestamp, Consumer<Message> fn) {
+    try {
+      forEachMessageThrows(timestamp, fn);
+    } catch (Exception e) {
+      throw new IllegalStateException(e);
+    }
+  }
+
+  @Override
+  public OptionalLong removeMessageByCancellationToken(String token) {
+    try {
+      return remove(token);
+    } catch (Exception e) {
+      throw new IllegalStateException(
+          "Failed clearing a message with a cancellation token " + token, e);
+    }
   }
 
   @Override
   public void add(Message message, long untilTimestamp) {
-    bufferState.setCurrentNamespace(untilTimestamp);
     try {
-      bufferState.add(message);
+      addThrows(message, untilTimestamp);
     } catch (Exception e) {
       throw new RuntimeException("Error adding delayed message to state buffer: " + message, e);
     }
   }
 
-  @Override
-  public Iterable<Message> getForTimestamp(long timestamp) {
-    bufferState.setCurrentNamespace(timestamp);
+  // -----------------------------------------------------------------------------------------------------
+  // Internal
+  // -----------------------------------------------------------------------------------------------------
 
-    try {
-      return bufferState.get();
-    } catch (Exception e) {
-      throw new RuntimeException(
-          "Error accessing delayed message in state buffer for timestamp: " + timestamp, e);
+  private void forEachMessageThrows(long timestamp, Consumer<Message> fn) throws Exception {
+    bufferState.setCurrentNamespace(timestamp);
+    for (Message message : bufferState.get()) {
+      removeMessageIdMapping(message);
+      fn.accept(message);
+    }
+    bufferState.clear();
+  }
+
+  private void addThrows(Message message, long untilTimestamp) throws Exception {
+    bufferState.setCurrentNamespace(untilTimestamp);
+    bufferState.add(message);
+    Optional<String> maybeToken = message.cancellationToken();
+    if (!maybeToken.isPresent()) {
+      return;
+    }
+    String cancellationToken = maybeToken.get();
+    @Nullable Long previousTimestamp = cancellationTokenToTimestamp.get(cancellationToken);
+    if (previousTimestamp != null) {
+      throw new IllegalStateException(
+          "Trying to associate a message with cancellation token "
+              + cancellationToken
+              + " and timestamp "
+              + untilTimestamp
+              + ", but a message with the same cancellation token exists and with a timestamp "
+              + previousTimestamp);
+    }
+    cancellationTokenToTimestamp.put(cancellationToken, untilTimestamp);
+  }
+
+  private OptionalLong remove(String cancellationToken) throws Exception {
+    final @Nullable Long untilTimestamp = cancellationTokenToTimestamp.get(cancellationToken);
+    if (untilTimestamp == null) {
+      // The message associated with @cancellationToken has already been delivered, or previously
+      // removed.
+      return OptionalLong.empty();
+    }
+    cancellationTokenToTimestamp.remove(cancellationToken);
+    bufferState.setCurrentNamespace(untilTimestamp);
+    List<Message> newList = removeMessageByToken(bufferState.get(), cancellationToken);
+    if (!newList.isEmpty()) {
+      // There are more messages to process, so we indicate to the caller that
+      // they should NOT cancel the timer.
+      bufferState.update(newList);
+      return OptionalLong.empty();
+    }
+    // There are no more message to remove, we clear the buffer and indicate
+    // to our caller to remove the timer for @untilTimestamp
+    bufferState.clear();
+    return OptionalLong.of(untilTimestamp);
+  }
+
+  // ---------------------------------------------------------------------------------------------------------
+  // Helpers
+  // ---------------------------------------------------------------------------------------------------------
+
+  private void removeMessageIdMapping(Message message) throws Exception {
+    Optional<String> maybeToken = message.cancellationToken();
+    if (maybeToken.isPresent()) {
+      cancellationTokenToTimestamp.remove(maybeToken.get());
     }
   }
 
-  @Override
-  public void clearForTimestamp(long timestamp) {
-    bufferState.setCurrentNamespace(timestamp);
-    bufferState.clear();
+  private static List<Message> removeMessageByToken(Iterable<Message> messages, String token) {
+    ArrayList<Message> newList = new ArrayList<>();
+    for (Message message : messages) {
+      Optional<String> thisMessageId = message.cancellationToken();
+      if (!thisMessageId.isPresent() || !Objects.equals(thisMessageId.get(), token)) {
+        newList.add(message);
+      }
+    }
+    return newList;
   }
 }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
index 741575b..8dcd01b 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/FunctionGroupOperator.java
@@ -104,6 +104,11 @@
     final ListStateDescriptor<Message> delayedMessageStateDescriptor =
         new ListStateDescriptor<>(
             FlinkStateDelayedMessagesBuffer.BUFFER_STATE_NAME, envelopeSerializer.duplicate());
+    final MapStateDescriptor<String, Long> delayedMessageIndexDescriptor =
+        new MapStateDescriptor<>(
+            FlinkStateDelayedMessagesBuffer.INDEX_STATE_NAME, String.class, Long.class);
+    final MapState<String, Long> delayedMessageIndex =
+        getRuntimeContext().getMapState(delayedMessageIndexDescriptor);
     final MapState<Long, Message> asyncOperationState =
         getRuntimeContext().getMapState(asyncOperationStateDescriptor);
 
@@ -130,6 +135,7 @@
             new FlinkTimerServiceFactory(
                 super.getTimeServiceManager().orElseThrow(IllegalStateException::new)),
             delayedMessagesBufferState(delayedMessageStateDescriptor),
+            delayedMessageIndex,
             sideOutputs,
             output,
             MessageFactory.forKey(statefulFunctionsUniverse.messageFactoryKey()),
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
index 55b521f..b881a85 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/Reductions.java
@@ -62,6 +62,7 @@
       KeyedStateBackend<Object> keyedStateBackend,
       TimerServiceFactory timerServiceFactory,
       InternalListState<String, Long, Message> delayedMessagesBufferState,
+      MapState<String, Long> delayMessageIndex,
       Map<EgressIdentifier<?>, OutputTag<Object>> sideOutputs,
       Output<StreamRecord<Message>> output,
       MessageFactory messageFactory,
@@ -117,6 +118,7 @@
     // for delayed messages
     container.add(
         "delayed-messages-buffer-state", InternalListState.class, delayedMessagesBufferState);
+    container.add("delayed-message-index", MapState.class, delayMessageIndex);
     container.add(
         "delayed-messages-buffer",
         DelayedMessagesBuffer.class,
@@ -124,6 +126,7 @@
     container.add(
         "delayed-messages-timer-service-factory", TimerServiceFactory.class, timerServiceFactory);
     container.add(DelaySink.class);
+    container.add(DelayMessageHandler.class);
 
     // lazy providers for the sinks
     container.add("function-group", new Lazy<>(LocalFunctionGroup.class));
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
index 77db7dc..e1a0c87 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/functions/ReusableContext.java
@@ -109,6 +109,23 @@
   }
 
   @Override
+  public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+    Objects.requireNonNull(delay);
+    Objects.requireNonNull(to);
+    Objects.requireNonNull(message);
+    Objects.requireNonNull(cancellationToken);
+
+    Message envelope = messageFactory.from(self(), to, message, cancellationToken);
+    delaySink.accept(envelope, delay.toMillis());
+  }
+
+  @Override
+  public void cancelDelayedMessage(String cancellationToken) {
+    Objects.requireNonNull(cancellationToken);
+    delaySink.removeMessageByCancellationToken(cancellationToken);
+  }
+
+  @Override
   public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {
     Objects.requireNonNull(metadata);
     Objects.requireNonNull(future);
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
index 2278fa5..be10e3f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/Message.java
@@ -18,6 +18,7 @@
 package org.apache.flink.statefun.flink.core.message;
 
 import java.io.IOException;
+import java.util.Optional;
 import java.util.OptionalLong;
 import org.apache.flink.core.memory.DataOutputView;
 
@@ -35,6 +36,8 @@
    */
   OptionalLong isBarrierMessage();
 
+  Optional<String> cancellationToken();
+
   Message copy(MessageFactory context);
 
   void writeTo(MessageFactory context, DataOutputView target) throws IOException;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
index e4d2d3a..780415e 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/MessageFactory.java
@@ -54,6 +54,10 @@
     return new SdkMessage(from, to, payload);
   }
 
+  public Message from(Address from, Address to, Object payload, String messageId) {
+    return new SdkMessage(from, to, payload, messageId);
+  }
+
   // -------------------------------------------------------------------------------------------------------
 
   void copy(DataInputView source, DataOutputView target) throws IOException {
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
index dabda14..500958f 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/ProtobufMessage.java
@@ -19,6 +19,7 @@
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.OptionalLong;
 import javax.annotation.Nullable;
 import org.apache.flink.core.memory.DataOutputView;
@@ -82,6 +83,15 @@
   }
 
   @Override
+  public Optional<String> cancellationToken() {
+    String token = envelope.getCancellationToken();
+    if (token.isEmpty()) {
+      return Optional.empty();
+    }
+    return Optional.of(token);
+  }
+
+  @Override
   public Message copy(MessageFactory unused) {
     return new ProtobufMessage(envelope);
   }
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
index c10f2e9..ca5ee50 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/message/SdkMessage.java
@@ -19,6 +19,7 @@
 
 import java.io.IOException;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.OptionalLong;
 import javax.annotation.Nullable;
 import org.apache.flink.core.memory.DataOutputView;
@@ -29,18 +30,27 @@
 
 final class SdkMessage implements Message {
 
-  @Nullable private final Address source;
-
   private final Address target;
 
+  @Nullable private final Address source;
+  @Nullable private final String cancellationToken;
+  @Nullable private Envelope cachedEnvelope;
+
   private Object payload;
 
-  @Nullable private Envelope cachedEnvelope;
-
   SdkMessage(@Nullable Address source, Address target, Object payload) {
+    this(source, target, payload, null);
+  }
+
+  SdkMessage(
+      @Nullable Address source,
+      Address target,
+      Object payload,
+      @Nullable String cancellationToken) {
     this.source = source;
     this.target = Objects.requireNonNull(target);
     this.payload = Objects.requireNonNull(payload);
+    this.cancellationToken = cancellationToken;
   }
 
   @Override
@@ -68,8 +78,13 @@
   }
 
   @Override
+  public Optional<String> cancellationToken() {
+    return Optional.ofNullable(cancellationToken);
+  }
+
+  @Override
   public Message copy(MessageFactory factory) {
-    return new SdkMessage(source, target, payload);
+    return new SdkMessage(source, target, payload, cancellationToken);
   }
 
   @Override
@@ -86,6 +101,9 @@
       }
       builder.setTarget(sdkAddressToProtobufAddress(target));
       builder.setPayload(factory.serializeUserMessagePayload(payload));
+      if (cancellationToken != null) {
+        builder.setCancellationToken(cancellationToken);
+      }
       cachedEnvelope = builder.build();
     }
     return cachedEnvelope;
diff --git a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
index d0ed196..b9fdc1a 100644
--- a/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
+++ b/statefun-flink/statefun-flink-core/src/main/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunction.java
@@ -233,14 +233,34 @@
   private void handleOutgoingDelayedMessages(Context context, InvocationResponse invocationResult) {
     for (FromFunction.DelayedInvocation delayedInvokeCommand :
         invocationResult.getDelayedInvocationsList()) {
-      final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
-      final TypedValue message = delayedInvokeCommand.getArgument();
-      final long delay = delayedInvokeCommand.getDelayInMs();
 
-      context.sendAfter(Duration.ofMillis(delay), to, message);
+      if (delayedInvokeCommand.getIsCancellationRequest()) {
+        handleDelayedMessageCancellation(context, delayedInvokeCommand);
+      } else {
+        handleDelayedMessageSending(context, delayedInvokeCommand);
+      }
     }
   }
 
+  private void handleDelayedMessageSending(
+      Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+    final Address to = polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
+    final TypedValue message = delayedInvokeCommand.getArgument();
+    final long delay = delayedInvokeCommand.getDelayInMs();
+
+    context.sendAfter(Duration.ofMillis(delay), to, message);
+  }
+
+  private void handleDelayedMessageCancellation(
+      Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
+    String token = delayedInvokeCommand.getCancellationToken();
+    if (token.isEmpty()) {
+      throw new IllegalArgumentException(
+          "Can not handle a cancellation request without a cancellation token.");
+    }
+    context.cancelDelayedMessage(token);
+  }
+
   // --------------------------------------------------------------------------------
   // Send Message to Remote Function
   // --------------------------------------------------------------------------------
diff --git a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
index 1b09239..1c17e0b 100644
--- a/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
+++ b/statefun-flink/statefun-flink-core/src/main/protobuf/stateful-functions.proto
@@ -37,10 +37,14 @@
     int64 checkpoint_id = 1;
 }
 
+
 message Envelope {
     EnvelopeAddress source = 1;
     EnvelopeAddress target = 2;
 
+    // an optional token that can be used track delayed message cancellation.
+    string cancellation_token = 10;
+
     oneof body {
         Checkpoint checkpoint = 4;
         Payload payload = 3;
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
index 9c7ccd0..7cbd4b6 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/LocalStatefulFunctionGroupTest.java
@@ -133,6 +133,12 @@
     public void sendAfter(Duration duration, Address to, Object message) {}
 
     @Override
+    public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {}
+
+    @Override
+    public void cancelDelayedMessage(String cancellationToken) {}
+
+    @Override
     public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
 
     @Override
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
index 3f84bce..cf3b19a 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/functions/ReductionsTest.java
@@ -100,12 +100,13 @@
             new FakeKeyedStateBackend(),
             new FakeTimerServiceFactory(),
             new FakeInternalListState(),
+            new FakeMapState<>(),
             new HashMap<>(),
             new FakeOutput(),
             TestUtils.ENVELOPE_FACTORY,
             MoreExecutors.directExecutor(),
             new FakeMetricGroup(),
-            new FakeMapState());
+            new FakeMapState<>());
 
     assertThat(reductions, notNullValue());
   }
@@ -517,44 +518,44 @@
     }
   }
 
-  private static final class FakeMapState implements MapState<Long, Message> {
+  private static final class FakeMapState<K, V> implements MapState<K, V> {
 
     @Override
-    public Message get(Long key) throws Exception {
+    public V get(K key) throws Exception {
       return null;
     }
 
     @Override
-    public void put(Long key, Message value) throws Exception {}
+    public void put(K key, V value) throws Exception {}
 
     @Override
-    public void putAll(Map<Long, Message> map) throws Exception {}
+    public void putAll(Map<K, V> map) throws Exception {}
 
     @Override
-    public void remove(Long key) throws Exception {}
+    public void remove(K key) throws Exception {}
 
     @Override
-    public boolean contains(Long key) throws Exception {
+    public boolean contains(K key) throws Exception {
       return false;
     }
 
     @Override
-    public Iterable<Entry<Long, Message>> entries() throws Exception {
+    public Iterable<Entry<K, V>> entries() throws Exception {
       return null;
     }
 
     @Override
-    public Iterable<Long> keys() throws Exception {
+    public Iterable<K> keys() throws Exception {
       return null;
     }
 
     @Override
-    public Iterable<Message> values() throws Exception {
+    public Iterable<V> values() throws Exception {
       return null;
     }
 
     @Override
-    public Iterator<Entry<Long, Message>> iterator() throws Exception {
+    public Iterator<Entry<K, V>> iterator() throws Exception {
       return null;
     }
 
diff --git a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
index f88916f..5b7c536 100644
--- a/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
+++ b/statefun-flink/statefun-flink-core/src/test/java/org/apache/flink/statefun/flink/core/reqreply/RequestReplyFunctionTest.java
@@ -37,6 +37,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
 import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
 import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
 import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
@@ -197,7 +198,7 @@
     functionUnderTest.invoke(context, successfulAsyncOperation(response));
 
     assertFalse(context.delayed.isEmpty());
-    assertEquals(Duration.ofMillis(1), context.delayed.get(0).getKey());
+    assertEquals(Duration.ofMillis(1), context.delayed.get(0).delay());
   }
 
   @Test
@@ -376,6 +377,38 @@
     }
   }
 
+  private static final class DelayedMessage {
+    final Duration delay;
+    final @Nullable String messageId;
+    final Address target;
+    final Object message;
+
+    public DelayedMessage(
+        Duration delay, @Nullable String messageId, Address target, Object message) {
+      this.delay = delay;
+      this.messageId = messageId;
+      this.target = target;
+      this.message = message;
+    }
+
+    public Duration delay() {
+      return delay;
+    }
+
+    @Nullable
+    public String messageId() {
+      return messageId;
+    }
+
+    public Address target() {
+      return target;
+    }
+
+    public Object message() {
+      return message;
+    }
+  }
+
   private static final class FakeContext implements InternalContext {
 
     private final BacklogTrackingMetrics fakeMetrics = new BacklogTrackingMetrics();
@@ -385,7 +418,7 @@
 
     // capture emitted messages
     List<Map.Entry<EgressIdentifier<?>, ?>> egresses = new ArrayList<>();
-    List<Map.Entry<Duration, ?>> delayed = new ArrayList<>();
+    List<DelayedMessage> delayed = new ArrayList<>();
 
     @Override
     public void awaitAsyncOperationComplete() {
@@ -417,10 +450,18 @@
 
     @Override
     public void sendAfter(Duration delay, Address to, Object message) {
-      delayed.add(new SimpleImmutableEntry<>(delay, message));
+      delayed.add(new DelayedMessage(delay, null, to, message));
     }
 
     @Override
+    public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+      delayed.add(new DelayedMessage(delay, cancellationToken, to, message));
+    }
+
+    @Override
+    public void cancelDelayedMessage(String cancellationToken) {}
+
+    @Override
     public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
   }
 
diff --git a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
index d3a1231..ea9a32d 100644
--- a/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
+++ b/statefun-sdk-embedded/src/main/java/org/apache/flink/statefun/sdk/Context.java
@@ -75,6 +75,33 @@
   void sendAfter(Duration delay, Address to, Object message);
 
   /**
+   * Invokes another function with an input (associated with a {@code cancellationToken}),
+   * identified by the target function's {@link Address}, after a given delay.
+   *
+   * <p>Providing an id to a message, allows "unsending" this message later. ({@link
+   * #cancelDelayedMessage(String)}).
+   *
+   * @param delay the amount of delay before invoking the target function. Value needs to be &gt;=
+   *     0.
+   * @param to the target function's address.
+   * @param message the input to provide for the delayed invocation.
+   * @param cancellationToken the non-empty, non-null, unique token to attach to this message, to be
+   *     used for message cancellation. (see {@link #cancelDelayedMessage(String)}.)
+   */
+  void sendAfter(Duration delay, Address to, Object message, String cancellationToken);
+
+  /**
+   * Cancel a delayed message (a message that was send via {@link #sendAfter(Duration, Address,
+   * Object, String)}).
+   *
+   * <p>NOTE: this is a best-effort operation, since the message might have been already delivered.
+   * If the message was delivered, this is a no-op operation.
+   *
+   * @param cancellationToken the id of the message to un-send.
+   */
+  void cancelDelayedMessage(String cancellationToken);
+
+  /**
    * Invokes another function with an input, identified by the target function's {@link
    * FunctionType} and unique id.
    *
diff --git a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
index 43ab6b8..8ef48e0 100644
--- a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
+++ b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/Context.java
@@ -56,6 +56,26 @@
   void sendAfter(Duration duration, Message message);
 
   /**
+   * Sends out a {@link Message} to another function, after a specified {@link Duration} delay.
+   *
+   * @param duration the amount of time to delay the message delivery. * @param cancellationToken
+   * @param cancellationToken the non-empty, non-null, unique token to attach to this message, to be
+   *     used for message cancellation. (see {@link #cancelDelayedMessage(String)}.)
+   * @param message the message to send.
+   */
+  void sendAfter(Duration duration, String cancellationToken, Message message);
+
+  /**
+   * Cancel a delayed message (a message that was send via {@link #sendAfter(Duration, Message)}).
+   *
+   * <p>NOTE: this is a best-effort operation, since the message might have been already delivered.
+   * If the message was delivered, this is a no-op operation.
+   *
+   * @param cancellationToken the id of the message to un-send.
+   */
+  void cancelDelayedMessage(String cancellationToken);
+
+  /**
    * Sends out a {@link EgressMessage} to an egress.
    *
    * @param message the message to send.
diff --git a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
index 49e9fcc..07e44ca 100644
--- a/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
+++ b/statefun-sdk-java/src/main/java/org/apache/flink/statefun/sdk/java/handler/ConcurrentContext.java
@@ -116,6 +116,46 @@
   }
 
   @Override
+  public void sendAfter(Duration duration, String cancellationToken, Message message) {
+    Objects.requireNonNull(duration);
+    if (cancellationToken == null || cancellationToken.isEmpty()) {
+      throw new IllegalArgumentException("message cancellation token can not be empty or null.");
+    }
+    Objects.requireNonNull(message);
+
+    FromFunction.DelayedInvocation outInvocation =
+        FromFunction.DelayedInvocation.newBuilder()
+            .setArgument(getTypedValue(message))
+            .setTarget(protoAddressFromSdk(message.targetAddress()))
+            .setDelayInMs(duration.toMillis())
+            .setCancellationToken(cancellationToken)
+            .build();
+
+    synchronized (responseBuilder) {
+      checkNotDone();
+      responseBuilder.addDelayedInvocations(outInvocation);
+    }
+  }
+
+  @Override
+  public void cancelDelayedMessage(String cancellationToken) {
+    if (cancellationToken == null || cancellationToken.isEmpty()) {
+      throw new IllegalArgumentException("message cancellation token can not be empty or null.");
+    }
+
+    FromFunction.DelayedInvocation cancellation =
+        FromFunction.DelayedInvocation.newBuilder()
+            .setIsCancellationRequest(true)
+            .setCancellationToken(cancellationToken)
+            .build();
+
+    synchronized (responseBuilder) {
+      checkNotDone();
+      responseBuilder.addDelayedInvocations(cancellation);
+    }
+  }
+
+  @Override
   public void send(EgressMessage message) {
     Objects.requireNonNull(message);
 
diff --git a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
index 19d9f2a..ac72d7c 100644
--- a/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
+++ b/statefun-sdk-protos/src/main/protobuf/sdk/request-reply.proto
@@ -115,6 +115,15 @@
     // DelayedInvocation represents a delayed remote function call with a target address, an argument
     // and a delay in milliseconds, after which this message to be sent.
     message DelayedInvocation {
+        // a boolean value (default false) that indicates rather this is a regular delayed message, or (true) a message
+        // cancellation request.
+        // in case of a regular delayed message all other fields are expected to be preset, otherwise only the
+        // cancellation_token is expected
+        bool is_cancellation_request = 10;
+
+        // an optional cancellation token that can be used to request the "unsending" of a delayed message.
+        string cancellation_token = 11;
+
         // the amount of milliseconds to wait before sending this message
         int64 delay_in_ms = 1;
         // the target address to send this message to
diff --git a/statefun-sdk-python/statefun/context.py b/statefun-sdk-python/statefun/context.py
index b1692a5..578e020 100644
--- a/statefun-sdk-python/statefun/context.py
+++ b/statefun-sdk-python/statefun/context.py
@@ -24,7 +24,6 @@
 
 
 class Context(abc.ABC):
-
     __slots__ = ()
 
     @property
@@ -62,13 +61,24 @@
         """
         pass
 
-    def send_after(self, duration: timedelta, message: Message):
+    def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""):
         """
         Send a message to a target function after a specified delay.
 
         :param duration: the amount of time to wait before sending this message out.
         :param message: the message to send.
+        :param cancellation_token: an optional cancellation token to associate with this message.
         """
+        pass
+
+    def cancel_delayed_message(self, cancellation_token: str):
+        """
+        Cancel a delayed message (message that was sent using send_after) with a given token.
+
+        Please note that this is a best-effort operation, since the message might have been already delivered.
+        If the message was delivered, this is a no-op operation.
+        """
+        pass
 
     def send_egress(self, message: EgressMessage):
         """
diff --git a/statefun-sdk-python/statefun/request_reply_v3.py b/statefun-sdk-python/statefun/request_reply_v3.py
index fa05253..f4db551 100644
--- a/statefun-sdk-python/statefun/request_reply_v3.py
+++ b/statefun-sdk-python/statefun/request_reply_v3.py
@@ -28,6 +28,16 @@
 from statefun.request_reply_pb2 import ToFunction, FromFunction, Address, TypedValue
 from statefun.storage import resolve, Cell
 
+from dataclasses import dataclass
+
+
+@dataclass
+class DelayedMessage:
+    is_cancellation: bool = None
+    duration: int = None
+    message: Message = None,
+    cancellation_token: str = None
+
 
 class UserFacingContext(statefun.context.Context):
     __slots__ = (
@@ -37,7 +47,7 @@
     def __init__(self, address, storage):
         self._self_address = address
         self._outgoing_messages = []
-        self._outgoing_delayed_messages = []
+        self._outgoing_delayed_messages: typing.List[DelayedMessage] = []
         self._outgoing_egress_messages = []
         self._storage = storage
         self._caller = None
@@ -66,15 +76,28 @@
         """
         self._outgoing_messages.append(message)
 
-    def send_after(self, duration: timedelta, message: Message):
+    def send_after(self, duration: timedelta, message: Message, cancellation_token: str = ""):
         """
         Send a message to a target function after a specified delay.
 
         :param duration: the amount of time to wait before sending this message out.
         :param message: the message to send.
+        :param cancellation_token: an optional cancellation token to associate with this message.
         """
         ms = int(duration.total_seconds() * 1000.0)
-        self._outgoing_delayed_messages.append((ms, message))
+        record = DelayedMessage(is_cancellation=False, duration=ms, message=message,
+                                cancellation_token=cancellation_token)
+        self._outgoing_delayed_messages.append(record)
+
+    def cancel_delayed_message(self, cancellation_token: str):
+        """
+        Cancel a delayed message (message that was sent using send_after) with a given token.
+
+        Please note that this is a best-effort operation, since the message might have been already delivered.
+        If the message was delivered, this is a no-op operation.
+        """
+        record = DelayedMessage(is_cancellation=True, cancellation_token=cancellation_token)
+        self._outgoing_delayed_messages.append(record)
 
     def send_egress(self, message: EgressMessage):
         """
@@ -145,17 +168,34 @@
         outgoing.argument.CopyFrom(message.typed_value)
 
 
-def collect_delayed(delayed_messages: typing.List[typing.Tuple[timedelta, Message]], invocation_result):
+def collect_delayed(delayed_messages: typing.List[DelayedMessage], invocation_result):
     delayed_invocations = invocation_result.delayed_invocations
-    for delay, message in delayed_messages:
+    for delayed_message in delayed_messages:
         outgoing = delayed_invocations.add()
 
-        namespace, type = parse_typename(message.target_typename)
-        outgoing.target.namespace = namespace
-        outgoing.target.type = type
-        outgoing.target.id = message.target_id
-        outgoing.delay_in_ms = delay
-        outgoing.argument.CopyFrom(message.typed_value)
+        if delayed_message.is_cancellation:
+            # handle cancellation
+            outgoing.cancellation_token = delayed_message.cancellation_token
+            outgoing.is_cancellation_request = True
+        else:
+            message = delayed_message.message
+            namespace, type = parse_typename(message.target_typename)
+
+            outgoing.target.namespace = namespace
+            outgoing.target.type = type
+            outgoing.target.id = message.target_id
+            outgoing.delay_in_ms = delayed_message.duration
+            outgoing.argument.CopyFrom(message.typed_value)
+            if delayed_message.cancellation_token is not None:
+                outgoing.cancellation_token = delayed_message.cancellation_token
+
+
+def collect_cancellations(tokens: typing.List[str], invocation_result):
+    outgoing_cancellations = invocation_result.outgoing_delay_cancellations
+    for token in tokens:
+        if token:
+            delay_cancelltion = outgoing_cancellations.add()
+            delay_cancelltion.cancellation_token = token
 
 
 def collect_egress(egresses: typing.List[EgressMessage], invocation_result):
diff --git a/statefun-sdk-python/tests/request_reply_test.py b/statefun-sdk-python/tests/request_reply_test.py
index 612750f..5bd783a 100644
--- a/statefun-sdk-python/tests/request_reply_test.py
+++ b/statefun-sdk-python/tests/request_reply_test.py
@@ -91,6 +91,7 @@
 NTH_OUTGOING_MESSAGE = lambda n: [key("invocation_result"), key("outgoing_messages"), nth(n)]
 NTH_STATE_MUTATION = lambda n: [key("invocation_result"), key("state_mutations"), nth(n)]
 NTH_DELAYED_MESSAGE = lambda n: [key("invocation_result"), key("delayed_invocations"), nth(n)]
+NTH_CANCELLATION_MESSAGE = lambda n: [key("invocation_result"), key("outgoing_delay_cancellations"), nth(n)]
 NTH_EGRESS = lambda n: [key("invocation_result"), key("outgoing_egresses"), nth(n)]
 NTH_MISSING_STATE_SPEC = lambda n: [key("incomplete_invocation_context"), key("missing_values"), nth(n)]
 
@@ -123,6 +124,16 @@
                                message_builder(target_typename="night/owl",
                                                target_id="1",
                                                str_value="hoo hoo"))
+
+            # delayed with cancellation
+            context.send_after(timedelta(hours=1),
+                               message_builder(target_typename="night/owl",
+                                               target_id="1",
+                                               str_value="hoo hoo"),
+                               cancellation_token="token-1234")
+
+            context.cancel_delayed_message("token-1234")
+
             # kafka egresses
             context.send_egress(
                 kafka_egress_message(typename="e/kafka",
@@ -165,6 +176,15 @@
         first_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(0))
         self.assertEqual(int(first_delayed['delay_in_ms']), 1000 * 60 * 60)
 
+        # assert delayed with token
+        second_delayed = json_at(result_json, NTH_DELAYED_MESSAGE(1))
+        self.assertEqual(second_delayed['cancellation_token'], "token-1234")
+
+        # assert cancellation
+        first_cancellation = json_at(result_json, NTH_DELAYED_MESSAGE(2))
+        self.assertTrue(first_cancellation['is_cancellation_request'])
+        self.assertEqual(first_cancellation['cancellation_token'], "token-1234")
+
         # assert egresses
         first_egress = json_at(result_json, NTH_EGRESS(0))
         self.assertEqual(first_egress['egress_namespace'], 'e')
diff --git a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
index 1c39775..fdf6be8 100644
--- a/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
+++ b/statefun-testutil/src/main/java/org/apache/flink/statefun/testutils/function/TestContext.java
@@ -102,7 +102,21 @@
   @Override
   public void sendAfter(Duration delay, Address to, Object message) {
     pendingMessage.add(
-        new PendingMessage(new Envelope(self(), to, message), watermark + delay.toMillis()));
+        new PendingMessage(new Envelope(self(), to, message), watermark + delay.toMillis(), null));
+  }
+
+  @Override
+  public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
+    Objects.requireNonNull(cancellationToken);
+    pendingMessage.add(
+        new PendingMessage(
+            new Envelope(self(), to, message), watermark + delay.toMillis(), cancellationToken));
+  }
+
+  @Override
+  public void cancelDelayedMessage(String cancellationToken) {
+    pendingMessage.removeIf(
+        pendingMessage -> Objects.equals(pendingMessage.cancellationToken, cancellationToken));
   }
 
   @Override
@@ -186,12 +200,13 @@
 
   private static class PendingMessage {
     Envelope envelope;
-
+    String cancellationToken;
     long timer;
 
-    PendingMessage(Envelope envelope, long timer) {
+    PendingMessage(Envelope envelope, long timer, String cancellationToken) {
       this.envelope = envelope;
       this.timer = timer;
+      this.cancellationToken = cancellationToken;
     }
   }
 }