blob: 09e46b0dc41bd09ae8ae7249cac74b4aecf0dd66 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.direct.portable.artifact;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import java.io.File;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessServerBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.Uninterruptibles;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link LocalFileSystemArtifactStagerService}. */
@RunWith(JUnit4.class)
public class LocalFileSystemArtifactStagerServiceTest {
@Rule public TemporaryFolder temporaryFolder = new TemporaryFolder();
private ArtifactStagingServiceGrpc.ArtifactStagingServiceStub stub;
private LocalFileSystemArtifactStagerService stager;
private Server server;
@Before
public void setup() throws Exception {
stager = LocalFileSystemArtifactStagerService.forRootDirectory(temporaryFolder.newFolder());
server =
InProcessServerBuilder.forName("fs_stager")
.directExecutor()
.addService(stager)
.build()
.start();
stub =
ArtifactStagingServiceGrpc.newStub(
InProcessChannelBuilder.forName("fs_stager").usePlaintext().build());
}
@After
public void teardown() {
server.shutdownNow();
}
@Test
public void singleDataPutArtifactSucceeds() throws Exception {
byte[] data = "foo-bar-baz".getBytes(UTF_8);
RecordingStreamObserver<ArtifactApi.PutArtifactResponse> responseObserver =
new RecordingStreamObserver<>();
StreamObserver<ArtifactApi.PutArtifactRequest> requestObserver =
stub.putArtifact(responseObserver);
String name = "my-artifact";
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setMetadata(
ArtifactApi.PutArtifactMetadata.newBuilder()
.setMetadata(ArtifactApi.ArtifactMetadata.newBuilder().setName(name).build())
.setStagingSessionToken("token")
.build())
.build());
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder().setData(ByteString.copyFrom(data)).build())
.build());
requestObserver.onCompleted();
responseObserver.awaitTerminalState();
File staged = stager.getLocation().getArtifactFile(name);
assertThat(staged.exists(), is(true));
ByteBuffer buf = ByteBuffer.allocate(data.length);
new FileInputStream(staged).getChannel().read(buf);
Assert.assertArrayEquals(data, buf.array());
}
@Test
public void multiPartPutArtifactSucceeds() throws Exception {
byte[] partOne = "foo-".getBytes(UTF_8);
byte[] partTwo = "bar-".getBytes(UTF_8);
byte[] partThree = "baz".getBytes(UTF_8);
RecordingStreamObserver<ArtifactApi.PutArtifactResponse> responseObserver =
new RecordingStreamObserver<>();
StreamObserver<ArtifactApi.PutArtifactRequest> requestObserver =
stub.putArtifact(responseObserver);
String name = "my-artifact";
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setMetadata(
ArtifactApi.PutArtifactMetadata.newBuilder()
.setMetadata(ArtifactApi.ArtifactMetadata.newBuilder().setName(name).build())
.setStagingSessionToken("token")
.build())
.build());
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder()
.setData(ByteString.copyFrom(partOne))
.build())
.build());
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder()
.setData(ByteString.copyFrom(partTwo))
.build())
.build());
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder()
.setData(ByteString.copyFrom(partThree))
.build())
.build());
requestObserver.onCompleted();
responseObserver.awaitTerminalState();
File staged = stager.getLocation().getArtifactFile(name);
assertThat(staged.exists(), is(true));
ByteBuffer buf = ByteBuffer.allocate("foo-bar-baz".length());
new FileInputStream(staged).getChannel().read(buf);
Assert.assertArrayEquals("foo-bar-baz".getBytes(UTF_8), buf.array());
}
@Test
public void putArtifactBeforeNameFails() {
byte[] data = "foo-".getBytes(UTF_8);
RecordingStreamObserver<ArtifactApi.PutArtifactResponse> responseObserver =
new RecordingStreamObserver<>();
StreamObserver<ArtifactApi.PutArtifactRequest> requestObserver =
stub.putArtifact(responseObserver);
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder().setData(ByteString.copyFrom(data)).build())
.build());
responseObserver.awaitTerminalState();
assertThat(responseObserver.error, Matchers.not(Matchers.nullValue()));
}
@Test
public void putArtifactWithNoContentFails() {
RecordingStreamObserver<ArtifactApi.PutArtifactResponse> responseObserver =
new RecordingStreamObserver<>();
StreamObserver<ArtifactApi.PutArtifactRequest> requestObserver =
stub.putArtifact(responseObserver);
requestObserver.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(ArtifactApi.ArtifactChunk.getDefaultInstance())
.build());
responseObserver.awaitTerminalState();
assertThat(responseObserver.error, Matchers.not(Matchers.nullValue()));
}
@Test
public void commitManifestWithAllArtifactsSucceeds() {
ArtifactApi.ArtifactMetadata firstArtifact =
stageBytes("first-artifact", "foo, bar, baz, quux".getBytes(UTF_8));
ArtifactApi.ArtifactMetadata secondArtifact =
stageBytes("second-artifact", "spam, ham, eggs".getBytes(UTF_8));
ArtifactApi.Manifest manifest =
ArtifactApi.Manifest.newBuilder()
.addArtifact(firstArtifact)
.addArtifact(secondArtifact)
.build();
RecordingStreamObserver<ArtifactApi.CommitManifestResponse> commitResponseObserver =
new RecordingStreamObserver<>();
stub.commitManifest(
ArtifactApi.CommitManifestRequest.newBuilder().setManifest(manifest).build(),
commitResponseObserver);
commitResponseObserver.awaitTerminalState();
assertThat(commitResponseObserver.completed, is(true));
assertThat(commitResponseObserver.responses, Matchers.hasSize(1));
ArtifactApi.CommitManifestResponse commitResponse = commitResponseObserver.responses.get(0);
assertThat(commitResponse.getRetrievalToken(), Matchers.not(Matchers.nullValue()));
}
@Test
public void commitManifestWithMissingArtifactFails() {
ArtifactApi.ArtifactMetadata firstArtifact =
stageBytes("first-artifact", "foo, bar, baz, quux".getBytes(UTF_8));
ArtifactApi.ArtifactMetadata absentArtifact =
ArtifactApi.ArtifactMetadata.newBuilder().setName("absent").build();
ArtifactApi.Manifest manifest =
ArtifactApi.Manifest.newBuilder()
.addArtifact(firstArtifact)
.addArtifact(absentArtifact)
.build();
RecordingStreamObserver<ArtifactApi.CommitManifestResponse> commitResponseObserver =
new RecordingStreamObserver<>();
stub.commitManifest(
ArtifactApi.CommitManifestRequest.newBuilder().setManifest(manifest).build(),
commitResponseObserver);
commitResponseObserver.awaitTerminalState();
assertThat(commitResponseObserver.error, Matchers.not(Matchers.nullValue()));
}
private ArtifactApi.ArtifactMetadata stageBytes(String name, byte[] bytes) {
StreamObserver<ArtifactApi.PutArtifactRequest> requests =
stub.putArtifact(new RecordingStreamObserver<>());
requests.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setMetadata(
ArtifactApi.PutArtifactMetadata.newBuilder()
.setMetadata(ArtifactApi.ArtifactMetadata.newBuilder().setName(name).build())
.setStagingSessionToken("token")
.build())
.build());
requests.onNext(
ArtifactApi.PutArtifactRequest.newBuilder()
.setData(
ArtifactApi.ArtifactChunk.newBuilder().setData(ByteString.copyFrom(bytes)).build())
.build());
requests.onCompleted();
return ArtifactApi.ArtifactMetadata.newBuilder().setName(name).build();
}
private static class RecordingStreamObserver<T> implements StreamObserver<T> {
private List<T> responses = new ArrayList<>();
@Nullable private Throwable error = null;
private boolean completed = false;
@Override
public void onNext(T value) {
failIfTerminal();
responses.add(value);
}
@Override
public void onError(Throwable t) {
failIfTerminal();
error = t;
}
@Override
public void onCompleted() {
failIfTerminal();
completed = true;
}
private boolean isTerminal() {
return error != null || completed;
}
private void failIfTerminal() {
if (isTerminal()) {
Assert.fail(
String.format(
"Should have terminated after entering a terminal state: completed %s, error %s",
completed, error));
}
}
void awaitTerminalState() {
while (!isTerminal()) {
Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS);
}
}
}
}