blob: 01ddeea07082c9105e57994f80198203ac026a4c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.fn.data;
import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray;
import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.empty;
import static org.junit.Assert.assertThat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements;
import org.apache.beam.model.fnexecution.v1.BeamFnDataGrpc;
import org.apache.beam.model.pipeline.v1.Endpoints;
import org.apache.beam.runners.dataflow.harness.test.TestStreams;
import org.apache.beam.runners.dataflow.worker.fn.stream.ServerStreamObserverFactory;
import org.apache.beam.runners.fnexecution.GrpcContextHeaderAccessorProvider;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.LengthPrefixCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.InboundDataClient;
import org.apache.beam.sdk.fn.data.LogicalEndpoint;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.BindableService;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.CallOptions;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Channel;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ClientCall;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ClientInterceptor;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Metadata;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Metadata.Key;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.MethodDescriptor;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.ServerInterceptors;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link BeamFnDataGrpcService}. */
@RunWith(JUnit4.class)
@SuppressWarnings("FutureReturnValueIgnored")
public class BeamFnDataGrpcServiceTest {
private static final BeamFnApi.Target TARGET =
BeamFnApi.Target.newBuilder().setPrimitiveTransformReference("888").setName("test").build();
private static final Coder<WindowedValue<String>> CODER =
LengthPrefixCoder.of(WindowedValue.getValueOnlyCoder(StringUtf8Coder.of()));
private static final String DEFAULT_CLIENT = "";
private Server server;
private BeamFnDataGrpcService service;
@Before
public void setUp() throws Exception {
Endpoints.ApiServiceDescriptor descriptor =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(UUID.randomUUID().toString()).build();
PipelineOptions options = PipelineOptionsFactory.create();
service =
new BeamFnDataGrpcService(
options,
descriptor,
ServerStreamObserverFactory.fromOptions(options)::from,
GrpcContextHeaderAccessorProvider.getHeaderAccessor());
server = createServer(service, descriptor);
}
@After
public void tearDown() {
server.shutdownNow();
}
@Test
public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throws Exception {
BlockingQueue<Elements> clientInboundElements = new LinkedBlockingQueue<>();
ExecutorService executorService = Executors.newCachedThreadPool();
CountDownLatch waitForInboundElements = new CountDownLatch(1);
int numberOfClients = 3;
for (int client = 0; client < numberOfClients; ++client) {
executorService.submit(
() -> {
ManagedChannel channel =
InProcessChannelBuilder.forName(service.getApiServiceDescriptor().getUrl()).build();
StreamObserver<BeamFnApi.Elements> outboundObserver =
BeamFnDataGrpc.newStub(channel)
.data(TestStreams.withOnNext(clientInboundElements::add).build());
waitForInboundElements.await();
outboundObserver.onCompleted();
return null;
});
}
for (int i = 0; i < 3; ++i) {
CloseableFnDataReceiver<WindowedValue<String>> consumer =
service
.getDataService(DEFAULT_CLIENT)
.send(LogicalEndpoint.of(Integer.toString(i), TARGET), CODER);
consumer.accept(valueInGlobalWindow("A" + i));
consumer.accept(valueInGlobalWindow("B" + i));
consumer.accept(valueInGlobalWindow("C" + i));
consumer.close();
}
// Specifically copy the elements to a new list so we perform blocking calls on the queue
// to ensure the elements arrive.
List<Elements> copy = new ArrayList<>();
for (int i = 0; i < numberOfClients; ++i) {
copy.add(clientInboundElements.take());
}
assertThat(
copy,
containsInAnyOrder(elementsWithData("0"), elementsWithData("1"), elementsWithData("2")));
waitForInboundElements.countDown();
}
@Test
public void testMessageReceivedByProperClientWhenThereAreMultipleClients() throws Exception {
ConcurrentHashMap<String, LinkedBlockingQueue<Elements>> clientInboundElements =
new ConcurrentHashMap<>();
ExecutorService executorService = Executors.newCachedThreadPool();
CountDownLatch waitForInboundElements = new CountDownLatch(1);
int numberOfClients = 3;
int numberOfMessages = 3;
for (int client = 0; client < numberOfClients; ++client) {
String clientId = Integer.toString(client);
clientInboundElements.put(clientId, new LinkedBlockingQueue<>());
executorService.submit(
() -> {
ManagedChannel channel =
InProcessChannelBuilder.forName(service.getApiServiceDescriptor().getUrl())
.intercept(
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method,
CallOptions callOptions,
Channel next) {
return new SimpleForwardingClientCall<ReqT, RespT>(
next.newCall(method, callOptions)) {
@Override
public void start(
Listener<RespT> responseListener, Metadata headers) {
headers.put(
Key.of("worker_id", Metadata.ASCII_STRING_MARSHALLER),
clientId);
super.start(responseListener, headers);
}
};
}
})
.build();
StreamObserver<BeamFnApi.Elements> outboundObserver =
BeamFnDataGrpc.newStub(channel)
.data(TestStreams.withOnNext(clientInboundElements.get(clientId)::add).build());
waitForInboundElements.await();
outboundObserver.onCompleted();
return null;
});
}
for (int client = 0; client < numberOfClients; ++client) {
for (int i = 0; i < 3; ++i) {
String instructionId = client + "-" + i;
CloseableFnDataReceiver<WindowedValue<String>> consumer =
service
.getDataService(Integer.toString(client))
.send(LogicalEndpoint.of(instructionId, TARGET), CODER);
consumer.accept(valueInGlobalWindow("A" + instructionId));
consumer.accept(valueInGlobalWindow("B" + instructionId));
consumer.accept(valueInGlobalWindow("C" + instructionId));
consumer.close();
}
}
for (int client = 0; client < numberOfClients; ++client) {
// Specifically copy the elements to a new list so we perform blocking calls on the queue
// to ensure the elements arrive.
ArrayList<BeamFnApi.Elements> copy = new ArrayList<>();
for (int i = 0; i < numberOfMessages; ++i) {
copy.add(clientInboundElements.get(Integer.toString(client)).take());
}
assertThat(
copy,
containsInAnyOrder(
elementsWithData(client + "-" + 0),
elementsWithData(client + "-" + 1),
elementsWithData(client + "-" + 2)));
}
waitForInboundElements.countDown();
}
@Test
public void testMultipleClientsSendMessagesAreDirectedToProperConsumers() throws Exception {
LinkedBlockingQueue<BeamFnApi.Elements> clientInboundElements = new LinkedBlockingQueue<>();
ExecutorService executorService = Executors.newCachedThreadPool();
CountDownLatch waitForInboundElements = new CountDownLatch(1);
for (int i = 0; i < 3; ++i) {
String instructionReference = Integer.toString(i);
executorService.submit(
() -> {
ManagedChannel channel =
InProcessChannelBuilder.forName(service.getApiServiceDescriptor().getUrl()).build();
StreamObserver<BeamFnApi.Elements> outboundObserver =
BeamFnDataGrpc.newStub(channel)
.data(TestStreams.withOnNext(clientInboundElements::add).build());
outboundObserver.onNext(elementsWithData(instructionReference));
waitForInboundElements.await();
outboundObserver.onCompleted();
return null;
});
}
List<Collection<WindowedValue<String>>> serverInboundValues = new ArrayList<>();
Collection<InboundDataClient> inboundDataClients = new ArrayList<>();
for (int i = 0; i < 3; ++i) {
BlockingQueue<WindowedValue<String>> serverInboundValue = new LinkedBlockingQueue<>();
serverInboundValues.add(serverInboundValue);
inboundDataClients.add(
service
.getDataService(DEFAULT_CLIENT)
.receive(
LogicalEndpoint.of(Integer.toString(i), TARGET), CODER, serverInboundValue::add));
}
// Waiting for the client provides the necessary synchronization for the elements to arrive.
for (InboundDataClient inboundDataClient : inboundDataClients) {
inboundDataClient.awaitCompletion();
}
waitForInboundElements.countDown();
for (int i = 0; i < 3; ++i) {
assertThat(
serverInboundValues.get(i),
contains(
valueInGlobalWindow("A" + i),
valueInGlobalWindow("B" + i),
valueInGlobalWindow("C" + i)));
}
assertThat(clientInboundElements, empty());
}
private BeamFnApi.Elements elementsWithData(String id) throws CoderException {
return BeamFnApi.Elements.newBuilder()
.addData(
BeamFnApi.Elements.Data.newBuilder()
.setInstructionReference(id)
.setTarget(TARGET)
.setData(
ByteString.copyFrom(encodeToByteArray(CODER, valueInGlobalWindow("A" + id)))
.concat(
ByteString.copyFrom(
encodeToByteArray(CODER, valueInGlobalWindow("B" + id))))
.concat(
ByteString.copyFrom(
encodeToByteArray(CODER, valueInGlobalWindow("C" + id))))))
.addData(BeamFnApi.Elements.Data.newBuilder().setInstructionReference(id).setTarget(TARGET))
.build();
}
private Server createServer(BindableService service, Endpoints.ApiServiceDescriptor descriptor)
throws Exception {
String serverName = descriptor.getUrl();
Server server =
InProcessServerBuilder.forName(serverName)
.addService(
ServerInterceptors.intercept(
service, GrpcContextHeaderAccessorProvider.interceptor()))
.build();
server.start();
return server;
}
}