| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.runners.dataflow.worker.windmill; |
| |
| import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.PrintWriter; |
| import java.io.SequenceInputStream; |
| import java.net.URI; |
| import java.util.ArrayList; |
| import java.util.Collections; |
| import java.util.Deque; |
| import java.util.Enumeration; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.NoSuchElementException; |
| import java.util.Random; |
| import java.util.Set; |
| import java.util.concurrent.BlockingDeque; |
| import java.util.concurrent.CancellationException; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ConcurrentLinkedDeque; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.Executor; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.LinkedBlockingDeque; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicInteger; |
| import java.util.concurrent.atomic.AtomicLong; |
| import java.util.function.Consumer; |
| import java.util.function.Function; |
| import java.util.function.Supplier; |
| import javax.annotation.Nullable; |
| import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils; |
| import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk; |
| import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; |
| import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; |
| import org.apache.beam.sdk.options.PipelineOptionsFactory; |
| import org.apache.beam.sdk.util.BackOff; |
| import org.apache.beam.sdk.util.BackOffUtils; |
| import org.apache.beam.sdk.util.FluentBackoff; |
| import org.apache.beam.sdk.util.Sleeper; |
| import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.CallCredentials; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Channel; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Status; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.StatusRuntimeException; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.auth.MoreCallCredentials; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.inprocess.InProcessChannelBuilder; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.GrpcSslContexts; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.NegotiationType; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.netty.NettyChannelBuilder; |
| import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Verify; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; |
| import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.net.HostAndPort; |
| import org.joda.time.Duration; |
| import org.joda.time.Instant; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** gRPC client for communicating with Windmill Service. */ |
| // Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6562 |
| // Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6564 |
| @SuppressFBWarnings({"JLM_JSR166_UTILCONCURRENT_MONITORENTER", "IS2_INCONSISTENT_SYNC"}) |
| public class GrpcWindmillServer extends WindmillServerStub { |
| private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class); |
| |
| // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively |
| // high. |
| private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300; |
| private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300; |
| // Stream clean close seconds must be set lower than the stream deadline seconds. |
| private static final long DEFAULT_STREAM_CLEAN_CLOSE_SECONDS = 180; |
| |
| private static final Duration MIN_BACKOFF = Duration.millis(1); |
| private static final Duration MAX_BACKOFF = Duration.standardSeconds(30); |
| // Default gRPC streams to 2MB chunks, which has shown to be a large enough chunk size to reduce |
| // per-chunk overhead, and small enough that we can still granularly flow-control. |
| private static final int COMMIT_STREAM_CHUNK_SIZE = 2 << 20; |
| private static final int GET_DATA_STREAM_CHUNK_SIZE = 2 << 20; |
| |
| private static final AtomicLong nextId = new AtomicLong(0); |
| |
| private final StreamingDataflowWorkerOptions options; |
| private final int streamingRpcBatchLimit; |
| private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub> stubList = |
| new ArrayList<>(); |
| private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub> |
| syncStubList = new ArrayList<>(); |
| private WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub = null; |
| private long unaryDeadlineSeconds = DEFAULT_UNARY_RPC_DEADLINE_SECONDS; |
| private ImmutableSet<HostAndPort> endpoints; |
| private int logEveryNStreamFailures = 20; |
| private Duration maxBackoff = MAX_BACKOFF; |
| private final ThrottleTimer getWorkThrottleTimer = new ThrottleTimer(); |
| private final ThrottleTimer getDataThrottleTimer = new ThrottleTimer(); |
| private final ThrottleTimer commitWorkThrottleTimer = new ThrottleTimer(); |
| Random rand = new Random(); |
| |
| private final Set<AbstractWindmillStream<?, ?>> streamRegistry = |
| Collections.newSetFromMap(new ConcurrentHashMap<AbstractWindmillStream<?, ?>, Boolean>()); |
| |
| public GrpcWindmillServer(StreamingDataflowWorkerOptions options) throws IOException { |
| this.options = options; |
| this.streamingRpcBatchLimit = options.getWindmillServiceStreamingRpcBatchLimit(); |
| this.endpoints = ImmutableSet.of(); |
| if (options.getWindmillServiceEndpoint() != null) { |
| Set<HostAndPort> endpoints = new HashSet<>(); |
| for (String endpoint : Splitter.on(',').split(options.getWindmillServiceEndpoint())) { |
| endpoints.add( |
| HostAndPort.fromString(endpoint).withDefaultPort(options.getWindmillServicePort())); |
| } |
| initializeWindmillService(endpoints); |
| } else if (!streamingEngineEnabled() && options.getLocalWindmillHostport() != null) { |
| int portStart = options.getLocalWindmillHostport().lastIndexOf(':'); |
| String endpoint = options.getLocalWindmillHostport().substring(0, portStart); |
| assert ("grpc:localhost".equals(endpoint)); |
| int port = Integer.parseInt(options.getLocalWindmillHostport().substring(portStart + 1)); |
| this.endpoints = ImmutableSet.<HostAndPort>of(HostAndPort.fromParts("localhost", port)); |
| initializeLocalHost(port); |
| } |
| } |
| |
| private GrpcWindmillServer(String name, boolean enableStreamingEngine) { |
| this.options = PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class); |
| this.streamingRpcBatchLimit = Integer.MAX_VALUE; |
| options.setProject("project"); |
| options.setJobId("job"); |
| options.setWorkerId("worker"); |
| if (enableStreamingEngine) { |
| List<String> experiments = this.options.getExperiments(); |
| if (experiments == null) { |
| experiments = new ArrayList<>(); |
| } |
| experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT); |
| options.setExperiments(experiments); |
| } |
| this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel(name))); |
| } |
| |
| private boolean streamingEngineEnabled() { |
| return options.isEnableStreamingEngine(); |
| } |
| |
| @Override |
| public synchronized void setWindmillServiceEndpoints(Set<HostAndPort> endpoints) |
| throws IOException { |
| Preconditions.checkNotNull(endpoints); |
| if (endpoints.equals(this.endpoints)) { |
| // The endpoints are equal don't recreate the stubs. |
| return; |
| } |
| LOG.info("Creating a new windmill stub, endpoints: {}", endpoints); |
| if (this.endpoints != null) { |
| LOG.info("Previous windmill stub endpoints: {}", this.endpoints); |
| } |
| initializeWindmillService(endpoints); |
| } |
| |
| @Override |
| public synchronized boolean isReady() { |
| return !stubList.isEmpty(); |
| } |
| |
| private synchronized void initializeLocalHost(int port) throws IOException { |
| this.logEveryNStreamFailures = 1; |
| this.maxBackoff = Duration.millis(500); |
| this.unaryDeadlineSeconds = 10; // For local testing use a short deadline. |
| Channel channel = localhostChannel(port); |
| if (streamingEngineEnabled()) { |
| this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(channel)); |
| this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(channel)); |
| } else { |
| this.syncApplianceStub = WindmillApplianceGrpc.newBlockingStub(channel); |
| } |
| } |
| |
| /** |
| * Create a wrapper around credentials callback that delegates to the underlying vendored {@link |
| * com.google.auth.RequestMetadataCallback}. Note that this class should override every method |
| * that is not final and not static and call the delegate directly. |
| * |
| * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation |
| * delegate to reduce maintenance burden. |
| */ |
| private static class VendoredRequestMetadataCallbackAdapter |
| implements com.google.auth.RequestMetadataCallback { |
| private final org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback |
| callback; |
| |
| private VendoredRequestMetadataCallbackAdapter( |
| org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback callback) { |
| this.callback = callback; |
| } |
| |
| @Override |
| public void onSuccess(Map<String, List<String>> metadata) { |
| callback.onSuccess(metadata); |
| } |
| |
| @Override |
| public void onFailure(Throwable exception) { |
| callback.onFailure(exception); |
| } |
| } |
| |
| /** |
| * Create a wrapper around credentials that delegates to the underlying {@link |
| * com.google.auth.Credentials}. Note that this class should override every method that is not |
| * final and not static and call the delegate directly. |
| * |
| * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation |
| * delegate to reduce maintenance burden. |
| */ |
| private static class VendoredCredentialsAdapter |
| extends org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.Credentials { |
| private final com.google.auth.Credentials credentials; |
| |
| private VendoredCredentialsAdapter(com.google.auth.Credentials credentials) { |
| this.credentials = credentials; |
| } |
| |
| @Override |
| public String getAuthenticationType() { |
| return credentials.getAuthenticationType(); |
| } |
| |
| @Override |
| public Map<String, List<String>> getRequestMetadata() throws IOException { |
| return credentials.getRequestMetadata(); |
| } |
| |
| @Override |
| public void getRequestMetadata( |
| final URI uri, |
| Executor executor, |
| final org.apache.beam.vendor.grpc.v1p21p0.com.google.auth.RequestMetadataCallback |
| callback) { |
| credentials.getRequestMetadata( |
| uri, executor, new VendoredRequestMetadataCallbackAdapter(callback)); |
| } |
| |
| @Override |
| public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException { |
| return credentials.getRequestMetadata(uri); |
| } |
| |
| @Override |
| public boolean hasRequestMetadata() { |
| return credentials.hasRequestMetadata(); |
| } |
| |
| @Override |
| public boolean hasRequestMetadataOnly() { |
| return credentials.hasRequestMetadataOnly(); |
| } |
| |
| @Override |
| public void refresh() throws IOException { |
| credentials.refresh(); |
| } |
| } |
| |
| private synchronized void initializeWindmillService(Set<HostAndPort> endpoints) |
| throws IOException { |
| LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", endpoints); |
| this.stubList.clear(); |
| this.syncStubList.clear(); |
| this.endpoints = ImmutableSet.<HostAndPort>copyOf(endpoints); |
| for (HostAndPort endpoint : this.endpoints) { |
| if ("localhost".equals(endpoint.getHost())) { |
| initializeLocalHost(endpoint.getPort()); |
| } else { |
| CallCredentials creds = |
| MoreCallCredentials.from(new VendoredCredentialsAdapter(options.getGcpCredential())); |
| this.stubList.add( |
| CloudWindmillServiceV1Alpha1Grpc.newStub(remoteChannel(endpoint)) |
| .withCallCredentials(creds)); |
| this.syncStubList.add( |
| CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(remoteChannel(endpoint)) |
| .withCallCredentials(creds)); |
| } |
| } |
| } |
| |
| @VisibleForTesting |
| static GrpcWindmillServer newTestInstance(String name, boolean enableStreamingEngine) { |
| return new GrpcWindmillServer(name, enableStreamingEngine); |
| } |
| |
| private Channel inProcessChannel(String name) { |
| return InProcessChannelBuilder.forName(name).directExecutor().build(); |
| } |
| |
| private Channel localhostChannel(int port) { |
| return NettyChannelBuilder.forAddress("localhost", port) |
| .maxInboundMessageSize(java.lang.Integer.MAX_VALUE) |
| .negotiationType(NegotiationType.PLAINTEXT) |
| .build(); |
| } |
| |
| private Channel remoteChannel(HostAndPort endpoint) throws IOException { |
| return NettyChannelBuilder.forAddress(endpoint.getHost(), endpoint.getPort()) |
| .maxInboundMessageSize(java.lang.Integer.MAX_VALUE) |
| .negotiationType(NegotiationType.TLS) |
| // Set ciphers(null) to not use GCM, which is disabled for Dataflow |
| // due to it being horribly slow. |
| .sslContext(GrpcSslContexts.forClient().ciphers(null).build()) |
| .build(); |
| } |
| |
| private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() { |
| if (stubList.isEmpty()) { |
| throw new RuntimeException("windmillServiceEndpoint has not been set"); |
| } |
| if (stubList.size() == 1) { |
| return stubList.get(0); |
| } |
| return stubList.get(rand.nextInt(stubList.size())); |
| } |
| |
| private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub |
| syncStub() { |
| if (syncStubList.isEmpty()) { |
| throw new RuntimeException("windmillServiceEndpoint has not been set"); |
| } |
| if (syncStubList.size() == 1) { |
| return syncStubList.get(0); |
| } |
| return syncStubList.get(rand.nextInt(syncStubList.size())); |
| } |
| |
| @Override |
| public void appendSummaryHtml(PrintWriter writer) { |
| writer.write("Active Streams:<br>"); |
| for (AbstractWindmillStream<?, ?> stream : streamRegistry) { |
| stream.appendSummaryHtml(writer); |
| writer.write("<br>"); |
| } |
| } |
| |
| // Configure backoff to retry calls forever, with a maximum sane retry interval. |
| private BackOff grpcBackoff() { |
| return FluentBackoff.DEFAULT |
| .withInitialBackoff(MIN_BACKOFF) |
| .withMaxBackoff(maxBackoff) |
| .backoff(); |
| } |
| |
| private <ResponseT> ResponseT callWithBackoff(Supplier<ResponseT> function) { |
| BackOff backoff = grpcBackoff(); |
| int rpcErrors = 0; |
| while (true) { |
| try { |
| return function.get(); |
| } catch (StatusRuntimeException e) { |
| try { |
| if (++rpcErrors % 20 == 0) { |
| LOG.warn( |
| "Many exceptions calling gRPC. Last exception: {} with status {}", |
| e, |
| e.getStatus()); |
| } |
| if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) { |
| throw new WindmillServerStub.RpcException(e); |
| } |
| } catch (IOException | InterruptedException i) { |
| if (i instanceof InterruptedException) { |
| Thread.currentThread().interrupt(); |
| } |
| WindmillServerStub.RpcException rpcException = new WindmillServerStub.RpcException(e); |
| rpcException.addSuppressed(i); |
| throw rpcException; |
| } |
| } |
| } |
| } |
| |
| @Override |
| public GetWorkResponse getWork(GetWorkRequest request) { |
| if (syncApplianceStub == null) { |
| return callWithBackoff( |
| () -> |
| syncStub() |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .getWork( |
| request |
| .toBuilder() |
| .setJobId(options.getJobId()) |
| .setProjectId(options.getProject()) |
| .setWorkerId(options.getWorkerId()) |
| .build())); |
| } else { |
| return callWithBackoff( |
| () -> |
| syncApplianceStub |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .getWork(request)); |
| } |
| } |
| |
| @Override |
| public GetDataResponse getData(GetDataRequest request) { |
| if (syncApplianceStub == null) { |
| return callWithBackoff( |
| () -> |
| syncStub() |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .getData( |
| request |
| .toBuilder() |
| .setJobId(options.getJobId()) |
| .setProjectId(options.getProject()) |
| .build())); |
| } else { |
| return callWithBackoff( |
| () -> |
| syncApplianceStub |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .getData(request)); |
| } |
| } |
| |
| @Override |
| public CommitWorkResponse commitWork(CommitWorkRequest request) { |
| if (syncApplianceStub == null) { |
| return callWithBackoff( |
| () -> |
| syncStub() |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .commitWork( |
| request |
| .toBuilder() |
| .setJobId(options.getJobId()) |
| .setProjectId(options.getProject()) |
| .build())); |
| } else { |
| return callWithBackoff( |
| () -> |
| syncApplianceStub |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .commitWork(request)); |
| } |
| } |
| |
| @Override |
| public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { |
| return new GrpcGetWorkStream( |
| GetWorkRequest.newBuilder(request) |
| .setJobId(options.getJobId()) |
| .setProjectId(options.getProject()) |
| .setWorkerId(options.getWorkerId()) |
| .build(), |
| receiver); |
| } |
| |
| @Override |
| public GetDataStream getDataStream() { |
| return new GrpcGetDataStream(); |
| } |
| |
| @Override |
| public CommitWorkStream commitWorkStream() { |
| return new GrpcCommitWorkStream(); |
| } |
| |
| @Override |
| public GetConfigResponse getConfig(GetConfigRequest request) { |
| if (syncApplianceStub == null) { |
| throw new RpcException( |
| new UnsupportedOperationException("GetConfig not supported with windmill service.")); |
| } else { |
| return callWithBackoff( |
| () -> |
| syncApplianceStub |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .getConfig(request)); |
| } |
| } |
| |
| @Override |
| public ReportStatsResponse reportStats(ReportStatsRequest request) { |
| if (syncApplianceStub == null) { |
| throw new RpcException( |
| new UnsupportedOperationException("ReportStats not supported with windmill service.")); |
| } else { |
| return callWithBackoff( |
| () -> |
| syncApplianceStub |
| .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS) |
| .reportStats(request)); |
| } |
| } |
| |
| @Override |
| public long getAndResetThrottleTime() { |
| return getWorkThrottleTimer.getAndResetThrottleTime() |
| + getDataThrottleTimer.getAndResetThrottleTime() |
| + commitWorkThrottleTimer.getAndResetThrottleTime(); |
| } |
| |
| private JobHeader makeHeader() { |
| return JobHeader.newBuilder() |
| .setJobId(options.getJobId()) |
| .setProjectId(options.getProject()) |
| .setWorkerId(options.getWorkerId()) |
| .build(); |
| } |
| |
| /** Returns a long that is unique to this process. */ |
| private static long uniqueId() { |
| return nextId.incrementAndGet(); |
| } |
| |
| /** |
| * Base class for persistent streams connecting to Windmill. |
| * |
| * <p>This class handles the underlying gRPC StreamObservers, and automatically reconnects the |
| * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on |
| * a broken stream. |
| * |
| * <p>Subclasses should override onResponse to handle responses from the server, and onNewStream |
| * to perform any work that must be done when a new stream is created, such as sending headers or |
| * retrying requests. |
| * |
| * <p>send and startStream should not be called from onResponse; use executor() instead. |
| * |
| * <p>Synchronization on this is used to synchronize the gRpc stream state and internal data |
| * structures. Since grpc channel operations may block, synchronization on this stream may also |
| * block. This is generally not a problem since streams are used in a single-threaded manner. |
| * However some accessors used for status page and other debugging need to take care not to |
| * require synchronizing on this. |
| */ |
| private abstract class AbstractWindmillStream<RequestT, ResponseT> implements WindmillStream { |
| private final StreamObserverFactory streamObserverFactory = StreamObserverFactory.direct(); |
| private final Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory; |
| private final Executor executor = Executors.newSingleThreadExecutor(); |
| |
| // The following should be protected by synchronizing on this, except for |
| // the atomics which may be read atomically for status pages. |
| private StreamObserver<RequestT> requestObserver; |
| private final AtomicLong startTimeMs = new AtomicLong(); |
| private final AtomicInteger errorCount = new AtomicInteger(); |
| private final BackOff backoff = grpcBackoff(); |
| private final AtomicLong sleepUntil = new AtomicLong(); |
| protected final AtomicBoolean clientClosed = new AtomicBoolean(); |
| private final CountDownLatch finishLatch = new CountDownLatch(1); |
| |
| protected AbstractWindmillStream( |
| Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory) { |
| this.clientFactory = clientFactory; |
| } |
| |
| /** Called on each response from the server */ |
| protected abstract void onResponse(ResponseT response); |
| /** Called when a new underlying stream to the server has been opened. */ |
| protected abstract void onNewStream(); |
| /** Returns whether there are any pending requests that should be retried on a stream break. */ |
| protected abstract boolean hasPendingRequests(); |
| /** |
| * Called when the stream is throttled due to resource exhausted errors. Will be called for each |
| * resource exhausted error not just the first. onResponse() must stop throttling on reciept of |
| * the first good message. |
| */ |
| protected abstract void startThrottleTimer(); |
| /** Send a request to the server. */ |
| protected final synchronized void send(RequestT request) { |
| requestObserver.onNext(request); |
| } |
| |
| /** Starts the underlying stream. */ |
| protected final void startStream() { |
| // Add the stream to the registry after it has been fully constructed. |
| streamRegistry.add(this); |
| BackOff backoff = grpcBackoff(); |
| while (true) { |
| try { |
| synchronized (this) { |
| startTimeMs.set(Instant.now().getMillis()); |
| requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver()); |
| onNewStream(); |
| if (clientClosed.get()) { |
| close(); |
| } |
| return; |
| } |
| } catch (Exception e) { |
| LOG.error("Failed to create new stream, retrying: ", e); |
| try { |
| long sleep = backoff.nextBackOffMillis(); |
| sleepUntil.set(Instant.now().getMillis() + sleep); |
| Thread.sleep(sleep); |
| } catch (InterruptedException i) { |
| // Keep trying to create the stream. |
| } catch (IOException i) { |
| // Ignore. |
| } |
| } |
| } |
| } |
| |
| protected final Executor executor() { |
| return executor; |
| } |
| |
| // Care is taken that synchronization on this is unnecessary for all status page information. |
| // Blocking sends are made beneath this stream object's lock which could block status page |
| // rendering. |
| public final void appendSummaryHtml(PrintWriter writer) { |
| appendSpecificHtml(writer); |
| if (errorCount.get() > 0) { |
| writer.format(", %d errors", errorCount.get()); |
| } |
| if (clientClosed.get()) { |
| writer.write(", client closed"); |
| } |
| long sleepLeft = sleepUntil.get() - Instant.now().getMillis(); |
| if (sleepLeft > 0) { |
| writer.format(", %dms backoff remaining", sleepLeft); |
| } |
| writer.format(", current stream is %dms old", Instant.now().getMillis() - startTimeMs.get()); |
| } |
| |
| // Don't require synchronization on stream, see the appendSummaryHtml comment. |
| protected abstract void appendSpecificHtml(PrintWriter writer); |
| |
| private class ResponseObserver implements StreamObserver<ResponseT> { |
| @Override |
| public void onNext(ResponseT response) { |
| try { |
| backoff.reset(); |
| } catch (IOException e) { |
| // Ignore. |
| } |
| onResponse(response); |
| } |
| |
| @Override |
| public void onError(Throwable t) { |
| onStreamFinished(t); |
| } |
| |
| @Override |
| public void onCompleted() { |
| onStreamFinished(null); |
| } |
| |
| private void onStreamFinished(@Nullable Throwable t) { |
| synchronized (this) { |
| if (clientClosed.get() && !hasPendingRequests()) { |
| streamRegistry.remove(AbstractWindmillStream.this); |
| finishLatch.countDown(); |
| return; |
| } |
| } |
| if (t != null) { |
| Status status = null; |
| if (t instanceof StatusRuntimeException) { |
| status = ((StatusRuntimeException) t).getStatus(); |
| } |
| if (errorCount.incrementAndGet() % logEveryNStreamFailures == 0) { |
| LOG.warn( |
| "{} streaming Windmill RPC errors for a stream, last was: {} with status {}", |
| errorCount.get(), |
| t.toString(), |
| status); |
| } |
| // If the stream was stopped due to a resource exhausted error then we are throttled. |
| if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) { |
| startThrottleTimer(); |
| } |
| |
| try { |
| long sleep = backoff.nextBackOffMillis(); |
| sleepUntil.set(Instant.now().getMillis() + sleep); |
| Thread.sleep(sleep); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| } catch (IOException e) { |
| // Ignore. |
| } |
| } |
| executor.execute(AbstractWindmillStream.this::startStream); |
| } |
| } |
| |
| @Override |
| public final synchronized void close() { |
| // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream. |
| clientClosed.set(true); |
| requestObserver.onCompleted(); |
| } |
| |
| @Override |
| public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException { |
| return finishLatch.await(time, unit); |
| } |
| |
| @Override |
| public final void closeAfterDefaultTimeout() throws InterruptedException { |
| if (!finishLatch.await(DEFAULT_STREAM_CLEAN_CLOSE_SECONDS, TimeUnit.SECONDS)) { |
| // If the stream did not close due to error in the specified amount of time, half-close |
| // the stream cleanly. |
| close(); |
| } |
| } |
| |
| @Override |
| public final Instant startTime() { |
| return new Instant(startTimeMs.get()); |
| } |
| } |
| |
| private class GrpcGetWorkStream |
| extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk> |
| implements GetWorkStream { |
| private final GetWorkRequest request; |
| private final WorkItemReceiver receiver; |
| private final Map<Long, WorkItemBuffer> buffers = new ConcurrentHashMap<>(); |
| private final AtomicLong inflightMessages = new AtomicLong(); |
| private final AtomicLong inflightBytes = new AtomicLong(); |
| |
| private GrpcGetWorkStream(GetWorkRequest request, WorkItemReceiver receiver) { |
| super( |
| responseObserver -> |
| stub() |
| .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) |
| .getWorkStream(responseObserver)); |
| this.request = request; |
| this.receiver = receiver; |
| startStream(); |
| } |
| |
| @Override |
| protected synchronized void onNewStream() { |
| buffers.clear(); |
| inflightMessages.set(request.getMaxItems()); |
| inflightBytes.set(request.getMaxBytes()); |
| send(StreamingGetWorkRequest.newBuilder().setRequest(request).build()); |
| } |
| |
| @Override |
| protected boolean hasPendingRequests() { |
| return false; |
| } |
| |
| @Override |
| public void appendSpecificHtml(PrintWriter writer) { |
| // Number of buffers is same as distict workers that sent work on this stream. |
| writer.format( |
| "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed", |
| buffers.size(), inflightMessages.intValue(), inflightBytes.intValue()); |
| } |
| |
| @Override |
| protected void onResponse(StreamingGetWorkResponseChunk chunk) { |
| getWorkThrottleTimer.stop(); |
| long id = chunk.getStreamId(); |
| |
| WorkItemBuffer buffer = buffers.computeIfAbsent(id, (Long l) -> new WorkItemBuffer()); |
| buffer.append(chunk); |
| |
| if (chunk.getRemainingBytesForWorkItem() == 0) { |
| long size = buffer.bufferedSize(); |
| buffer.runAndReset(); |
| |
| // Record the fact that there are now fewer outstanding messages and bytes on the stream. |
| long numInflight = inflightMessages.decrementAndGet(); |
| long bytesInflight = inflightBytes.addAndGet(-size); |
| |
| // If the outstanding items or bytes limit has gotten too low, top both off with a |
| // GetWorkExtension. The goal is to keep the limits relatively close to their maximum |
| // values without sending too many extension requests. |
| if (numInflight < request.getMaxItems() / 2 || bytesInflight < request.getMaxBytes() / 2) { |
| long moreItems = request.getMaxItems() - numInflight; |
| long moreBytes = request.getMaxBytes() - bytesInflight; |
| inflightMessages.getAndAdd(moreItems); |
| inflightBytes.getAndAdd(moreBytes); |
| final StreamingGetWorkRequest extension = |
| StreamingGetWorkRequest.newBuilder() |
| .setRequestExtension( |
| StreamingGetWorkRequestExtension.newBuilder() |
| .setMaxItems(moreItems) |
| .setMaxBytes(moreBytes)) |
| .build(); |
| executor() |
| .execute( |
| () -> { |
| try { |
| send(extension); |
| } catch (IllegalStateException e) { |
| // Stream was closed. |
| } |
| }); |
| } |
| } |
| } |
| |
| @Override |
| protected void startThrottleTimer() { |
| getWorkThrottleTimer.start(); |
| } |
| |
| private class WorkItemBuffer { |
| private String computation; |
| private Instant inputDataWatermark; |
| private Instant synchronizedProcessingTime; |
| private ByteString data = ByteString.EMPTY; |
| private long bufferedSize = 0; |
| |
| private void setMetadata(Windmill.ComputationWorkItemMetadata metadata) { |
| this.computation = metadata.getComputationId(); |
| this.inputDataWatermark = |
| WindmillTimeUtils.windmillToHarnessWatermark(metadata.getInputDataWatermark()); |
| this.synchronizedProcessingTime = |
| WindmillTimeUtils.windmillToHarnessWatermark( |
| metadata.getDependentRealtimeInputWatermark()); |
| } |
| |
| public void append(StreamingGetWorkResponseChunk chunk) { |
| if (chunk.hasComputationMetadata()) { |
| setMetadata(chunk.getComputationMetadata()); |
| } |
| |
| this.data = data.concat(chunk.getSerializedWorkItem()); |
| this.bufferedSize += chunk.getSerializedWorkItem().size(); |
| } |
| |
| public long bufferedSize() { |
| return bufferedSize; |
| } |
| |
| public void runAndReset() { |
| try { |
| receiver.receiveWork( |
| computation, |
| inputDataWatermark, |
| synchronizedProcessingTime, |
| Windmill.WorkItem.parseFrom(data.newInput())); |
| } catch (IOException e) { |
| LOG.error("Failed to parse work item from stream: ", e); |
| } |
| data = ByteString.EMPTY; |
| bufferedSize = 0; |
| } |
| } |
| } |
| |
| private class GrpcGetDataStream |
| extends AbstractWindmillStream<StreamingGetDataRequest, StreamingGetDataResponse> |
| implements GetDataStream { |
| private class QueuedRequest { |
| public QueuedRequest(String computation, KeyedGetDataRequest request) { |
| this.id = uniqueId(); |
| this.globalDataRequest = null; |
| this.dataRequest = |
| ComputationGetDataRequest.newBuilder() |
| .setComputationId(computation) |
| .addRequests(request) |
| .build(); |
| this.byteSize = this.dataRequest.getSerializedSize(); |
| } |
| |
| public QueuedRequest(GlobalDataRequest request) { |
| this.id = uniqueId(); |
| this.globalDataRequest = request; |
| this.dataRequest = null; |
| this.byteSize = this.globalDataRequest.getSerializedSize(); |
| } |
| |
| final long id; |
| final long byteSize; |
| final GlobalDataRequest globalDataRequest; |
| final ComputationGetDataRequest dataRequest; |
| AppendableInputStream responseStream = null; |
| } |
| |
| private class QueuedBatch { |
| public QueuedBatch() {} |
| |
| final List<QueuedRequest> requests = new ArrayList<>(); |
| long byteSize = 0; |
| boolean finalized = false; |
| final CountDownLatch sent = new CountDownLatch(1); |
| }; |
| |
| private final Deque<QueuedBatch> batches = new ConcurrentLinkedDeque<>(); |
| private final Map<Long, AppendableInputStream> pending = new ConcurrentHashMap<>(); |
| |
| @Override |
| public void appendSpecificHtml(PrintWriter writer) { |
| writer.format( |
| "GetDataStream: %d pending on-wire, %d queued batches", pending.size(), batches.size()); |
| } |
| |
| GrpcGetDataStream() { |
| super( |
| responseObserver -> |
| stub() |
| .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) |
| .getDataStream(responseObserver)); |
| startStream(); |
| } |
| |
| @Override |
| protected synchronized void onNewStream() { |
| send(StreamingGetDataRequest.newBuilder().setHeader(makeHeader()).build()); |
| |
| if (clientClosed.get()) { |
| // We rely on close only occurring after all methods on the stream have returned. |
| // Since the requestKeyedData and requestGlobalData methods are blocking this |
| // means there should be no pending requests. |
| Verify.verify(!hasPendingRequests()); |
| } else { |
| for (AppendableInputStream responseStream : pending.values()) { |
| responseStream.cancel(); |
| } |
| } |
| } |
| |
| @Override |
| protected boolean hasPendingRequests() { |
| return !pending.isEmpty() || !batches.isEmpty(); |
| } |
| |
| @Override |
| protected void onResponse(StreamingGetDataResponse chunk) { |
| Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount()); |
| Preconditions.checkArgument( |
| chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1); |
| getDataThrottleTimer.stop(); |
| |
| for (int i = 0; i < chunk.getRequestIdCount(); ++i) { |
| AppendableInputStream responseStream = pending.get(chunk.getRequestId(i)); |
| Verify.verify(responseStream != null, "No pending response stream"); |
| responseStream.append(chunk.getSerializedResponse(i).newInput()); |
| if (chunk.getRemainingBytesForResponse() == 0) { |
| responseStream.complete(); |
| } |
| } |
| } |
| |
| @Override |
| protected void startThrottleTimer() { |
| getDataThrottleTimer.start(); |
| } |
| |
| @Override |
| public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) { |
| return issueRequest(new QueuedRequest(computation, request), KeyedGetDataResponse::parseFrom); |
| } |
| |
| @Override |
| public GlobalData requestGlobalData(GlobalDataRequest request) { |
| return issueRequest(new QueuedRequest(request), GlobalData::parseFrom); |
| } |
| |
| @Override |
| public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) { |
| long builderBytes = 0; |
| StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); |
| for (Map.Entry<String, List<KeyedGetDataRequest>> entry : active.entrySet()) { |
| for (KeyedGetDataRequest request : entry.getValue()) { |
| // Calculate the bytes with some overhead for proto encoding. |
| long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10; |
| if (builderBytes > 0 |
| && (builderBytes + bytes > GET_DATA_STREAM_CHUNK_SIZE |
| || builder.getRequestIdCount() >= streamingRpcBatchLimit)) { |
| send(builder.build()); |
| builderBytes = 0; |
| builder.clear(); |
| } |
| builderBytes += bytes; |
| builder.addStateRequest( |
| ComputationGetDataRequest.newBuilder() |
| .setComputationId(entry.getKey()) |
| .addRequests(request)); |
| } |
| } |
| if (builderBytes > 0) { |
| send(builder.build()); |
| } |
| } |
| |
| private <ResponseT> ResponseT issueRequest(QueuedRequest request, ParseFn<ResponseT> parseFn) { |
| while (true) { |
| request.responseStream = new AppendableInputStream(); |
| try { |
| queueRequestAndWait(request); |
| return parseFn.parse(request.responseStream); |
| } catch (CancellationException e) { |
| // Retry issuing the request since the response stream was cancelled. |
| continue; |
| } catch (IOException e) { |
| LOG.error("Parsing GetData response failed: ", e); |
| continue; |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| throw new RuntimeException(e); |
| } finally { |
| pending.remove(request.id); |
| } |
| } |
| } |
| |
| private void queueRequestAndWait(QueuedRequest request) throws InterruptedException { |
| QueuedBatch batch; |
| boolean responsibleForSend = false; |
| CountDownLatch waitForSendLatch = null; |
| synchronized (batches) { |
| batch = batches.isEmpty() ? null : batches.getLast(); |
| if (batch == null |
| || batch.finalized |
| || batch.requests.size() >= streamingRpcBatchLimit |
| || batch.byteSize + request.byteSize > GET_DATA_STREAM_CHUNK_SIZE) { |
| if (batch != null) { |
| waitForSendLatch = batch.sent; |
| } |
| batch = new QueuedBatch(); |
| batches.addLast(batch); |
| responsibleForSend = true; |
| } |
| batch.requests.add(request); |
| batch.byteSize += request.byteSize; |
| } |
| if (responsibleForSend) { |
| if (waitForSendLatch == null) { |
| // If there was not a previous batch wait a little while to improve |
| // batching. |
| Thread.sleep(1); |
| } else { |
| waitForSendLatch.await(); |
| } |
| // Finalize the batch so that no additional requests will be added. Leave the batch in the |
| // queue so that a subsequent batch will wait for it's completion. |
| synchronized (batches) { |
| Verify.verify(batch == batches.peekFirst()); |
| batch.finalized = true; |
| } |
| sendBatch(batch.requests); |
| synchronized (batches) { |
| Verify.verify(batch == batches.pollFirst()); |
| } |
| // Notify all waiters with requests in this batch as well as the sender |
| // of the next batch (if one exists). |
| batch.sent.countDown(); |
| } else { |
| // Wait for this batch to be sent before parsing the response. |
| batch.sent.await(); |
| } |
| } |
| |
| private void sendBatch(List<QueuedRequest> requests) { |
| StreamingGetDataRequest batchedRequest = flushToBatch(requests); |
| synchronized (this) { |
| // Synchronization of pending inserts is necessary with send to ensure duplicates are not |
| // sent on stream reconnect. |
| for (QueuedRequest request : requests) { |
| Verify.verify(pending.put(request.id, request.responseStream) == null); |
| } |
| try { |
| send(batchedRequest); |
| } catch (IllegalStateException e) { |
| // The stream broke before this call went through; onNewStream will retry the fetch. |
| } |
| } |
| } |
| |
| private StreamingGetDataRequest flushToBatch(List<QueuedRequest> requests) { |
| // Put all global data requests first because there is only a single repeated field for |
| // request ids and the initial ids correspond to global data requests if they are present. |
| requests.sort( |
| (QueuedRequest r1, QueuedRequest r2) -> { |
| boolean r1gd = r1.globalDataRequest != null; |
| boolean r2gd = r2.globalDataRequest != null; |
| return r1gd == r2gd ? 0 : (r1gd ? -1 : 1); |
| }); |
| StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); |
| for (QueuedRequest request : requests) { |
| builder.addRequestId(request.id); |
| if (request.globalDataRequest == null) { |
| builder.addStateRequest(request.dataRequest); |
| } else { |
| builder.addGlobalDataRequest(request.globalDataRequest); |
| } |
| } |
| return builder.build(); |
| } |
| } |
| |
| private class GrpcCommitWorkStream |
| extends AbstractWindmillStream<StreamingCommitWorkRequest, StreamingCommitResponse> |
| implements CommitWorkStream { |
| private class PendingRequest { |
| private final String computation; |
| private final WorkItemCommitRequest request; |
| private final Consumer<CommitStatus> onDone; |
| |
| PendingRequest( |
| String computation, WorkItemCommitRequest request, Consumer<CommitStatus> onDone) { |
| this.computation = computation; |
| this.request = request; |
| this.onDone = onDone; |
| } |
| |
| long getBytes() { |
| return (long) request.getSerializedSize() + computation.length(); |
| } |
| } |
| |
| private final Map<Long, PendingRequest> pending = new ConcurrentHashMap<>(); |
| |
| private class Batcher { |
| long queuedBytes = 0; |
| Map<Long, PendingRequest> queue = new HashMap<>(); |
| |
| boolean canAccept(PendingRequest request) { |
| return queue.isEmpty() |
| || (queue.size() < streamingRpcBatchLimit |
| && (request.getBytes() + queuedBytes) < COMMIT_STREAM_CHUNK_SIZE); |
| } |
| |
| void add(long id, PendingRequest request) { |
| assert (canAccept(request)); |
| queuedBytes += request.getBytes(); |
| queue.put(id, request); |
| } |
| |
| void flush() { |
| flushInternal(queue); |
| queuedBytes = 0; |
| } |
| } |
| |
| private final Batcher batcher = new Batcher(); |
| |
| GrpcCommitWorkStream() { |
| super( |
| responseObserver -> |
| stub() |
| .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS) |
| .commitWorkStream(responseObserver)); |
| startStream(); |
| } |
| |
| @Override |
| public void appendSpecificHtml(PrintWriter writer) { |
| writer.format("CommitWorkStream: %d pending", pending.size()); |
| } |
| |
| @Override |
| protected synchronized void onNewStream() { |
| send(StreamingCommitWorkRequest.newBuilder().setHeader(makeHeader()).build()); |
| Batcher resendBatcher = new Batcher(); |
| for (Map.Entry<Long, PendingRequest> entry : pending.entrySet()) { |
| if (!resendBatcher.canAccept(entry.getValue())) { |
| resendBatcher.flush(); |
| } |
| resendBatcher.add(entry.getKey(), entry.getValue()); |
| } |
| resendBatcher.flush(); |
| } |
| |
| @Override |
| protected boolean hasPendingRequests() { |
| return !pending.isEmpty(); |
| } |
| |
| @Override |
| protected void onResponse(StreamingCommitResponse response) { |
| commitWorkThrottleTimer.stop(); |
| |
| for (int i = 0; i < response.getRequestIdCount(); ++i) { |
| long requestId = response.getRequestId(i); |
| PendingRequest done = pending.remove(requestId); |
| if (done == null) { |
| LOG.error("Got unknown commit request ID: {}", requestId); |
| } else { |
| done.onDone.accept( |
| (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK); |
| } |
| } |
| } |
| |
| @Override |
| protected void startThrottleTimer() { |
| commitWorkThrottleTimer.start(); |
| } |
| |
| @Override |
| public boolean commitWorkItem( |
| String computation, WorkItemCommitRequest commitRequest, Consumer<CommitStatus> onDone) { |
| PendingRequest request = new PendingRequest(computation, commitRequest, onDone); |
| if (!batcher.canAccept(request)) { |
| return false; |
| } |
| batcher.add(uniqueId(), request); |
| return true; |
| } |
| |
| @Override |
| public void flush() { |
| batcher.flush(); |
| } |
| |
| private final void flushInternal(Map<Long, PendingRequest> requests) { |
| if (requests.isEmpty()) { |
| return; |
| } |
| if (requests.size() == 1) { |
| Map.Entry<Long, PendingRequest> elem = requests.entrySet().iterator().next(); |
| if (elem.getValue().request.getSerializedSize() > COMMIT_STREAM_CHUNK_SIZE) { |
| issueMultiChunkRequest(elem.getKey(), elem.getValue()); |
| } else { |
| issueSingleRequest(elem.getKey(), elem.getValue()); |
| } |
| } else { |
| issueBatchedRequest(requests); |
| } |
| requests.clear(); |
| } |
| |
| private void issueSingleRequest(final long id, PendingRequest pendingRequest) { |
| StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); |
| requestBuilder |
| .addCommitChunkBuilder() |
| .setComputationId(pendingRequest.computation) |
| .setRequestId(id) |
| .setShardingKey(pendingRequest.request.getShardingKey()) |
| .setSerializedWorkItemCommit(pendingRequest.request.toByteString()); |
| StreamingCommitWorkRequest chunk = requestBuilder.build(); |
| try { |
| synchronized (this) { |
| pending.put(id, pendingRequest); |
| send(chunk); |
| } |
| } catch (IllegalStateException e) { |
| // Stream was broken, request will be retried when stream is reopened. |
| } |
| } |
| |
| private void issueBatchedRequest(Map<Long, PendingRequest> requests) { |
| StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder(); |
| String lastComputation = null; |
| for (Map.Entry<Long, PendingRequest> entry : requests.entrySet()) { |
| PendingRequest request = entry.getValue(); |
| StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder(); |
| if (lastComputation == null || !lastComputation.equals(request.computation)) { |
| chunkBuilder.setComputationId(request.computation); |
| lastComputation = request.computation; |
| } |
| chunkBuilder.setRequestId(entry.getKey()); |
| chunkBuilder.setShardingKey(request.request.getShardingKey()); |
| chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString()); |
| } |
| StreamingCommitWorkRequest request = requestBuilder.build(); |
| try { |
| synchronized (this) { |
| pending.putAll(requests); |
| send(request); |
| } |
| } catch (IllegalStateException e) { |
| // Stream was broken, request will be retried when stream is reopened. |
| } |
| } |
| |
| private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) { |
| Preconditions.checkNotNull(pendingRequest.computation); |
| final ByteString serializedCommit = pendingRequest.request.toByteString(); |
| |
| synchronized (this) { |
| pending.put(id, pendingRequest); |
| for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) { |
| int end = i + COMMIT_STREAM_CHUNK_SIZE; |
| ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size())); |
| |
| StreamingCommitRequestChunk.Builder chunkBuilder = |
| StreamingCommitRequestChunk.newBuilder() |
| .setRequestId(id) |
| .setSerializedWorkItemCommit(chunk) |
| .setComputationId(pendingRequest.computation) |
| .setShardingKey(pendingRequest.request.getShardingKey()); |
| int remaining = serializedCommit.size() - end; |
| if (remaining > 0) { |
| chunkBuilder.setRemainingBytesForWorkItem(remaining); |
| } |
| |
| StreamingCommitWorkRequest requestChunk = |
| StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build(); |
| try { |
| send(requestChunk); |
| } catch (IllegalStateException e) { |
| // Stream was broken, request will be retried when stream is reopened. |
| break; |
| } |
| } |
| } |
| } |
| } |
| |
| @FunctionalInterface |
| private interface ParseFn<ResponseT> { |
| ResponseT parse(InputStream input) throws IOException; |
| } |
| |
| /** An InputStream that can be dynamically extended with additional InputStreams. */ |
| @SuppressWarnings("JdkObsolete") |
| private static class AppendableInputStream extends InputStream { |
| private static final InputStream POISON_PILL = ByteString.EMPTY.newInput(); |
| private final AtomicBoolean cancelled = new AtomicBoolean(false); |
| private final AtomicBoolean complete = new AtomicBoolean(false); |
| private final BlockingDeque<InputStream> queue = new LinkedBlockingDeque<>(10); |
| private final InputStream stream = |
| new SequenceInputStream( |
| new Enumeration<InputStream>() { |
| InputStream current = ByteString.EMPTY.newInput(); |
| |
| @Override |
| public boolean hasMoreElements() { |
| if (current != null) { |
| return true; |
| } |
| |
| try { |
| current = queue.take(); |
| if (current != POISON_PILL) { |
| return true; |
| } |
| if (cancelled.get()) { |
| throw new CancellationException(); |
| } |
| if (complete.get()) { |
| return false; |
| } |
| throw new IllegalStateException("Got poison pill but stream is not done."); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| throw new CancellationException(); |
| } |
| } |
| |
| @Override |
| public InputStream nextElement() { |
| if (!hasMoreElements()) { |
| throw new NoSuchElementException(); |
| } |
| InputStream next = current; |
| current = null; |
| return next; |
| } |
| }); |
| |
| /** Appends a new InputStream to the tail of this stream. */ |
| public synchronized void append(InputStream chunk) { |
| try { |
| queue.put(chunk); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| } |
| } |
| |
| /** Cancels the stream. Future calls to InputStream methods will throw CancellationException. */ |
| public synchronized void cancel() { |
| cancelled.set(true); |
| try { |
| // Put the poison pill at the head of the queue to cancel as quickly as possible. |
| queue.clear(); |
| queue.putFirst(POISON_PILL); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| } |
| } |
| |
| /** Signals that no new InputStreams will be added to this stream. */ |
| public synchronized void complete() { |
| complete.set(true); |
| try { |
| queue.put(POISON_PILL); |
| } catch (InterruptedException e) { |
| Thread.currentThread().interrupt(); |
| } |
| } |
| |
| @Override |
| public int read() throws IOException { |
| if (cancelled.get()) { |
| throw new CancellationException(); |
| } |
| return stream.read(); |
| } |
| |
| @Override |
| public int read(byte[] b, int off, int len) throws IOException { |
| if (cancelled.get()) { |
| throw new CancellationException(); |
| } |
| return stream.read(b, off, len); |
| } |
| |
| @Override |
| public int available() throws IOException { |
| if (cancelled.get()) { |
| throw new CancellationException(); |
| } |
| return stream.available(); |
| } |
| |
| @Override |
| public void close() throws IOException { |
| stream.close(); |
| } |
| } |
| |
| /** |
| * A stopwatch used to track the amount of time spent throttled due to Resource Exhausted errors. |
| * Throttle time is cumulative for all three rpcs types but not for all streams. So if GetWork and |
| * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both |
| * blocked for x totalTime will be x. All methods are thread safe. |
| */ |
| private static class ThrottleTimer { |
| |
| // This is -1 if not currently being throttled or the time in |
| // milliseconds when throttling for this type started. |
| private long startTime = -1; |
| // This is the collected total throttle times since the last poll. Throttle times are |
| // reported as a delta so this is cleared whenever it gets reported. |
| private long totalTime = 0; |
| |
| /** |
| * Starts the timer if it has not been started and does nothing if it has already been started. |
| */ |
| public synchronized void start() { |
| if (!throttled()) { // This timer is not started yet so start it now. |
| startTime = Instant.now().getMillis(); |
| } |
| } |
| |
| /** Stops the timer if it has been started and does nothing if it has not been started. */ |
| public synchronized void stop() { |
| if (throttled()) { // This timer has been started already so stop it now. |
| totalTime += Instant.now().getMillis() - startTime; |
| startTime = -1; |
| } |
| } |
| |
| /** Returns if the specified type is currently being throttled */ |
| public synchronized boolean throttled() { |
| return startTime != -1; |
| } |
| |
| /** Returns the combined total of all throttle times and resets those times to 0. */ |
| public synchronized long getAndResetThrottleTime() { |
| if (throttled()) { |
| stop(); |
| start(); |
| } |
| long toReturn = totalTime; |
| totalTime = 0; |
| return toReturn; |
| } |
| } |
| } |