blob: 8bcacfad645188f228b476c84b8b20d49c0645cd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.fn.harness.data;
import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.util.Collection;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.InboundDataClient;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.fn.test.TestExecutors;
import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
import org.apache.beam.sdk.fn.test.TestStreams;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.CallStreamObserver;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** Tests for {@link QueueingBeamFnDataClient}. */
@RunWith(JUnit4.class)
public class QueueingBeamFnDataClientTest {
private static final Logger LOG = LoggerFactory.getLogger(QueueingBeamFnDataClientTest.class);
@Rule public TestExecutorService executor = TestExecutors.from(Executors::newCachedThreadPool);
private static final Coder<WindowedValue<String>> CODER =
LengthPrefixCoder.of(
WindowedValue.getFullCoder(StringUtf8Coder.of(), GlobalWindow.Coder.INSTANCE));
private static final LogicalEndpoint ENDPOINT_A = LogicalEndpoint.of("12L", "34L");
private static final LogicalEndpoint ENDPOINT_B = LogicalEndpoint.of("56L", "78L");
private static final BeamFnApi.Elements ELEMENTS_A_1;
private static final BeamFnApi.Elements ELEMENTS_A_2;
private static final BeamFnApi.Elements ELEMENTS_B_1;
static {
try {
ELEMENTS_A_1 =
BeamFnApi.Elements.newBuilder()
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionId(ENDPOINT_A.getInstructionId())
.setTransformId(ENDPOINT_A.getTransformId())
.setData(
ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("ABC")))
.concat(
ByteString.copyFrom(
encodeToByteArray(CODER, valueInGlobalWindow("DEF"))))))
.build();
ELEMENTS_A_2 =
BeamFnApi.Elements.newBuilder()
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionId(ENDPOINT_A.getInstructionId())
.setTransformId(ENDPOINT_A.getTransformId())
.setData(
ByteString.copyFrom(
encodeToByteArray(CODER, valueInGlobalWindow("GHI")))))
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionId(ENDPOINT_A.getInstructionId())
.setTransformId(ENDPOINT_A.getTransformId()))
.build();
ELEMENTS_B_1 =
BeamFnApi.Elements.newBuilder()
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionId(ENDPOINT_B.getInstructionId())
.setTransformId(ENDPOINT_B.getTransformId())
.setData(
ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("JKL")))
.concat(
ByteString.copyFrom(
encodeToByteArray(CODER, valueInGlobalWindow("MNO"))))))
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionId(ENDPOINT_B.getInstructionId())
.setTransformId(ENDPOINT_B.getTransformId()))
.build();
} catch (Exception e) {
throw new ExceptionInInitializerError(e);
}
}
@Test(timeout = 10000)
public void testBasicInboundConsumerBehaviour() throws Exception {
CountDownLatch waitForClientToConnect = new CountDownLatch(1);
CountDownLatch receiveAllValuesA = new CountDownLatch(3);
CountDownLatch receiveAllValuesB = new CountDownLatch(2);
Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>();
Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>();
Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver =
new AtomicReference<>();
CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
TestStreams.withOnNext(inboundServerValues::add).build();
Endpoints.ApiServiceDescriptor apiServiceDescriptor =
Endpoints.ApiServiceDescriptor.newBuilder()
.setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
.build();
Server server =
InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
.addService(
new BeamFnDataGrpc.BeamFnDataImplBase() {
@Override
public StreamObserver<BeamFnApi.Elements> data(
StreamObserver<BeamFnApi.Elements> outboundObserver) {
outboundServerObserver.set(outboundObserver);
waitForClientToConnect.countDown();
return inboundServerObserver;
}
})
.build();
server.start();
try {
ManagedChannel channel =
InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
BeamFnDataGrpcClient clientFactory =
new BeamFnDataGrpcClient(
PipelineOptionsFactory.create(),
(Endpoints.ApiServiceDescriptor descriptor) -> channel,
OutboundObserverFactory.trivial());
QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(clientFactory);
InboundDataClient readFutureA =
queueingClient.receive(
apiServiceDescriptor,
ENDPOINT_A,
CODER,
(WindowedValue<String> wv) -> {
inboundValuesA.add(wv);
receiveAllValuesA.countDown();
});
waitForClientToConnect.await();
Future<?> sendElementsFuture =
executor.submit(
() -> {
outboundServerObserver.get().onNext(ELEMENTS_A_1);
// Purposefully transmit some data before the consumer for B is bound showing that
// data is not lost
outboundServerObserver.get().onNext(ELEMENTS_B_1);
});
// This can be compeleted before we get values?
InboundDataClient readFutureB =
queueingClient.receive(
apiServiceDescriptor,
ENDPOINT_B,
CODER,
(WindowedValue<String> wv) -> {
inboundValuesB.add(wv);
receiveAllValuesB.countDown();
});
Future<?> drainElementsFuture =
executor.submit(
() -> {
try {
queueingClient.drainAndBlock();
} catch (Exception e) {
LOG.error("Failed ", e);
fail();
}
});
receiveAllValuesB.await();
assertThat(inboundValuesB, contains(valueInGlobalWindow("JKL"), valueInGlobalWindow("MNO")));
outboundServerObserver.get().onNext(ELEMENTS_A_2);
receiveAllValuesA.await(); // Wait for A's values to be available
assertThat(
inboundValuesA,
contains(
valueInGlobalWindow("ABC"), valueInGlobalWindow("DEF"), valueInGlobalWindow("GHI")));
// Wait for these threads to terminate
sendElementsFuture.get();
drainElementsFuture.get();
} finally {
server.shutdownNow();
}
}
@Test(timeout = 100000)
public void testBundleProcessorThrowsExecutionExceptionWhenUserCodeThrows() throws Exception {
CountDownLatch waitForClientToConnect = new CountDownLatch(1);
// Collection<WindowedValue<String>> inboundValuesA = new ConcurrentLinkedQueue<>();
Collection<WindowedValue<String>> inboundValuesB = new ConcurrentLinkedQueue<>();
Collection<BeamFnApi.Elements> inboundServerValues = new ConcurrentLinkedQueue<>();
AtomicReference<StreamObserver<BeamFnApi.Elements>> outboundServerObserver =
new AtomicReference<>();
CallStreamObserver<BeamFnApi.Elements> inboundServerObserver =
TestStreams.withOnNext(inboundServerValues::add).build();
Endpoints.ApiServiceDescriptor apiServiceDescriptor =
Endpoints.ApiServiceDescriptor.newBuilder()
.setUrl(this.getClass().getName() + "-" + UUID.randomUUID().toString())
.build();
Server server =
InProcessServerBuilder.forName(apiServiceDescriptor.getUrl())
.addService(
new BeamFnDataGrpc.BeamFnDataImplBase() {
@Override
public StreamObserver<BeamFnApi.Elements> data(
StreamObserver<BeamFnApi.Elements> outboundObserver) {
outboundServerObserver.set(outboundObserver);
waitForClientToConnect.countDown();
return inboundServerObserver;
}
})
.build();
server.start();
try {
ManagedChannel channel =
InProcessChannelBuilder.forName(apiServiceDescriptor.getUrl()).build();
BeamFnDataGrpcClient clientFactory =
new BeamFnDataGrpcClient(
PipelineOptionsFactory.create(),
(Endpoints.ApiServiceDescriptor descriptor) -> channel,
OutboundObserverFactory.trivial());
QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(clientFactory);
InboundDataClient readFutureA =
queueingClient.receive(
apiServiceDescriptor,
ENDPOINT_A,
CODER,
(WindowedValue<String> wv) -> {
throw new RuntimeException("Intentionally fail!"); // Error injected here.
});
waitForClientToConnect.await();
Future<?> sendElementsFuture =
executor.submit(
() -> {
outboundServerObserver.get().onNext(ELEMENTS_A_1);
// Purposefully transmit some data before the consumer for B is bound showing that
// data is not lost
outboundServerObserver.get().onNext(ELEMENTS_B_1);
});
InboundDataClient readFutureB =
queueingClient.receive(
apiServiceDescriptor,
ENDPOINT_B,
CODER,
(WindowedValue<String> wv) -> {
inboundValuesB.add(wv);
});
Future<?> drainElementsFuture =
executor.submit(
() -> {
boolean intentionallyFailed = false;
try {
queueingClient.drainAndBlock();
} catch (RuntimeException e) {
intentionallyFailed = true;
} catch (Exception e) {
LOG.error("Unintentional failure", e);
fail();
}
assertTrue(intentionallyFailed);
});
// Fail all InboundObservers if any of the downstream consumers fail.
// This allows the ProcessBundlerHandler to unblock everything and fail properly.
// Wait for these threads to terminate
sendElementsFuture.get();
drainElementsFuture.get();
boolean intentionallyFailedA = false;
try {
readFutureA.awaitCompletion();
} catch (ExecutionException e) {
if (e.getCause() instanceof RuntimeException) {
intentionallyFailedA = true;
}
}
assertTrue(intentionallyFailedA);
boolean intentionallyFailedB = false;
try {
readFutureB.awaitCompletion();
} catch (ExecutionException e) {
if (e.getCause() instanceof RuntimeException) {
intentionallyFailedB = true;
}
}
assertTrue(intentionallyFailedB);
} finally {
server.shutdownNow();
}
}
}