blob: 54c1d1e3a17f38215c691b1a218e2282e732c45c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.fn.harness.control;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.beam.fn.harness.PTransformRunnerFactory;
import org.apache.beam.fn.harness.data.BeamFnDataClient;
import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry;
import org.apache.beam.fn.harness.data.PTransformFunctionRegistry;
import org.apache.beam.fn.harness.state.BeamFnStateClient;
import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse;
import org.apache.beam.model.pipeline.v1.Endpoints.ApiServiceDescriptor;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Coder;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
import org.apache.beam.sdk.function.ThrowingConsumer;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.Message;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
/** Tests for {@link ProcessBundleHandler}. */
@RunWith(JUnit4.class)
public class ProcessBundleHandlerTest {
private static final String DATA_INPUT_URN = "beam:source:runner:0.1";
private static final String DATA_OUTPUT_URN = "beam:sink:runner:0.1";
@Rule public ExpectedException thrown = ExpectedException.none();
@Mock private BeamFnDataClient beamFnDataClient;
@Captor private ArgumentCaptor<ThrowingConsumer<Exception, WindowedValue<String>>> consumerCaptor;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
}
@Test
public void testOrderOfStartAndFinishCalls() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.putOutputs("2L-output", "2L-output-pc")
.build())
.putTransforms(
"3L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_OUTPUT_URN).build())
.putInputs("3L-input", "2L-output-pc")
.build())
.putPcollections("2L-output-pc", RunnerApi.PCollection.getDefaultInstance())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
List<RunnerApi.PTransform> transformsProcessed = new ArrayList<>();
List<String> orderOfOperations = new ArrayList<>();
PTransformRunnerFactory<Object> startFinishRecorder =
(pipelineOptions,
beamFnDataClient,
beamFnStateClient,
pTransformId,
pTransform,
processBundleInstructionId,
pCollections,
coders,
windowingStrategies,
pCollectionConsumerRegistry,
startFunctionRegistry,
finishFunctionRegistry,
splitListener) -> {
assertThat(processBundleInstructionId.get(), equalTo("999L"));
transformsProcessed.add(pTransform);
startFunctionRegistry.register(
pTransformId, () -> orderOfOperations.add("Start" + pTransformId));
finishFunctionRegistry.register(
pTransformId, () -> orderOfOperations.add("Finish" + pTransformId));
return null;
};
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
null /* beamFnStateClient */,
ImmutableMap.of(
DATA_INPUT_URN, startFinishRecorder,
DATA_OUTPUT_URN, startFinishRecorder));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setInstructionId("999L")
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
// Processing of transforms is performed in reverse order.
assertThat(
transformsProcessed,
contains(
processBundleDescriptor.getTransformsMap().get("3L"),
processBundleDescriptor.getTransformsMap().get("2L")));
// Start should occur in reverse order while finish calls should occur in forward order
assertThat(orderOfOperations, contains("Start3L", "Start2L", "Finish2L", "Finish3L"));
}
@Test
public void testCreatingPTransformExceptionsArePropagated() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.build())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
null /* beamFnStateGrpcClientCache */,
ImmutableMap.of(
DATA_INPUT_URN,
(pipelineOptions,
beamFnDataClient,
beamFnStateClient,
pTransformId,
pTransform,
processBundleInstructionId,
pCollections,
coders,
windowingStrategies,
pCollectionConsumerRegistry,
startFunctionRegistry,
finishFunctionRegistry,
splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
throw new IllegalStateException("TestException");
}));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
}
@Test
public void testPTransformStartExceptionsArePropagated() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.build())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
null /* beamFnStateGrpcClientCache */,
ImmutableMap.of(
DATA_INPUT_URN,
(PTransformRunnerFactory<Object>)
(pipelineOptions,
beamFnDataClient,
beamFnStateClient,
pTransformId,
pTransform,
processBundleInstructionId,
pCollections,
coders,
windowingStrategies,
pCollectionConsumerRegistry,
startFunctionRegistry,
finishFunctionRegistry,
splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
startFunctionRegistry.register(
pTransformId, ProcessBundleHandlerTest::throwException);
return null;
}));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
}
@Test
public void testPTransformFinishExceptionsArePropagated() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.build())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
null /* beamFnStateGrpcClientCache */,
ImmutableMap.of(
DATA_INPUT_URN,
(PTransformRunnerFactory<Object>)
(pipelineOptions,
beamFnDataClient,
beamFnStateClient,
pTransformId,
pTransform,
processBundleInstructionId,
pCollections,
coders,
windowingStrategies,
pCollectionConsumerRegistry,
startFunctionRegistry,
finishFunctionRegistry,
splitListener) -> {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("TestException");
finishFunctionRegistry.register(
pTransformId, ProcessBundleHandlerTest::throwException);
return null;
}));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
}
@Test
public void testPendingStateCallsBlockTillCompletion() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.build())
.setStateApiServiceDescriptor(ApiServiceDescriptor.getDefaultInstance())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
CompletableFuture<StateResponse> successfulResponse = new CompletableFuture<>();
CompletableFuture<StateResponse> unsuccessfulResponse = new CompletableFuture<>();
BeamFnStateGrpcClientCache mockBeamFnStateGrpcClient =
Mockito.mock(BeamFnStateGrpcClientCache.class);
BeamFnStateClient mockBeamFnStateClient = Mockito.mock(BeamFnStateClient.class);
when(mockBeamFnStateGrpcClient.forApiServiceDescriptor(any()))
.thenReturn(mockBeamFnStateClient);
doAnswer(
invocation -> {
StateRequest.Builder stateRequestBuilder =
(StateRequest.Builder) invocation.getArguments()[0];
CompletableFuture<StateResponse> completableFuture =
(CompletableFuture<StateResponse>) invocation.getArguments()[1];
new Thread(
() -> {
// Simulate sleeping which introduces a race which most of the time requires
// the ProcessBundleHandler to block.
Uninterruptibles.sleepUninterruptibly(500, TimeUnit.MILLISECONDS);
switch (stateRequestBuilder.getInstructionId()) {
case "SUCCESS":
completableFuture.complete(StateResponse.getDefaultInstance());
break;
case "FAIL":
completableFuture.completeExceptionally(
new RuntimeException("TEST ERROR"));
}
})
.start();
return null;
})
.when(mockBeamFnStateClient)
.handle(any(), any());
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
mockBeamFnStateGrpcClient,
ImmutableMap.of(
DATA_INPUT_URN,
new PTransformRunnerFactory<Object>() {
@Override
public Object createRunnerForPTransform(
PipelineOptions pipelineOptions,
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
Map<String, WindowingStrategy> windowingStrategies,
PCollectionConsumerRegistry pCollectionConsumerRegistry,
PTransformFunctionRegistry startFunctionRegistry,
PTransformFunctionRegistry finishFunctionRegistry,
BundleSplitListener splitListener)
throws IOException {
startFunctionRegistry.register(
pTransformId, () -> doStateCalls(beamFnStateClient));
return null;
}
private void doStateCalls(BeamFnStateClient beamFnStateClient) {
beamFnStateClient.handle(
StateRequest.newBuilder().setInstructionId("SUCCESS"), successfulResponse);
beamFnStateClient.handle(
StateRequest.newBuilder().setInstructionId("FAIL"), unsuccessfulResponse);
}
}));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
assertTrue(successfulResponse.isDone());
assertTrue(unsuccessfulResponse.isDone());
}
@Test
public void testStateCallsFailIfNoStateApiServiceDescriptorSpecified() throws Exception {
BeamFnApi.ProcessBundleDescriptor processBundleDescriptor =
BeamFnApi.ProcessBundleDescriptor.newBuilder()
.putTransforms(
"2L",
RunnerApi.PTransform.newBuilder()
.setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build())
.build())
.build();
Map<String, Message> fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor);
ProcessBundleHandler handler =
new ProcessBundleHandler(
PipelineOptionsFactory.create(),
fnApiRegistry::get,
beamFnDataClient,
null /* beamFnStateGrpcClientCache */,
ImmutableMap.of(
DATA_INPUT_URN,
new PTransformRunnerFactory<Object>() {
@Override
public Object createRunnerForPTransform(
PipelineOptions pipelineOptions,
BeamFnDataClient beamFnDataClient,
BeamFnStateClient beamFnStateClient,
String pTransformId,
PTransform pTransform,
Supplier<String> processBundleInstructionId,
Map<String, PCollection> pCollections,
Map<String, Coder> coders,
Map<String, WindowingStrategy> windowingStrategies,
PCollectionConsumerRegistry pCollectionConsumerRegistry,
PTransformFunctionRegistry startFunctionRegistry,
PTransformFunctionRegistry finishFunctionRegistry,
BundleSplitListener splitListener)
throws IOException {
startFunctionRegistry.register(
pTransformId, () -> doStateCalls(beamFnStateClient));
return null;
}
private void doStateCalls(BeamFnStateClient beamFnStateClient) {
thrown.expect(IllegalStateException.class);
thrown.expectMessage("State API calls are unsupported");
beamFnStateClient.handle(
StateRequest.newBuilder().setInstructionId("SUCCESS"),
new CompletableFuture<>());
}
}));
handler.processBundle(
BeamFnApi.InstructionRequest.newBuilder()
.setProcessBundle(
BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L"))
.build());
}
private static void throwException() {
throw new IllegalStateException("TestException");
}
}