blob: 92d25aeefdc363ff7d1c8cd128236a9e31942540 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.direct.portable.artifact;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactChunk;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactMetadata;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetArtifactRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.GetManifestResponse;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest;
import org.apache.beam.model.jobmanagement.v1.ArtifactRetrievalServiceGrpc;
import org.apache.beam.runners.core.construction.ArtifactServiceStager;
import org.apache.beam.runners.core.construction.ArtifactServiceStager.StagedFile;
import org.apache.beam.runners.fnexecution.GrpcFnServer;
import org.apache.beam.runners.fnexecution.InProcessServerFactory;
import org.apache.beam.runners.fnexecution.ServerFactory;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for {@link LocalFileSystemArtifactRetrievalService}. */
@RunWith(JUnit4.class)
public class LocalFileSystemArtifactRetrievalServiceTest {
@Rule public TemporaryFolder tmp = new TemporaryFolder();
private File root;
private ServerFactory serverFactory = InProcessServerFactory.create();
private GrpcFnServer<LocalFileSystemArtifactStagerService> stagerServer;
private GrpcFnServer<LocalFileSystemArtifactRetrievalService> retrievalServer;
private ArtifactRetrievalServiceGrpc.ArtifactRetrievalServiceStub retrievalStub;
@Before
public void setup() throws Exception {
root = tmp.newFolder();
stagerServer =
GrpcFnServer.allocatePortAndCreateFor(
LocalFileSystemArtifactStagerService.forRootDirectory(root), serverFactory);
}
@After
public void teardown() throws Exception {
stagerServer.close();
retrievalServer.close();
}
@Test
public void retrieveManifest() throws Exception {
Map<String, byte[]> artifacts = new HashMap<>();
artifacts.put("foo", "bar, baz, quux".getBytes(UTF_8));
artifacts.put("spam", new byte[] {127, -22, 5});
stageAndCreateRetrievalService(artifacts);
final AtomicReference<Manifest> returned = new AtomicReference<>();
final CountDownLatch completed = new CountDownLatch(1);
retrievalStub.getManifest(
GetManifestRequest.getDefaultInstance(),
new StreamObserver<GetManifestResponse>() {
@Override
public void onNext(GetManifestResponse value) {
returned.set(value.getManifest());
}
@Override
public void onError(Throwable t) {
completed.countDown();
}
@Override
public void onCompleted() {
completed.countDown();
}
});
completed.await();
assertThat(returned.get(), not(nullValue()));
List<String> manifestArtifacts = new ArrayList<>();
for (ArtifactMetadata artifactMetadata : returned.get().getArtifactList()) {
manifestArtifacts.add(artifactMetadata.getName());
}
assertThat(manifestArtifacts, containsInAnyOrder("foo", "spam"));
}
@Test
public void retrieveArtifact() throws Exception {
Map<String, byte[]> artifacts = new HashMap<>();
byte[] fooContents = "bar, baz, quux".getBytes(UTF_8);
artifacts.put("foo", fooContents);
byte[] spamContents = {127, -22, 5};
artifacts.put("spam", spamContents);
stageAndCreateRetrievalService(artifacts);
final CountDownLatch completed = new CountDownLatch(2);
ByteArrayOutputStream returnedFooBytes = new ByteArrayOutputStream();
retrievalStub.getArtifact(
GetArtifactRequest.newBuilder().setName("foo").build(),
new MultimapChunkAppender(returnedFooBytes, completed));
ByteArrayOutputStream returnedSpamBytes = new ByteArrayOutputStream();
retrievalStub.getArtifact(
GetArtifactRequest.newBuilder().setName("spam").build(),
new MultimapChunkAppender(returnedSpamBytes, completed));
completed.await();
assertArrayEquals(fooContents, returnedFooBytes.toByteArray());
assertArrayEquals(spamContents, returnedSpamBytes.toByteArray());
}
@Test
public void retrieveArtifactNotPresent() throws Exception {
stageAndCreateRetrievalService(
Collections.singletonMap("foo", "bar, baz, quux".getBytes(UTF_8)));
final CountDownLatch completed = new CountDownLatch(1);
final AtomicReference<Throwable> thrown = new AtomicReference<>();
retrievalStub.getArtifact(
GetArtifactRequest.newBuilder().setName("spam").build(),
new StreamObserver<ArtifactChunk>() {
@Override
public void onNext(ArtifactChunk value) {
fail(
"Should never receive an "
+ ArtifactChunk.class.getSimpleName()
+ " for a nonexistent artifact");
}
@Override
public void onError(Throwable t) {
thrown.set(t);
completed.countDown();
}
@Override
public void onCompleted() {
completed.countDown();
}
});
completed.await();
assertThat(thrown.get(), not(nullValue()));
assertThat(thrown.get().getMessage(), containsString("No such artifact"));
assertThat(thrown.get().getMessage(), containsString("spam"));
}
private void stageAndCreateRetrievalService(Map<String, byte[]> artifacts) throws Exception {
List<StagedFile> artifactFiles = new ArrayList<>();
for (Map.Entry<String, byte[]> artifact : artifacts.entrySet()) {
File artifactFile = tmp.newFile(artifact.getKey());
Files.write(artifactFile.toPath(), artifact.getValue());
artifactFiles.add(StagedFile.of(artifactFile, artifactFile.getName()));
}
String stagingSessionToken = "token";
ArtifactServiceStager stager =
ArtifactServiceStager.overChannel(
InProcessChannelBuilder.forName(stagerServer.getApiServiceDescriptor().getUrl())
.build());
stager.stage(stagingSessionToken, artifactFiles);
retrievalServer =
GrpcFnServer.allocatePortAndCreateFor(
LocalFileSystemArtifactRetrievalService.forRootDirectory(root), serverFactory);
retrievalStub =
ArtifactRetrievalServiceGrpc.newStub(
InProcessChannelBuilder.forName(retrievalServer.getApiServiceDescriptor().getUrl())
.build());
}
private static class MultimapChunkAppender implements StreamObserver<ArtifactChunk> {
private final ByteArrayOutputStream target;
private final CountDownLatch completed;
private MultimapChunkAppender(ByteArrayOutputStream target, CountDownLatch completed) {
this.target = target;
this.completed = completed;
}
@Override
public void onNext(ArtifactChunk value) {
try {
target.write(value.getData().toByteArray());
} catch (IOException e) {
// This should never happen
throw new AssertionError(e);
}
}
@Override
public void onError(Throwable t) {
completed.countDown();
}
@Override
public void onCompleted() {
completed.countDown();
}
}
}