blob: 93a90bb12296b64580482b0d3f7c59409b177360 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.fnexecution.control;
import static org.hamcrest.Matchers.contains;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.util.Collections;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionRequest;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.InstructionResponse;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.runners.core.construction.Environments;
import org.apache.beam.runners.core.construction.JavaReadViaImpulse;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.InProcessServerFactory;
import org.apache.beam.runners.fnexecution.data.GrpcDataService;
import org.apache.beam.runners.fnexecution.environment.EnvironmentFactory;
import org.apache.beam.runners.fnexecution.environment.RemoteEnvironment;
import org.apache.beam.runners.fnexecution.state.GrpcStateService;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.fn.IdGenerators;
import org.apache.beam.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
/** Tests for {@link SingleEnvironmentInstanceJobBundleFactory}. */
@RunWith(JUnit4.class)
public class SingleEnvironmentInstanceJobBundleFactoryTest {
@Mock private EnvironmentFactory environmentFactory;
@Mock private InstructionRequestHandler instructionRequestHandler;
private ExecutorService executor = Executors.newCachedThreadPool();
private GrpcFnServer<GrpcDataService> dataServer;
private GrpcFnServer<GrpcStateService> stateServer;
private JobBundleFactory factory;
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
when(instructionRequestHandler.handle(any(InstructionRequest.class)))
.thenReturn(CompletableFuture.completedFuture(InstructionResponse.getDefaultInstance()));
InProcessServerFactory serverFactory = InProcessServerFactory.create();
dataServer =
GrpcFnServer.allocatePortAndCreateFor(
GrpcDataService.create(executor, OutboundObserverFactory.serverDirect()),
serverFactory);
stateServer = GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory);
factory =
SingleEnvironmentInstanceJobBundleFactory.create(
environmentFactory, dataServer, stateServer, IdGenerators.incrementingLongs());
}
@After
public void teardown() throws Exception {
try (AutoCloseable data = dataServer;
AutoCloseable state = stateServer) {
executor.shutdownNow();
}
}
@Test
public void closeShutsDownEnvironments() throws Exception {
Pipeline p = Pipeline.create();
p.apply("Create", Create.of(1, 2, 3));
p.replaceAll(Collections.singletonList(JavaReadViaImpulse.boundedOverride()));
ExecutableStage stage =
GreedyPipelineFuser.fuse(PipelineTranslation.toProto(p)).getFusedStages().stream()
.findFirst()
.get();
RemoteEnvironment remoteEnv = mock(RemoteEnvironment.class);
when(remoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
when(environmentFactory.createEnvironment(stage.getEnvironment())).thenReturn(remoteEnv);
factory.forStage(stage);
factory.close();
verify(remoteEnv).close();
}
@Test
public void closeShutsDownEnvironmentsWhenSomeFail() throws Exception {
Pipeline p = Pipeline.create();
p.apply("Create", Create.of(1, 2, 3));
p.replaceAll(Collections.singletonList(JavaReadViaImpulse.boundedOverride()));
ExecutableStage firstEnvStage =
GreedyPipelineFuser.fuse(PipelineTranslation.toProto(p)).getFusedStages().stream()
.findFirst()
.get();
ExecutableStagePayload basePayload =
ExecutableStagePayload.parseFrom(firstEnvStage.toPTransform("foo").getSpec().getPayload());
Environment secondEnv = Environments.createDockerEnvironment("second_env");
ExecutableStage secondEnvStage =
ExecutableStage.fromPayload(basePayload.toBuilder().setEnvironment(secondEnv).build());
Environment thirdEnv = Environments.createDockerEnvironment("third_env");
ExecutableStage thirdEnvStage =
ExecutableStage.fromPayload(basePayload.toBuilder().setEnvironment(thirdEnv).build());
RemoteEnvironment firstRemoteEnv = mock(RemoteEnvironment.class, "First Remote Env");
RemoteEnvironment secondRemoteEnv = mock(RemoteEnvironment.class, "Second Remote Env");
RemoteEnvironment thirdRemoteEnv = mock(RemoteEnvironment.class, "Third Remote Env");
when(environmentFactory.createEnvironment(firstEnvStage.getEnvironment()))
.thenReturn(firstRemoteEnv);
when(environmentFactory.createEnvironment(secondEnvStage.getEnvironment()))
.thenReturn(secondRemoteEnv);
when(environmentFactory.createEnvironment(thirdEnvStage.getEnvironment()))
.thenReturn(thirdRemoteEnv);
when(firstRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
when(secondRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
when(thirdRemoteEnv.getInstructionRequestHandler()).thenReturn(instructionRequestHandler);
factory.forStage(firstEnvStage);
factory.forStage(secondEnvStage);
factory.forStage(thirdEnvStage);
IllegalStateException firstException = new IllegalStateException("first stage");
doThrow(firstException).when(firstRemoteEnv).close();
IllegalStateException thirdException = new IllegalStateException("third stage");
doThrow(thirdException).when(thirdRemoteEnv).close();
try {
factory.close();
fail("Factory close should have thrown");
} catch (IllegalStateException expected) {
if (expected.equals(firstException)) {
assertThat(ImmutableList.copyOf(expected.getSuppressed()), contains(thirdException));
} else if (expected.equals(thirdException)) {
assertThat(ImmutableList.copyOf(expected.getSuppressed()), contains(firstException));
} else {
throw expected;
}
verify(firstRemoteEnv).close();
verify(secondRemoteEnv).close();
verify(thirdRemoteEnv).close();
}
}
}