/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker.windmill;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.io.SequenceInputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.WindmillTimeUtils;
import org.apache.beam.runners.dataflow.worker.options.StreamingDataflowWorkerOptions;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitStatus;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.CommitWorkResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetConfigResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GetWorkResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.JobHeader;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ReportStatsResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitRequestChunk;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingCommitWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetDataResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequest;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkRequestExtension;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.StreamingGetWorkResponseChunk;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest;
import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.CallCredentials;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Channel;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.Status;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.StatusRuntimeException;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.auth.MoreCallCredentials;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.inprocess.InProcessChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.netty.GrpcSslContexts;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.netty.NegotiationType;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.netty.NettyChannelBuilder;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Verify;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.net.HostAndPort;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** gRPC client for communicating with Windmill Service. */
// Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6562
// Very likely real potential for bugs - https://issues.apache.org/jira/browse/BEAM-6564
@SuppressFBWarnings({"JLM_JSR166_UTILCONCURRENT_MONITORENTER", "IS2_INCONSISTENT_SYNC"})
public class GrpcWindmillServer extends WindmillServerStub {
  private static final Logger LOG = LoggerFactory.getLogger(GrpcWindmillServer.class);

  // If a connection cannot be established, gRPC will fail fast so this deadline can be relatively
  // high.
  private static final long DEFAULT_UNARY_RPC_DEADLINE_SECONDS = 300;
  private static final long DEFAULT_STREAM_RPC_DEADLINE_SECONDS = 300;
  // Stream clean close seconds must be set lower than the stream deadline seconds.
  private static final long DEFAULT_STREAM_CLEAN_CLOSE_SECONDS = 180;

  private static final Duration MIN_BACKOFF = Duration.millis(1);
  private static final Duration MAX_BACKOFF = Duration.standardSeconds(30);
  // Internal gRPC batch size is 64KB, so pick something slightly smaller to account for other
  // fields in the commit.
  private static final int COMMIT_STREAM_CHUNK_SIZE = 63 * 1024;
  private static final int GET_DATA_STREAM_CHUNK_SIZE = 63 * 1024;

  private static final AtomicLong nextId = new AtomicLong(0);

  private final StreamingDataflowWorkerOptions options;
  private final int streamingRpcBatchLimit;
  private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub> stubList =
      new ArrayList<>();
  private final List<CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub>
      syncStubList = new ArrayList<>();
  private WindmillApplianceGrpc.WindmillApplianceBlockingStub syncApplianceStub = null;
  private long unaryDeadlineSeconds = DEFAULT_UNARY_RPC_DEADLINE_SECONDS;
  private ImmutableSet<HostAndPort> endpoints;
  private int logEveryNStreamFailures = 20;
  private Duration maxBackoff = MAX_BACKOFF;
  private final ThrottleTimer getWorkThrottleTimer = new ThrottleTimer();
  private final ThrottleTimer getDataThrottleTimer = new ThrottleTimer();
  private final ThrottleTimer commitWorkThrottleTimer = new ThrottleTimer();
  Random rand = new Random();

  private final Set<AbstractWindmillStream<?, ?>> streamRegistry =
      Collections.newSetFromMap(new ConcurrentHashMap<AbstractWindmillStream<?, ?>, Boolean>());

  public GrpcWindmillServer(StreamingDataflowWorkerOptions options) throws IOException {
    this.options = options;
    this.streamingRpcBatchLimit = options.getWindmillServiceStreamingRpcBatchLimit();
    this.endpoints = ImmutableSet.of();
    if (options.getWindmillServiceEndpoint() != null) {
      Set<HostAndPort> endpoints = new HashSet<>();
      for (String endpoint : Splitter.on(',').split(options.getWindmillServiceEndpoint())) {
        endpoints.add(
            HostAndPort.fromString(endpoint).withDefaultPort(options.getWindmillServicePort()));
      }
      initializeWindmillService(endpoints);
    } else if (!streamingEngineEnabled() && options.getLocalWindmillHostport() != null) {
      int portStart = options.getLocalWindmillHostport().lastIndexOf(':');
      String endpoint = options.getLocalWindmillHostport().substring(0, portStart);
      assert ("grpc:localhost".equals(endpoint));
      int port = Integer.parseInt(options.getLocalWindmillHostport().substring(portStart + 1));
      this.endpoints = ImmutableSet.<HostAndPort>of(HostAndPort.fromParts("localhost", port));
      initializeLocalHost(port);
    }
  }

  private GrpcWindmillServer(String name, boolean enableStreamingEngine) {
    this.options = PipelineOptionsFactory.create().as(StreamingDataflowWorkerOptions.class);
    this.streamingRpcBatchLimit = Integer.MAX_VALUE;
    options.setProject("project");
    options.setJobId("job");
    options.setWorkerId("worker");
    if (enableStreamingEngine) {
      List<String> experiments = this.options.getExperiments();
      if (experiments == null) {
        experiments = new ArrayList<>();
      }
      experiments.add(GcpOptions.STREAMING_ENGINE_EXPERIMENT);
      options.setExperiments(experiments);
    }
    this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(inProcessChannel(name)));
  }

  private boolean streamingEngineEnabled() {
    return options.isEnableStreamingEngine();
  }

  @Override
  public synchronized void setWindmillServiceEndpoints(Set<HostAndPort> endpoints)
      throws IOException {
    Preconditions.checkNotNull(endpoints);
    if (endpoints.equals(this.endpoints)) {
      // The endpoints are equal don't recreate the stubs.
      return;
    }
    LOG.info("Creating a new windmill stub, endpoints: {}", endpoints);
    if (this.endpoints != null) {
      LOG.info("Previous windmill stub endpoints: {}", this.endpoints);
    }
    initializeWindmillService(endpoints);
  }

  @Override
  public synchronized boolean isReady() {
    return !stubList.isEmpty();
  }

  private synchronized void initializeLocalHost(int port) throws IOException {
    this.logEveryNStreamFailures = 1;
    this.maxBackoff = Duration.millis(500);
    this.unaryDeadlineSeconds = 10; // For local testing use a short deadline.
    Channel channel = localhostChannel(port);
    if (streamingEngineEnabled()) {
      this.stubList.add(CloudWindmillServiceV1Alpha1Grpc.newStub(channel));
      this.syncStubList.add(CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(channel));
    } else {
      this.syncApplianceStub = WindmillApplianceGrpc.newBlockingStub(channel);
    }
  }

  /**
   * Create a wrapper around credentials callback that delegates to the underlying vendored {@link
   * com.google.auth.RequestMetadataCallback}. Note that this class should override every method
   * that is not final and not static and call the delegate directly.
   *
   * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation
   * delegate to reduce maintenance burden.
   */
  private static class VendoredRequestMetadataCallbackAdapter
      implements com.google.auth.RequestMetadataCallback {
    private final org.apache.beam.vendor.grpc.v1p13p1.com.google.auth.RequestMetadataCallback
        callback;

    private VendoredRequestMetadataCallbackAdapter(
        org.apache.beam.vendor.grpc.v1p13p1.com.google.auth.RequestMetadataCallback callback) {
      this.callback = callback;
    }

    @Override
    public void onSuccess(Map<String, List<String>> metadata) {
      callback.onSuccess(metadata);
    }

    @Override
    public void onFailure(Throwable exception) {
      callback.onFailure(exception);
    }
  }

  /**
   * Create a wrapper around credentials that delegates to the underlying {@link
   * com.google.auth.Credentials}. Note that this class should override every method that is not
   * final and not static and call the delegate directly.
   *
   * <p>TODO: Replace this with an auto generated proxy which calls the underlying implementation
   * delegate to reduce maintenance burden.
   */
  private static class VendoredCredentialsAdapter
      extends org.apache.beam.vendor.grpc.v1p13p1.com.google.auth.Credentials {
    private final com.google.auth.Credentials credentials;

    private VendoredCredentialsAdapter(com.google.auth.Credentials credentials) {
      this.credentials = credentials;
    }

    @Override
    public String getAuthenticationType() {
      return credentials.getAuthenticationType();
    }

    @Override
    public Map<String, List<String>> getRequestMetadata() throws IOException {
      return credentials.getRequestMetadata();
    }

    @Override
    public void getRequestMetadata(
        final URI uri,
        Executor executor,
        final org.apache.beam.vendor.grpc.v1p13p1.com.google.auth.RequestMetadataCallback
            callback) {
      credentials.getRequestMetadata(
          uri, executor, new VendoredRequestMetadataCallbackAdapter(callback));
    }

    @Override
    public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
      return credentials.getRequestMetadata(uri);
    }

    @Override
    public boolean hasRequestMetadata() {
      return credentials.hasRequestMetadata();
    }

    @Override
    public boolean hasRequestMetadataOnly() {
      return credentials.hasRequestMetadataOnly();
    }

    @Override
    public void refresh() throws IOException {
      credentials.refresh();
    }
  }

  private synchronized void initializeWindmillService(Set<HostAndPort> endpoints)
      throws IOException {
    LOG.info("Initializing Streaming Engine GRPC client for endpoints: {}", endpoints);
    this.stubList.clear();
    this.syncStubList.clear();
    this.endpoints = ImmutableSet.<HostAndPort>copyOf(endpoints);
    for (HostAndPort endpoint : this.endpoints) {
      if ("localhost".equals(endpoint.getHostText())) {
        initializeLocalHost(endpoint.getPort());
      } else {
        CallCredentials creds =
            MoreCallCredentials.from(new VendoredCredentialsAdapter(options.getGcpCredential()));
        this.stubList.add(
            CloudWindmillServiceV1Alpha1Grpc.newStub(remoteChannel(endpoint))
                .withCallCredentials(creds));
        this.syncStubList.add(
            CloudWindmillServiceV1Alpha1Grpc.newBlockingStub(remoteChannel(endpoint))
                .withCallCredentials(creds));
      }
    }
  }

  @VisibleForTesting
  static GrpcWindmillServer newTestInstance(String name, boolean enableStreamingEngine) {
    return new GrpcWindmillServer(name, enableStreamingEngine);
  }

  private Channel inProcessChannel(String name) {
    return InProcessChannelBuilder.forName(name).directExecutor().build();
  }

  private Channel localhostChannel(int port) {
    return NettyChannelBuilder.forAddress("localhost", port)
        .maxInboundMessageSize(java.lang.Integer.MAX_VALUE)
        .negotiationType(NegotiationType.PLAINTEXT)
        .build();
  }

  private Channel remoteChannel(HostAndPort endpoint) throws IOException {
    return NettyChannelBuilder.forAddress(endpoint.getHostText(), endpoint.getPort())
        .maxInboundMessageSize(java.lang.Integer.MAX_VALUE)
        .negotiationType(NegotiationType.TLS)
        // Set ciphers(null) to not use GCM, which is disabled for Dataflow
        // due to it being horribly slow.
        .sslContext(GrpcSslContexts.forClient().ciphers(null).build())
        .build();
  }

  private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1Stub stub() {
    if (stubList.isEmpty()) {
      throw new RuntimeException("windmillServiceEndpoint has not been set");
    }
    if (stubList.size() == 1) {
      return stubList.get(0);
    }
    return stubList.get(rand.nextInt(stubList.size()));
  }

  private synchronized CloudWindmillServiceV1Alpha1Grpc.CloudWindmillServiceV1Alpha1BlockingStub
      syncStub() {
    if (syncStubList.isEmpty()) {
      throw new RuntimeException("windmillServiceEndpoint has not been set");
    }
    if (syncStubList.size() == 1) {
      return syncStubList.get(0);
    }
    return syncStubList.get(rand.nextInt(syncStubList.size()));
  }

  @Override
  public void appendSummaryHtml(PrintWriter writer) {
    writer.write("Active Streams:<br>");
    for (AbstractWindmillStream<?, ?> stream : streamRegistry) {
      stream.appendSummaryHtml(writer);
      writer.write("<br>");
    }
  }

  // Configure backoff to retry calls forever, with a maximum sane retry interval.
  private BackOff grpcBackoff() {
    return FluentBackoff.DEFAULT
        .withInitialBackoff(MIN_BACKOFF)
        .withMaxBackoff(maxBackoff)
        .backoff();
  }

  private <ResponseT> ResponseT callWithBackoff(Supplier<ResponseT> function) {
    BackOff backoff = grpcBackoff();
    int rpcErrors = 0;
    while (true) {
      try {
        return function.get();
      } catch (StatusRuntimeException e) {
        try {
          if (++rpcErrors % 20 == 0) {
            LOG.warn(
                "Many exceptions calling gRPC. Last exception: {} with status {}",
                e,
                e.getStatus());
          }
          if (!BackOffUtils.next(Sleeper.DEFAULT, backoff)) {
            throw new WindmillServerStub.RpcException(e);
          }
        } catch (IOException | InterruptedException i) {
          if (i instanceof InterruptedException) {
            Thread.currentThread().interrupt();
          }
          WindmillServerStub.RpcException rpcException = new WindmillServerStub.RpcException(e);
          rpcException.addSuppressed(i);
          throw rpcException;
        }
      }
    }
  }

  @Override
  public GetWorkResponse getWork(GetWorkRequest request) {
    if (syncApplianceStub == null) {
      return callWithBackoff(
          () ->
              syncStub()
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .getWork(
                      request
                          .toBuilder()
                          .setJobId(options.getJobId())
                          .setProjectId(options.getProject())
                          .setWorkerId(options.getWorkerId())
                          .build()));
    } else {
      return callWithBackoff(
          () ->
              syncApplianceStub
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .getWork(request));
    }
  }

  @Override
  public GetDataResponse getData(GetDataRequest request) {
    if (syncApplianceStub == null) {
      return callWithBackoff(
          () ->
              syncStub()
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .getData(
                      request
                          .toBuilder()
                          .setJobId(options.getJobId())
                          .setProjectId(options.getProject())
                          .build()));
    } else {
      return callWithBackoff(
          () ->
              syncApplianceStub
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .getData(request));
    }
  }

  @Override
  public CommitWorkResponse commitWork(CommitWorkRequest request) {
    if (syncApplianceStub == null) {
      return callWithBackoff(
          () ->
              syncStub()
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .commitWork(
                      request
                          .toBuilder()
                          .setJobId(options.getJobId())
                          .setProjectId(options.getProject())
                          .build()));
    } else {
      return callWithBackoff(
          () ->
              syncApplianceStub
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .commitWork(request));
    }
  }

  @Override
  public GetWorkStream getWorkStream(GetWorkRequest request, WorkItemReceiver receiver) {
    return new GrpcGetWorkStream(
        GetWorkRequest.newBuilder(request)
            .setJobId(options.getJobId())
            .setProjectId(options.getProject())
            .setWorkerId(options.getWorkerId())
            .build(),
        receiver);
  }

  @Override
  public GetDataStream getDataStream() {
    return new GrpcGetDataStream();
  }

  @Override
  public CommitWorkStream commitWorkStream() {
    return new GrpcCommitWorkStream();
  }

  @Override
  public GetConfigResponse getConfig(GetConfigRequest request) {
    if (syncApplianceStub == null) {
      throw new RpcException(
          new UnsupportedOperationException("GetConfig not supported with windmill service."));
    } else {
      return callWithBackoff(
          () ->
              syncApplianceStub
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .getConfig(request));
    }
  }

  @Override
  public ReportStatsResponse reportStats(ReportStatsRequest request) {
    if (syncApplianceStub == null) {
      throw new RpcException(
          new UnsupportedOperationException("ReportStats not supported with windmill service."));
    } else {
      return callWithBackoff(
          () ->
              syncApplianceStub
                  .withDeadlineAfter(unaryDeadlineSeconds, TimeUnit.SECONDS)
                  .reportStats(request));
    }
  }

  @Override
  public long getAndResetThrottleTime() {
    return getWorkThrottleTimer.getAndResetThrottleTime()
        + getDataThrottleTimer.getAndResetThrottleTime()
        + commitWorkThrottleTimer.getAndResetThrottleTime();
  }

  private JobHeader makeHeader() {
    return JobHeader.newBuilder()
        .setJobId(options.getJobId())
        .setProjectId(options.getProject())
        .setWorkerId(options.getWorkerId())
        .build();
  }

  /** Returns a long that is unique to this process. */
  private static long uniqueId() {
    return nextId.incrementAndGet();
  }

  /**
   * Base class for persistent streams connecting to Windmill.
   *
   * <p>This class handles the underlying gRPC StreamObservers, and automatically reconnects the
   * stream if it is broken. Subclasses are responsible for retrying requests that have been lost on
   * a broken stream.
   *
   * <p>Subclasses should override onResponse to handle responses from the server, and onNewStream
   * to perform any work that must be done when a new stream is created, such as sending headers or
   * retrying requests.
   *
   * <p>send and startStream should not be called from onResponse; use executor() instead.
   *
   * <p>Synchronization on this is used to synchronize the gRpc stream state and internal data
   * structures. Since grpc channel operations may block, synchronization on this stream may also
   * block. This is generally not a problem since streams are used in a single-threaded manner.
   * However some accessors used for status page and other debugging need to take care not to
   * require synchronizing on this.
   */
  private abstract class AbstractWindmillStream<RequestT, ResponseT> implements WindmillStream {
    private final StreamObserverFactory streamObserverFactory = StreamObserverFactory.direct();
    private final Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory;
    private final Executor executor = Executors.newSingleThreadExecutor();

    // The following should be protected by synchronizing on this, except for
    // the atomics which may be read atomically for status pages.
    private StreamObserver<RequestT> requestObserver;
    private final AtomicLong startTimeMs = new AtomicLong();
    private final AtomicInteger errorCount = new AtomicInteger();
    private final BackOff backoff = grpcBackoff();
    private final AtomicLong sleepUntil = new AtomicLong();
    protected final AtomicBoolean clientClosed = new AtomicBoolean();
    private final CountDownLatch finishLatch = new CountDownLatch(1);

    protected AbstractWindmillStream(
        Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory) {
      this.clientFactory = clientFactory;
    }

    /** Called on each response from the server */
    protected abstract void onResponse(ResponseT response);
    /** Called when a new underlying stream to the server has been opened. */
    protected abstract void onNewStream();
    /** Returns whether there are any pending requests that should be retried on a stream break. */
    protected abstract boolean hasPendingRequests();
    /**
     * Called when the stream is throttled due to resource exhausted errors. Will be called for each
     * resource exhausted error not just the first. onResponse() must stop throttling on reciept of
     * the first good message.
     */
    protected abstract void startThrottleTimer();
    /** Send a request to the server. */
    protected final synchronized void send(RequestT request) {
      requestObserver.onNext(request);
    }

    /** Starts the underlying stream. */
    protected final void startStream() {
      // Add the stream to the registry after it has been fully constructed.
      streamRegistry.add(this);
      BackOff backoff = grpcBackoff();
      while (true) {
        try {
          synchronized (this) {
            startTimeMs.set(Instant.now().getMillis());
            requestObserver = streamObserverFactory.from(clientFactory, new ResponseObserver());
            onNewStream();
            if (clientClosed.get()) {
              close();
            }
            return;
          }
        } catch (Exception e) {
          LOG.error("Failed to create new stream, retrying: ", e);
          try {
            long sleep = backoff.nextBackOffMillis();
            sleepUntil.set(Instant.now().getMillis() + sleep);
            Thread.sleep(sleep);
          } catch (InterruptedException i) {
            // Keep trying to create the stream.
          } catch (IOException i) {
            // Ignore.
          }
        }
      }
    }

    protected final Executor executor() {
      return executor;
    }

    // Care is taken that synchronization on this is unnecessary for all status page information.
    // Blocking sends are made beneath this stream object's lock which could block status page
    // rendering.
    public final void appendSummaryHtml(PrintWriter writer) {
      appendSpecificHtml(writer);
      if (errorCount.get() > 0) {
        writer.format(", %d errors", errorCount.get());
      }
      if (clientClosed.get()) {
        writer.write(", client closed");
      }
      long sleepLeft = sleepUntil.get() - Instant.now().getMillis();
      if (sleepLeft > 0) {
        writer.format(", %dms backoff remaining", sleepLeft);
      }
      writer.format(", current stream is %dms old", Instant.now().getMillis() - startTimeMs.get());
    }

    // Don't require synchronization on stream, see the appendSummaryHtml comment.
    protected abstract void appendSpecificHtml(PrintWriter writer);

    private class ResponseObserver implements StreamObserver<ResponseT> {
      @Override
      public void onNext(ResponseT response) {
        try {
          backoff.reset();
        } catch (IOException e) {
          // Ignore.
        }
        onResponse(response);
      }

      @Override
      public void onError(Throwable t) {
        onStreamFinished(t);
      }

      @Override
      public void onCompleted() {
        onStreamFinished(null);
      }

      private void onStreamFinished(@Nullable Throwable t) {
        synchronized (this) {
          if (clientClosed.get() && !hasPendingRequests()) {
            streamRegistry.remove(AbstractWindmillStream.this);
            finishLatch.countDown();
            return;
          }
        }
        if (t != null) {
          Status status = null;
          if (t instanceof StatusRuntimeException) {
            status = ((StatusRuntimeException) t).getStatus();
          }
          if (errorCount.incrementAndGet() % logEveryNStreamFailures == 0) {
            LOG.warn(
                "{} streaming Windmill RPC errors for a stream, last was: {} with status {}",
                errorCount.get(),
                t.toString(),
                status);
          }
          // If the stream was stopped due to a resource exhausted error then we are throttled.
          if (status != null && status.getCode() == Status.Code.RESOURCE_EXHAUSTED) {
            startThrottleTimer();
          }

          try {
            long sleep = backoff.nextBackOffMillis();
            sleepUntil.set(Instant.now().getMillis() + sleep);
            Thread.sleep(sleep);
          } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
          } catch (IOException e) {
            // Ignore.
          }
        }
        executor.execute(AbstractWindmillStream.this::startStream);
      }
    }

    @Override
    public final synchronized void close() {
      // Synchronization of close and onCompleted necessary for correct retry logic in onNewStream.
      clientClosed.set(true);
      requestObserver.onCompleted();
    }

    @Override
    public final boolean awaitTermination(int time, TimeUnit unit) throws InterruptedException {
      return finishLatch.await(time, unit);
    }

    @Override
    public final void closeAfterDefaultTimeout() throws InterruptedException {
      if (!finishLatch.await(DEFAULT_STREAM_CLEAN_CLOSE_SECONDS, TimeUnit.SECONDS)) {
        // If the stream did not close due to error in the specified amount of time, half-close
        // the stream cleanly.
        close();
      }
    }

    @Override
    public final Instant startTime() {
      return new Instant(startTimeMs.get());
    }
  }

  private class GrpcGetWorkStream
      extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk>
      implements GetWorkStream {
    private final GetWorkRequest request;
    private final WorkItemReceiver receiver;
    private final Map<Long, WorkItemBuffer> buffers = new ConcurrentHashMap<>();
    private final AtomicLong inflightMessages = new AtomicLong();
    private final AtomicLong inflightBytes = new AtomicLong();

    private GrpcGetWorkStream(GetWorkRequest request, WorkItemReceiver receiver) {
      super(
          responseObserver ->
              stub()
                  .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                  .getWorkStream(responseObserver));
      this.request = request;
      this.receiver = receiver;
      startStream();
    }

    @Override
    protected synchronized void onNewStream() {
      buffers.clear();
      inflightMessages.set(request.getMaxItems());
      inflightBytes.set(request.getMaxBytes());
      send(StreamingGetWorkRequest.newBuilder().setRequest(request).build());
    }

    @Override
    protected boolean hasPendingRequests() {
      return false;
    }

    @Override
    public void appendSpecificHtml(PrintWriter writer) {
      // Number of buffers is same as distict workers that sent work on this stream.
      writer.format(
          "GetWorkStream: %d buffers, %d inflight messages allowed, %d inflight bytes allowed",
          buffers.size(), inflightMessages.intValue(), inflightBytes.intValue());
    }

    @Override
    protected void onResponse(StreamingGetWorkResponseChunk chunk) {
      getWorkThrottleTimer.stop();
      long id = chunk.getStreamId();

      WorkItemBuffer buffer = buffers.computeIfAbsent(id, (Long l) -> new WorkItemBuffer());
      buffer.append(chunk);

      if (chunk.getRemainingBytesForWorkItem() == 0) {
        long size = buffer.bufferedSize();
        buffer.runAndReset();

        // Record the fact that there are now fewer outstanding messages and bytes on the stream.
        long numInflight = inflightMessages.decrementAndGet();
        long bytesInflight = inflightBytes.addAndGet(-size);

        // If the outstanding items or bytes limit has gotten too low, top both off with a
        // GetWorkExtension.  The goal is to keep the limits relatively close to their maximum
        // values without sending too many extension requests.
        if (numInflight < request.getMaxItems() / 2 || bytesInflight < request.getMaxBytes() / 2) {
          long moreItems = request.getMaxItems() - numInflight;
          long moreBytes = request.getMaxBytes() - bytesInflight;
          inflightMessages.getAndAdd(moreItems);
          inflightBytes.getAndAdd(moreBytes);
          final StreamingGetWorkRequest extension =
              StreamingGetWorkRequest.newBuilder()
                  .setRequestExtension(
                      StreamingGetWorkRequestExtension.newBuilder()
                          .setMaxItems(moreItems)
                          .setMaxBytes(moreBytes))
                  .build();
          executor()
              .execute(
                  () -> {
                    try {
                      send(extension);
                    } catch (IllegalStateException e) {
                      // Stream was closed.
                    }
                  });
        }
      }
    }

    @Override
    protected void startThrottleTimer() {
      getWorkThrottleTimer.start();
    }

    private class WorkItemBuffer {
      private String computation;
      private Instant inputDataWatermark;
      private Instant synchronizedProcessingTime;
      private ByteString data = ByteString.EMPTY;
      private long bufferedSize = 0;

      private void setMetadata(Windmill.ComputationWorkItemMetadata metadata) {
        this.computation = metadata.getComputationId();
        this.inputDataWatermark =
            WindmillTimeUtils.windmillToHarnessWatermark(metadata.getInputDataWatermark());
        this.synchronizedProcessingTime =
            WindmillTimeUtils.windmillToHarnessWatermark(
                metadata.getDependentRealtimeInputWatermark());
      }

      public void append(StreamingGetWorkResponseChunk chunk) {
        if (chunk.hasComputationMetadata()) {
          setMetadata(chunk.getComputationMetadata());
        }

        this.data = data.concat(chunk.getSerializedWorkItem());
        this.bufferedSize += chunk.getSerializedWorkItem().size();
      }

      public long bufferedSize() {
        return bufferedSize;
      }

      public void runAndReset() {
        try {
          receiver.receiveWork(
              computation,
              inputDataWatermark,
              synchronizedProcessingTime,
              Windmill.WorkItem.parseFrom(data.newInput()));
        } catch (IOException e) {
          LOG.error("Failed to parse work item from stream: ", e);
        }
        data = ByteString.EMPTY;
        bufferedSize = 0;
      }
    }
  }

  private class GrpcGetDataStream
      extends AbstractWindmillStream<StreamingGetDataRequest, StreamingGetDataResponse>
      implements GetDataStream {
    private class QueuedRequest {
      public QueuedRequest(String computation, KeyedGetDataRequest request) {
        this.id = uniqueId();
        this.globalDataRequest = null;
        this.dataRequest =
            ComputationGetDataRequest.newBuilder()
                .setComputationId(computation)
                .addRequests(request)
                .build();
        this.byteSize = this.dataRequest.getSerializedSize();
      }

      public QueuedRequest(GlobalDataRequest request) {
        this.id = uniqueId();
        this.globalDataRequest = request;
        this.dataRequest = null;
        this.byteSize = this.globalDataRequest.getSerializedSize();
      }

      final long id;
      final long byteSize;
      final GlobalDataRequest globalDataRequest;
      final ComputationGetDataRequest dataRequest;
      AppendableInputStream responseStream = null;
    }

    private class QueuedBatch {
      public QueuedBatch() {}

      final List<QueuedRequest> requests = new ArrayList<>();
      long byteSize = 0;
      boolean finalized = false;
      final CountDownLatch sent = new CountDownLatch(1);
    };

    private final Deque<QueuedBatch> batches = new ConcurrentLinkedDeque<>();
    private final Map<Long, AppendableInputStream> pending = new ConcurrentHashMap<>();

    @Override
    public void appendSpecificHtml(PrintWriter writer) {
      writer.format(
          "GetDataStream: %d pending on-wire, %d queued batches", pending.size(), batches.size());
    }

    GrpcGetDataStream() {
      super(
          responseObserver ->
              stub()
                  .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                  .getDataStream(responseObserver));
      startStream();
    }

    @Override
    protected synchronized void onNewStream() {
      send(StreamingGetDataRequest.newBuilder().setHeader(makeHeader()).build());

      if (clientClosed.get()) {
        // We rely on close only occurring after all methods on the stream have returned.
        // Since the requestKeyedData and requestGlobalData methods are blocking this
        // means there should be no pending requests.
        Verify.verify(!hasPendingRequests());
      } else {
        for (AppendableInputStream responseStream : pending.values()) {
          responseStream.cancel();
        }
      }
    }

    @Override
    protected boolean hasPendingRequests() {
      return !pending.isEmpty() || !batches.isEmpty();
    }

    @Override
    protected void onResponse(StreamingGetDataResponse chunk) {
      Preconditions.checkArgument(chunk.getRequestIdCount() == chunk.getSerializedResponseCount());
      Preconditions.checkArgument(
          chunk.getRemainingBytesForResponse() == 0 || chunk.getRequestIdCount() == 1);
      getDataThrottleTimer.stop();

      for (int i = 0; i < chunk.getRequestIdCount(); ++i) {
        AppendableInputStream responseStream = pending.get(chunk.getRequestId(i));
        Verify.verify(responseStream != null, "No pending response stream");
        responseStream.append(chunk.getSerializedResponse(i).newInput());
        if (chunk.getRemainingBytesForResponse() == 0) {
          responseStream.complete();
        }
      }
    }

    @Override
    protected void startThrottleTimer() {
      getDataThrottleTimer.start();
    }

    @Override
    public KeyedGetDataResponse requestKeyedData(String computation, KeyedGetDataRequest request) {
      return issueRequest(new QueuedRequest(computation, request), KeyedGetDataResponse::parseFrom);
    }

    @Override
    public GlobalData requestGlobalData(GlobalDataRequest request) {
      return issueRequest(new QueuedRequest(request), GlobalData::parseFrom);
    }

    @Override
    public void refreshActiveWork(Map<String, List<KeyedGetDataRequest>> active) {
      long builderBytes = 0;
      StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
      for (Map.Entry<String, List<KeyedGetDataRequest>> entry : active.entrySet()) {
        for (KeyedGetDataRequest request : entry.getValue()) {
          // Calculate the bytes with some overhead for proto encoding.
          long bytes = (long) entry.getKey().length() + request.getSerializedSize() + 10;
          if (builderBytes > 0
              && (builderBytes + bytes > GET_DATA_STREAM_CHUNK_SIZE
                  || builder.getRequestIdCount() >= streamingRpcBatchLimit)) {
            send(builder.build());
            builderBytes = 0;
            builder.clear();
          }
          builderBytes += bytes;
          builder.addStateRequest(
              ComputationGetDataRequest.newBuilder()
                  .setComputationId(entry.getKey())
                  .addRequests(request));
        }
      }
      if (builderBytes > 0) {
        send(builder.build());
      }
    }

    private <ResponseT> ResponseT issueRequest(QueuedRequest request, ParseFn<ResponseT> parseFn) {
      while (true) {
        request.responseStream = new AppendableInputStream();
        try {
          queueRequestAndWait(request);
          return parseFn.parse(request.responseStream);
        } catch (CancellationException e) {
          // Retry issuing the request since the response stream was cancelled.
          continue;
        } catch (IOException e) {
          LOG.error("Parsing GetData response failed: ", e);
          continue;
        } catch (InterruptedException e) {
          Thread.currentThread().interrupt();
          throw new RuntimeException(e);
        } finally {
          pending.remove(request.id);
        }
      }
    }

    private void queueRequestAndWait(QueuedRequest request) throws InterruptedException {
      QueuedBatch batch;
      boolean responsibleForSend = false;
      CountDownLatch waitForSendLatch = null;
      synchronized (batches) {
        batch = batches.isEmpty() ? null : batches.getLast();
        if (batch == null
            || batch.finalized
            || batch.requests.size() >= streamingRpcBatchLimit
            || batch.byteSize + request.byteSize > GET_DATA_STREAM_CHUNK_SIZE) {
          if (batch != null) {
            waitForSendLatch = batch.sent;
          }
          batch = new QueuedBatch();
          batches.addLast(batch);
          responsibleForSend = true;
        }
        batch.requests.add(request);
        batch.byteSize += request.byteSize;
      }
      if (responsibleForSend) {
        if (waitForSendLatch == null) {
          // If there was not a previous batch wait a little while to improve
          // batching.
          Thread.sleep(1);
        } else {
          waitForSendLatch.await();
        }
        // Finalize the batch so that no additional requests will be added.  Leave the batch in the
        // queue so that a subsequent batch will wait for it's completion.
        synchronized (batches) {
          Verify.verify(batch == batches.peekFirst());
          batch.finalized = true;
        }
        sendBatch(batch.requests);
        synchronized (batches) {
          Verify.verify(batch == batches.pollFirst());
        }
        // Notify all waiters with requests in this batch as well as the sender
        // of the next batch (if one exists).
        batch.sent.countDown();
      } else {
        // Wait for this batch to be sent before parsing the response.
        batch.sent.await();
      }
    }

    private void sendBatch(List<QueuedRequest> requests) {
      StreamingGetDataRequest batchedRequest = flushToBatch(requests);
      synchronized (this) {
        // Synchronization of pending inserts is necessary with send to ensure duplicates are not
        // sent on stream reconnect.
        for (QueuedRequest request : requests) {
          Verify.verify(pending.put(request.id, request.responseStream) == null);
        }
        try {
          send(batchedRequest);
        } catch (IllegalStateException e) {
          // The stream broke before this call went through; onNewStream will retry the fetch.
        }
      }
    }

    private StreamingGetDataRequest flushToBatch(List<QueuedRequest> requests) {
      // Put all global data requests first because there is only a single repeated field for
      // request ids and the initial ids correspond to global data requests if they are present.
      requests.sort(
          (QueuedRequest r1, QueuedRequest r2) -> {
            boolean r1gd = r1.globalDataRequest != null;
            boolean r2gd = r2.globalDataRequest != null;
            return r1gd == r2gd ? 0 : (r1gd ? -1 : 1);
          });
      StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder();
      for (QueuedRequest request : requests) {
        builder.addRequestId(request.id);
        if (request.globalDataRequest == null) {
          builder.addStateRequest(request.dataRequest);
        } else {
          builder.addGlobalDataRequest(request.globalDataRequest);
        }
      }
      return builder.build();
    }
  }

  private class GrpcCommitWorkStream
      extends AbstractWindmillStream<StreamingCommitWorkRequest, StreamingCommitResponse>
      implements CommitWorkStream {
    private class PendingRequest {
      private final String computation;
      private final WorkItemCommitRequest request;
      private final Consumer<CommitStatus> onDone;

      PendingRequest(
          String computation, WorkItemCommitRequest request, Consumer<CommitStatus> onDone) {
        this.computation = computation;
        this.request = request;
        this.onDone = onDone;
      }

      long getBytes() {
        return (long) request.getSerializedSize() + computation.length();
      }
    }

    private final Map<Long, PendingRequest> pending = new ConcurrentHashMap<>();

    private class Batcher {
      long queuedBytes = 0;
      Map<Long, PendingRequest> queue = new HashMap<>();

      boolean canAccept(PendingRequest request) {
        return queue.isEmpty()
            || (queue.size() < streamingRpcBatchLimit
                && (request.getBytes() + queuedBytes) < COMMIT_STREAM_CHUNK_SIZE);
      }

      void add(long id, PendingRequest request) {
        assert (canAccept(request));
        queuedBytes += request.getBytes();
        queue.put(id, request);
      }

      void flush() {
        flushInternal(queue);
        queuedBytes = 0;
      }
    }

    private final Batcher batcher = new Batcher();

    GrpcCommitWorkStream() {
      super(
          responseObserver ->
              stub()
                  .withDeadlineAfter(DEFAULT_STREAM_RPC_DEADLINE_SECONDS, TimeUnit.SECONDS)
                  .commitWorkStream(responseObserver));
      startStream();
    }

    @Override
    public void appendSpecificHtml(PrintWriter writer) {
      writer.format("CommitWorkStream: %d pending", pending.size());
    }

    @Override
    protected synchronized void onNewStream() {
      send(StreamingCommitWorkRequest.newBuilder().setHeader(makeHeader()).build());
      Batcher resendBatcher = new Batcher();
      for (Map.Entry<Long, PendingRequest> entry : pending.entrySet()) {
        if (!resendBatcher.canAccept(entry.getValue())) {
          resendBatcher.flush();
        }
        resendBatcher.add(entry.getKey(), entry.getValue());
      }
      resendBatcher.flush();
    }

    @Override
    protected boolean hasPendingRequests() {
      return !pending.isEmpty();
    }

    @Override
    protected void onResponse(StreamingCommitResponse response) {
      commitWorkThrottleTimer.stop();

      for (int i = 0; i < response.getRequestIdCount(); ++i) {
        long requestId = response.getRequestId(i);
        PendingRequest done = pending.remove(requestId);
        if (done == null) {
          LOG.error("Got unknown commit request ID: {}", requestId);
        } else {
          done.onDone.accept(
              (i < response.getStatusCount()) ? response.getStatus(i) : CommitStatus.OK);
        }
      }
    }

    @Override
    protected void startThrottleTimer() {
      commitWorkThrottleTimer.start();
    }

    @Override
    public boolean commitWorkItem(
        String computation, WorkItemCommitRequest commitRequest, Consumer<CommitStatus> onDone) {
      PendingRequest request = new PendingRequest(computation, commitRequest, onDone);
      if (!batcher.canAccept(request)) {
        return false;
      }
      batcher.add(uniqueId(), request);
      return true;
    }

    @Override
    public void flush() {
      batcher.flush();
    }

    private final void flushInternal(Map<Long, PendingRequest> requests) {
      if (requests.isEmpty()) {
        return;
      }
      if (requests.size() == 1) {
        Map.Entry<Long, PendingRequest> elem = requests.entrySet().iterator().next();
        if (elem.getValue().request.getSerializedSize() > COMMIT_STREAM_CHUNK_SIZE) {
          issueMultiChunkRequest(elem.getKey(), elem.getValue());
        } else {
          issueSingleRequest(elem.getKey(), elem.getValue());
        }
      } else {
        issueBatchedRequest(requests);
      }
      requests.clear();
    }

    private void issueSingleRequest(final long id, PendingRequest pendingRequest) {
      StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
      requestBuilder
          .addCommitChunkBuilder()
          .setComputationId(pendingRequest.computation)
          .setRequestId(id)
          .setShardingKey(pendingRequest.request.getShardingKey())
          .setSerializedWorkItemCommit(pendingRequest.request.toByteString());
      StreamingCommitWorkRequest chunk = requestBuilder.build();
      try {
        synchronized (this) {
          pending.put(id, pendingRequest);
          send(chunk);
        }
      } catch (IllegalStateException e) {
        // Stream was broken, request will be retried when stream is reopened.
      }
    }

    private void issueBatchedRequest(Map<Long, PendingRequest> requests) {
      StreamingCommitWorkRequest.Builder requestBuilder = StreamingCommitWorkRequest.newBuilder();
      String lastComputation = null;
      for (Map.Entry<Long, PendingRequest> entry : requests.entrySet()) {
        PendingRequest request = entry.getValue();
        StreamingCommitRequestChunk.Builder chunkBuilder = requestBuilder.addCommitChunkBuilder();
        if (lastComputation == null || !lastComputation.equals(request.computation)) {
          chunkBuilder.setComputationId(request.computation);
          lastComputation = request.computation;
        }
        chunkBuilder.setRequestId(entry.getKey());
        chunkBuilder.setShardingKey(request.request.getShardingKey());
        chunkBuilder.setSerializedWorkItemCommit(request.request.toByteString());
      }
      StreamingCommitWorkRequest request = requestBuilder.build();
      try {
        synchronized (this) {
          pending.putAll(requests);
          send(request);
        }
      } catch (IllegalStateException e) {
        // Stream was broken, request will be retried when stream is reopened.
      }
    }

    private void issueMultiChunkRequest(final long id, PendingRequest pendingRequest) {
      Preconditions.checkNotNull(pendingRequest.computation);
      final ByteString serializedCommit = pendingRequest.request.toByteString();

      synchronized (this) {
        pending.put(id, pendingRequest);
        for (int i = 0; i < serializedCommit.size(); i += COMMIT_STREAM_CHUNK_SIZE) {
          int end = i + COMMIT_STREAM_CHUNK_SIZE;
          ByteString chunk = serializedCommit.substring(i, Math.min(end, serializedCommit.size()));

          StreamingCommitRequestChunk.Builder chunkBuilder =
              StreamingCommitRequestChunk.newBuilder()
                  .setRequestId(id)
                  .setSerializedWorkItemCommit(chunk)
                  .setComputationId(pendingRequest.computation)
                  .setShardingKey(pendingRequest.request.getShardingKey());
          int remaining = serializedCommit.size() - end;
          if (remaining > 0) {
            chunkBuilder.setRemainingBytesForWorkItem(remaining);
          }

          StreamingCommitWorkRequest requestChunk =
              StreamingCommitWorkRequest.newBuilder().addCommitChunk(chunkBuilder).build();
          try {
            send(requestChunk);
          } catch (IllegalStateException e) {
            // Stream was broken, request will be retried when stream is reopened.
            break;
          }
        }
      }
    }
  }

  @FunctionalInterface
  private interface ParseFn<ResponseT> {
    ResponseT parse(InputStream input) throws IOException;
  }

  /** An InputStream that can be dynamically extended with additional InputStreams. */
  @SuppressWarnings("JdkObsolete")
  private static class AppendableInputStream extends InputStream {
    private static final InputStream POISON_PILL = ByteString.EMPTY.newInput();
    private final AtomicBoolean cancelled = new AtomicBoolean(false);
    private final AtomicBoolean complete = new AtomicBoolean(false);
    private final BlockingDeque<InputStream> queue = new LinkedBlockingDeque<>(10);
    private final InputStream stream =
        new SequenceInputStream(
            new Enumeration<InputStream>() {
              InputStream current = ByteString.EMPTY.newInput();

              @Override
              public boolean hasMoreElements() {
                if (current != null) {
                  return true;
                }

                try {
                  current = queue.take();
                  if (current != POISON_PILL) {
                    return true;
                  }
                  if (cancelled.get()) {
                    throw new CancellationException();
                  }
                  if (complete.get()) {
                    return false;
                  }
                  throw new IllegalStateException("Got poison pill but stream is not done.");
                } catch (InterruptedException e) {
                  Thread.currentThread().interrupt();
                  throw new CancellationException();
                }
              }

              @Override
              public InputStream nextElement() {
                if (!hasMoreElements()) {
                  throw new NoSuchElementException();
                }
                InputStream next = current;
                current = null;
                return next;
              }
            });

    /** Appends a new InputStream to the tail of this stream. */
    public synchronized void append(InputStream chunk) {
      try {
        queue.put(chunk);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
    }

    /** Cancels the stream. Future calls to InputStream methods will throw CancellationException. */
    public synchronized void cancel() {
      cancelled.set(true);
      try {
        // Put the poison pill at the head of the queue to cancel as quickly as possible.
        queue.clear();
        queue.putFirst(POISON_PILL);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
    }

    /** Signals that no new InputStreams will be added to this stream. */
    public synchronized void complete() {
      complete.set(true);
      try {
        queue.put(POISON_PILL);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
    }

    @Override
    public int read() throws IOException {
      if (cancelled.get()) {
        throw new CancellationException();
      }
      return stream.read();
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
      if (cancelled.get()) {
        throw new CancellationException();
      }
      return stream.read(b, off, len);
    }

    @Override
    public int available() throws IOException {
      if (cancelled.get()) {
        throw new CancellationException();
      }
      return stream.available();
    }

    @Override
    public void close() throws IOException {
      stream.close();
    }
  }

  /**
   * A stopwatch used to track the amount of time spent throttled due to Resource Exhausted errors.
   * Throttle time is cumulative for all three rpcs types but not for all streams. So if GetWork and
   * CommitWork are both blocked for x, totalTime will be 2x. However, if 2 GetWork streams are both
   * blocked for x totalTime will be x. All methods are thread safe.
   */
  private static class ThrottleTimer {

    // This is -1 if not currently being throttled or the time in
    // milliseconds when throttling for this type started.
    private long startTime = -1;
    // This is the collected total throttle times since the last poll.  Throttle times are
    // reported as a delta so this is cleared whenever it gets reported.
    private long totalTime = 0;

    /**
     * Starts the timer if it has not been started and does nothing if it has already been started.
     */
    public synchronized void start() {
      if (!throttled()) { // This timer is not started yet so start it now.
        startTime = Instant.now().getMillis();
      }
    }

    /** Stops the timer if it has been started and does nothing if it has not been started. */
    public synchronized void stop() {
      if (throttled()) { // This timer has been started already so stop it now.
        totalTime += Instant.now().getMillis() - startTime;
        startTime = -1;
      }
    }

    /** Returns if the specified type is currently being throttled */
    public synchronized boolean throttled() {
      return startTime != -1;
    }

    /** Returns the combined total of all throttle times and resets those times to 0. */
    public synchronized long getAndResetThrottleTime() {
      if (throttled()) {
        stop();
        start();
      }
      long toReturn = totalTime;
      totalTime = 0;
      return toReturn;
    }
  }
}
