blob: 7db1bd940712da715aa03e42ea8420c81a034e4f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.statefun.flink.core.reqreply;
import static org.apache.flink.statefun.flink.core.TestUtils.FUNCTION_1_ADDR;
import static org.apache.flink.statefun.flink.core.common.PolyglotUtil.polyglotAddressToSdkAddress;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import com.google.protobuf.ByteString;
import java.time.Duration;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
import org.apache.flink.statefun.flink.core.metrics.RemoteInvocationMetrics;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.AsyncOperationResult;
import org.apache.flink.statefun.sdk.AsyncOperationResult.Status;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.DelayedInvocation;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.EgressMessage;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.ExpirationSpec;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.IncompleteInvocationContext;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.InvocationResponse;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueMutation;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueMutation.MutationType;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction.PersistedValueSpec;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction.Invocation;
import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.junit.Test;
public class RequestReplyFunctionTest {
private static final FunctionType FN_TYPE = new FunctionType("foo", "bar");
private final FakeClient client = new FakeClient();
private final FakeContext context = new FakeContext();
private final RequestReplyFunction functionUnderTest =
new RequestReplyFunction(
FN_TYPE, testInitialRegisteredState("session", "com.foo.bar/myType"), 10, client, true);
@Test
public void example() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertTrue(client.wasSentToFunction.hasInvocation());
assertThat(client.capturedInvocationBatchSize(), is(1));
}
@Test
public void callerIsSet() {
context.caller = FUNCTION_1_ADDR;
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
Invocation anInvocation = client.capturedInvocation(0);
Address caller = polyglotAddressToSdkAddress(anInvocation.getCaller());
assertThat(caller, is(FUNCTION_1_ADDR));
}
@Test
public void messageIsSet() {
TypedValue argument =
TypedValue.newBuilder()
.setTypename("io.statefun.foo/bar")
.setHasValue(true)
.setValue(ByteString.copyFromUtf8("Hello!"))
.build();
functionUnderTest.invoke(context, argument);
assertThat(client.capturedInvocation(0).getArgument(), is(argument));
}
@Test
public void batchIsAccumulatedWhileARequestIsInFlight() {
// send one message
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should be queued and sent as a batch
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// simulate a successful completion of the first operation
functionUnderTest.invoke(context, successfulAsyncOperation());
assertThat(client.capturedInvocationBatchSize(), is(2));
}
@Test
public void reachingABatchLimitTriggersBackpressure() {
RequestReplyFunction functionUnderTest = new RequestReplyFunction(FN_TYPE, 2, client);
// send one message
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should be queued
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// the following invocations should request backpressure
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.needsWaiting, is(true));
}
@Test
public void returnedMessageReleaseBackpressure() {
RequestReplyFunction functionUnderTest = new RequestReplyFunction(FN_TYPE, 2, client);
// the following invocations should cause backpressure
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// complete one message, should send a batch of size 3
context.needsWaiting = false;
functionUnderTest.invoke(context, successfulAsyncOperation());
// the next message should not cause backpressure.
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.needsWaiting, is(false));
}
@Test
public void stateIsModified() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// A message returned from the function
// that asks to put "hello" into the session state.
FromFunction response =
FromFunction.newBuilder()
.setInvocationResult(
InvocationResponse.newBuilder()
.addStateMutations(
PersistedValueMutation.newBuilder()
.setStateValue(
TypedValue.newBuilder()
.setTypename("com.foo.bar/myType")
.setHasValue(true)
.setValue(ByteString.copyFromUtf8("hello")))
.setMutationType(MutationType.MODIFY)
.setStateName("session")))
.build();
functionUnderTest.invoke(context, successfulAsyncOperation(response));
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(client.capturedState(0).getValue(), is(ByteString.copyFromUtf8("hello")));
}
@Test
public void delayedMessages() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
FromFunction response =
FromFunction.newBuilder()
.setInvocationResult(
InvocationResponse.newBuilder()
.addDelayedInvocations(
DelayedInvocation.newBuilder()
.setArgument(TypedValue.getDefaultInstance())
.setDelayInMs(1)
.build()))
.build();
functionUnderTest.invoke(context, successfulAsyncOperation(response));
assertFalse(context.delayed.isEmpty());
assertEquals(Duration.ofMillis(1), context.delayed.get(0).delay());
}
@Test
public void egressIsSent() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
FromFunction response =
FromFunction.newBuilder()
.setInvocationResult(
InvocationResponse.newBuilder()
.addOutgoingEgresses(
EgressMessage.newBuilder()
.setArgument(TypedValue.getDefaultInstance())
.setEgressNamespace("org.foo")
.setEgressType("bar")))
.build();
functionUnderTest.invoke(context, successfulAsyncOperation(response));
assertFalse(context.egresses.isEmpty());
assertEquals(
new EgressIdentifier<>("org.foo", "bar", TypedValue.class),
context.egresses.get(0).getKey());
}
@Test
public void retryBatchOnIncompleteInvocationContextResponse() {
TypedValue argument =
TypedValue.newBuilder()
.setTypename("io.statefun.foo/bar")
.setValue(ByteString.copyFromUtf8("Hello!"))
.build();
functionUnderTest.invoke(context, argument);
FromFunction response =
FromFunction.newBuilder()
.setIncompleteInvocationContext(
IncompleteInvocationContext.newBuilder()
.addMissingValues(
PersistedValueSpec.newBuilder()
.setStateName("new-state")
.setExpirationSpec(
ExpirationSpec.newBuilder()
.setMode(ExpirationSpec.ExpireMode.AFTER_INVOKE)
.setExpireAfterMillis(5000)
.build())))
.build();
functionUnderTest.invoke(context, successfulAsyncOperation(client.wasSentToFunction, response));
// re-sent batch should have identical invocation input messages
assertTrue(client.wasSentToFunction.hasInvocation());
assertThat(client.capturedInvocationBatchSize(), is(1));
assertThat(client.capturedInvocation(0).getArgument(), is(argument));
// re-sent batch should have new state as well as originally registered state
assertThat(client.capturedStateNames().size(), is(2));
assertThat(client.capturedStateNames(), hasItems("session", "new-state"));
}
@Test
public void backlogMetricsIncreasedOnInvoke() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// following should be accounted into backlog metrics
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
assertThat(context.functionTypeMetrics().numBacklog, is(2));
}
@Test
public void backlogMetricsDecreasedOnNextSuccess() {
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// following should be accounted into backlog metrics
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
functionUnderTest.invoke(context, TypedValue.getDefaultInstance());
// complete one message, should fully consume backlog
context.needsWaiting = false;
functionUnderTest.invoke(context, successfulAsyncOperation());
assertThat(context.functionTypeMetrics().numBacklog, is(0));
}
@Test
public void retryBatchOnUnkownAsyncResponseAfterRestore() {
TypedValue argument =
TypedValue.newBuilder()
.setTypename("io.statefun.foo/bar")
.setValue(ByteString.copyFromUtf8("Hello!"))
.build();
functionUnderTest.invoke(context, argument);
ToFunction originalRequest = client.wasSentToFunction;
RequestReplyFunction restoredFunction =
new RequestReplyFunction(FN_TYPE, new PersistedRemoteFunctionValues(), 2, client, true);
restoredFunction.invoke(context, unknownAsyncOperation(originalRequest));
// retry batch after a restore on an unknown async operation should start with empty state specs
assertTrue(client.wasSentToFunction.hasInvocation());
assertThat(client.capturedInvocationBatchSize(), is(1));
assertThat(client.capturedInvocation(0).getArgument(), is(argument));
assertThat(client.capturedStateNames().size(), is(0));
}
private static PersistedRemoteFunctionValues testInitialRegisteredState(
String existingStateName, String typename) {
final PersistedRemoteFunctionValues states = new PersistedRemoteFunctionValues();
states.registerStates(
Collections.singletonList(
PersistedValueSpec.newBuilder()
.setTypeTypename(typename)
.setStateName(existingStateName)
.build()));
return states;
}
private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation() {
return new AsyncOperationResult<>(
new Object(), Status.SUCCESS, FromFunction.getDefaultInstance(), null);
}
private static AsyncOperationResult<Object, FromFunction> successfulAsyncOperation(
FromFunction fromFunction) {
return new AsyncOperationResult<>(new Object(), Status.SUCCESS, fromFunction, null);
}
private static AsyncOperationResult<ToFunction, FromFunction> successfulAsyncOperation(
ToFunction toFunction, FromFunction fromFunction) {
return new AsyncOperationResult<>(toFunction, Status.SUCCESS, fromFunction, null);
}
private static AsyncOperationResult<ToFunction, FromFunction> unknownAsyncOperation(
ToFunction toFunction) {
return new AsyncOperationResult<>(
toFunction, Status.UNKNOWN, FromFunction.getDefaultInstance(), null);
}
private static final class FakeClient implements RequestReplyClient {
ToFunction wasSentToFunction;
Supplier<FromFunction> fromFunction = FromFunction::getDefaultInstance;
@Override
public CompletableFuture<FromFunction> call(
ToFunctionRequestSummary requestSummary,
RemoteInvocationMetrics metrics,
ToFunction toFunction) {
this.wasSentToFunction = toFunction;
try {
return CompletableFuture.completedFuture(this.fromFunction.get());
} catch (Throwable t) {
CompletableFuture<FromFunction> failed = new CompletableFuture<>();
failed.completeExceptionally(t);
return failed;
}
}
/** return the n-th invocation sent as part of the current batch. */
Invocation capturedInvocation(int n) {
return wasSentToFunction.getInvocation().getInvocations(n);
}
TypedValue capturedState(int n) {
return wasSentToFunction.getInvocation().getState(n).getStateValue();
}
Set<String> capturedStateNames() {
return wasSentToFunction.getInvocation().getStateList().stream()
.map(ToFunction.PersistedValue::getStateName)
.collect(Collectors.toSet());
}
public int capturedInvocationBatchSize() {
return wasSentToFunction.getInvocation().getInvocationsCount();
}
}
private static final class DelayedMessage {
final Duration delay;
final @Nullable String messageId;
final Address target;
final Object message;
public DelayedMessage(
Duration delay, @Nullable String messageId, Address target, Object message) {
this.delay = delay;
this.messageId = messageId;
this.target = target;
this.message = message;
}
public Duration delay() {
return delay;
}
@Nullable
public String messageId() {
return messageId;
}
public Address target() {
return target;
}
public Object message() {
return message;
}
}
private static final class FakeContext implements InternalContext {
private final BacklogTrackingMetrics fakeMetrics = new BacklogTrackingMetrics();
Address caller;
boolean needsWaiting;
// capture emitted messages
List<Map.Entry<EgressIdentifier<?>, ?>> egresses = new ArrayList<>();
List<DelayedMessage> delayed = new ArrayList<>();
@Override
public void awaitAsyncOperationComplete() {
needsWaiting = true;
}
@Override
public BacklogTrackingMetrics functionTypeMetrics() {
return fakeMetrics;
}
@Override
public Address self() {
return new Address(FN_TYPE, "0");
}
@Override
public Address caller() {
return caller;
}
@Override
public void send(Address to, Object message) {}
@Override
public <T> void send(EgressIdentifier<T> egress, T message) {
egresses.add(new SimpleImmutableEntry<>(egress, message));
}
@Override
public void sendAfter(Duration delay, Address to, Object message) {
delayed.add(new DelayedMessage(delay, null, to, message));
}
@Override
public void sendAfter(Duration delay, Address to, Object message, String cancellationToken) {
delayed.add(new DelayedMessage(delay, cancellationToken, to, message));
}
@Override
public void cancelDelayedMessage(String cancellationToken) {}
@Override
public <M, T> void registerAsyncOperation(M metadata, CompletableFuture<T> future) {}
}
private static final class BacklogTrackingMetrics implements FunctionTypeMetrics {
private int numBacklog = 0;
public int numBacklog() {
return numBacklog;
}
@Override
public void appendBacklogMessages(int count) {
numBacklog += count;
}
@Override
public void consumeBacklogMessages(int count) {
numBacklog -= count;
}
@Override
public void remoteInvocationFailures() {}
@Override
public void remoteInvocationLatency(long elapsed) {}
@Override
public void asyncOperationRegistered() {}
@Override
public void asyncOperationCompleted() {}
@Override
public void incomingMessage() {}
@Override
public void outgoingRemoteMessage() {}
@Override
public void outgoingEgressMessage() {}
@Override
public void outgoingLocalMessage() {}
@Override
public void blockedAddress() {}
@Override
public void unblockedAddress() {}
}
}