blob: 095b549d76c3d60eb08e043ff5498128178a2671 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import com.google.auto.value.AutoValue;
import com.google.common.io.BaseEncoding;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import io.grpc.Channel;
import io.grpc.stub.StreamObserver;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.security.MessageDigest;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactChunk;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.ArtifactMetadata;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.CommitManifestRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.Manifest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.PutArtifactRequest;
import org.apache.beam.model.jobmanagement.v1.ArtifactApi.PutArtifactResponse;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc.ArtifactStagingServiceBlockingStub;
import org.apache.beam.model.jobmanagement.v1.ArtifactStagingServiceGrpc.ArtifactStagingServiceStub;
/** A client to stage files on an {@link ArtifactStagingServiceGrpc ArtifactService}. */
public class ArtifactServiceStager {
// 2 MB per file-request
private static final int DEFAULT_BUFFER_SIZE = 2 * 1024 * 1024;
public static ArtifactServiceStager overChannel(Channel channel) {
return overChannel(channel, DEFAULT_BUFFER_SIZE);
}
/**
* Create a new ArtifactServiceStager with the specified buffer size. Useful for testing
* multi-part uploads.
*
* @param bufferSize the maximum size of the artifact chunk, in bytes.
*/
static ArtifactServiceStager overChannel(Channel channel, int bufferSize) {
return new ArtifactServiceStager(channel, bufferSize);
}
private final int bufferSize;
private final ArtifactStagingServiceStub stub;
private final ArtifactStagingServiceBlockingStub blockingStub;
private final ListeningExecutorService executorService =
MoreExecutors.listeningDecorator(Executors.newCachedThreadPool());
private ArtifactServiceStager(Channel channel, int bufferSize) {
this.stub = ArtifactStagingServiceGrpc.newStub(channel);
this.blockingStub = ArtifactStagingServiceGrpc.newBlockingStub(channel);
this.bufferSize = bufferSize;
}
public void stage(Iterable<File> files) throws IOException, InterruptedException {
final Map<File, ListenableFuture<ArtifactMetadata>> futures = new HashMap<>();
for (File file : files) {
futures.put(file, executorService.submit(new StagingCallable(file)));
}
ListenableFuture<StagingResult> stagingResult =
Futures.whenAllComplete(futures.values()).call(new ExtractStagingResultsCallable(futures));
stageManifest(stagingResult);
}
private void stageManifest(ListenableFuture<StagingResult> stagingFuture)
throws InterruptedException {
try {
StagingResult stagingResult = stagingFuture.get();
if (stagingResult.isSuccess()) {
Manifest manifest =
Manifest.newBuilder().addAllArtifact(stagingResult.getMetadata()).build();
blockingStub.commitManifest(
CommitManifestRequest.newBuilder().setManifest(manifest).build());
} else {
RuntimeException failure =
new RuntimeException(
String.format(
"Failed to stage %s files: %s",
stagingResult.getFailures().size(), stagingResult.getFailures().keySet()));
for (Throwable t : stagingResult.getFailures().values()) {
failure.addSuppressed(t);
}
throw failure;
}
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
private class StagingCallable implements Callable<ArtifactMetadata> {
private final File file;
private StagingCallable(File file) {
this.file = file;
}
@Override
public ArtifactMetadata call() throws Exception {
// TODO: Add Retries
PutArtifactResponseObserver responseObserver = new PutArtifactResponseObserver();
StreamObserver<PutArtifactRequest> requestObserver = stub.putArtifact(responseObserver);
ArtifactMetadata metadata = ArtifactMetadata.newBuilder().setName(file.getName()).build();
requestObserver.onNext(PutArtifactRequest.newBuilder().setMetadata(metadata).build());
MessageDigest md5Digest = MessageDigest.getInstance("MD5");
FileChannel channel = new FileInputStream(file).getChannel();
ByteBuffer readBuffer = ByteBuffer.allocate(bufferSize);
while (!responseObserver.isTerminal() && channel.position() < channel.size()) {
readBuffer.clear();
channel.read(readBuffer);
readBuffer.flip();
md5Digest.update(readBuffer);
readBuffer.rewind();
PutArtifactRequest request =
PutArtifactRequest.newBuilder()
.setData(
ArtifactChunk.newBuilder().setData(ByteString.copyFrom(readBuffer)).build())
.build();
requestObserver.onNext(request);
}
requestObserver.onCompleted();
responseObserver.awaitTermination();
if (responseObserver.err.get() != null) {
throw new RuntimeException(responseObserver.err.get());
}
return metadata.toBuilder().setMd5(BaseEncoding.base64().encode(md5Digest.digest())).build();
}
private class PutArtifactResponseObserver implements StreamObserver<PutArtifactResponse> {
private final CountDownLatch completed = new CountDownLatch(1);
private final AtomicReference<Throwable> err = new AtomicReference<>(null);
@Override
public void onNext(PutArtifactResponse value) {}
@Override
public void onError(Throwable t) {
err.set(t);
completed.countDown();
throw new RuntimeException(t);
}
@Override
public void onCompleted() {
completed.countDown();
}
public boolean isTerminal() {
return completed.getCount() == 0;
}
public void awaitTermination() throws InterruptedException {
completed.await();
}
}
}
private static class ExtractStagingResultsCallable implements Callable<StagingResult> {
private final Map<File, ListenableFuture<ArtifactMetadata>> futures;
private ExtractStagingResultsCallable(
Map<File, ListenableFuture<ArtifactMetadata>> futures) {
this.futures = futures;
}
@Override
public StagingResult call() throws Exception {
Set<ArtifactMetadata> metadata = new HashSet<>();
Map<File, Throwable> failures = new HashMap<>();
for (Entry<File, ListenableFuture<ArtifactMetadata>> stagedFileResult : futures.entrySet()) {
try {
metadata.add(stagedFileResult.getValue().get());
} catch (ExecutionException ee) {
failures.put(stagedFileResult.getKey(), ee.getCause());
} catch (InterruptedException ie) {
throw new AssertionError(
"This should never happen. " + "All of the futures are complete by construction", ie);
}
}
if (failures.isEmpty()) {
return StagingResult.success(metadata);
} else {
return StagingResult.failure(failures);
}
}
}
@AutoValue
abstract static class StagingResult {
static StagingResult success(Set<ArtifactMetadata> metadata) {
return new AutoValue_ArtifactServiceStager_StagingResult(
metadata, Collections.<File, Throwable>emptyMap());
}
static StagingResult failure(Map<File, Throwable> failures) {
return new AutoValue_ArtifactServiceStager_StagingResult(
null, failures);
}
boolean isSuccess() {
return getMetadata() != null;
}
@Nullable
abstract Set<ArtifactMetadata> getMetadata();
abstract Map<File, Throwable> getFailures();
}
}