Merge pull request #10580: [BEAM-9116] Limit the number of past invocations stored in JobServer

diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
index 80fbd65..abda968 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobService.java
@@ -21,7 +21,7 @@
 import java.util.List;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ConcurrentLinkedDeque;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import org.apache.beam.model.jobmanagement.v1.JobApi;
@@ -57,6 +57,7 @@
 import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusException;
 import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.StatusRuntimeException;
 import org.apache.beam.vendor.grpc.v1p26p0.io.grpc.stub.StreamObserver;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -72,12 +73,16 @@
 public class InMemoryJobService extends JobServiceGrpc.JobServiceImplBase implements FnService {
   private static final Logger LOG = LoggerFactory.getLogger(InMemoryJobService.class);
 
+  /** The default maximum number of completed invocations to keep. */
+  public static final int DEFAULT_MAX_INVOCATION_HISTORY = 10;
+
   /**
    * Creates an InMemoryJobService.
    *
    * @param stagingServiceDescriptor Endpoint for the staging service.
    * @param stagingServiceTokenProvider Function mapping a preparationId to a staging service token.
-   * @param invoker A JobInvoker that will actually create the jobs.
+   * @param cleanupJobFn A cleanup function to run, parameterized with the staging token of a job.
+   * @param invoker A JobInvoker which creates the jobs.
    * @return A new InMemoryJobService.
    */
   public static InMemoryJobService create(
@@ -86,22 +91,60 @@
       ThrowingConsumer<Exception, String> cleanupJobFn,
       JobInvoker invoker) {
     return new InMemoryJobService(
-        stagingServiceDescriptor, stagingServiceTokenProvider, cleanupJobFn, invoker);
+        stagingServiceDescriptor,
+        stagingServiceTokenProvider,
+        cleanupJobFn,
+        invoker,
+        DEFAULT_MAX_INVOCATION_HISTORY);
   }
 
-  private final ConcurrentMap<String, JobPreparation> preparations;
-  private final ConcurrentMap<String, JobInvocation> invocations;
-  private final ConcurrentMap<String, String> stagingSessionTokens;
+  /**
+   * Creates an InMemoryJobService.
+   *
+   * @param stagingServiceDescriptor The endpoint for the staging service.
+   * @param stagingServiceTokenProvider Function mapping a preparationId to a staging service token.
+   * @param cleanupJobFn A cleanup function to run, parameterized with the staging token of a job.
+   * @param invoker A JobInvoker which creates the jobs.
+   * @param maxInvocationHistory The maximum number of completed invocations to keep.
+   * @return A new InMemoryJobService.
+   */
+  public static InMemoryJobService create(
+      Endpoints.ApiServiceDescriptor stagingServiceDescriptor,
+      Function<String, String> stagingServiceTokenProvider,
+      ThrowingConsumer<Exception, String> cleanupJobFn,
+      JobInvoker invoker,
+      int maxInvocationHistory) {
+    return new InMemoryJobService(
+        stagingServiceDescriptor,
+        stagingServiceTokenProvider,
+        cleanupJobFn,
+        invoker,
+        maxInvocationHistory);
+  }
+
+  /** Map of preparationId to preparation. */
+  private final ConcurrentHashMap<String, JobPreparation> preparations;
+  /** Map of preparationId to staging token. */
+  private final ConcurrentHashMap<String, String> stagingSessionTokens;
+  /** Map of invocationId to invocation. */
+  private final ConcurrentHashMap<String, JobInvocation> invocations;
+  /** InvocationIds of completed invocations in least-recently-completed order. */
+  private final ConcurrentLinkedDeque<String> completedInvocationsIds;
+
   private final Endpoints.ApiServiceDescriptor stagingServiceDescriptor;
   private final Function<String, String> stagingServiceTokenProvider;
   private final ThrowingConsumer<Exception, String> cleanupJobFn;
   private final JobInvoker invoker;
 
+  /** The maximum number of past invocations to keep. */
+  private final int maxInvocationHistory;
+
   private InMemoryJobService(
       Endpoints.ApiServiceDescriptor stagingServiceDescriptor,
       Function<String, String> stagingServiceTokenProvider,
       ThrowingConsumer<Exception, String> cleanupJobFn,
-      JobInvoker invoker) {
+      JobInvoker invoker,
+      int maxInvocationHistory) {
     this.stagingServiceDescriptor = stagingServiceDescriptor;
     this.stagingServiceTokenProvider = stagingServiceTokenProvider;
     this.cleanupJobFn = cleanupJobFn;
@@ -109,9 +152,10 @@
 
     this.preparations = new ConcurrentHashMap<>();
     this.invocations = new ConcurrentHashMap<>();
-
-    // Map "preparation ID" to staging token
     this.stagingSessionTokens = new ConcurrentHashMap<>();
+    this.completedInvocationsIds = new ConcurrentLinkedDeque<>();
+    Preconditions.checkArgument(maxInvocationHistory >= 0);
+    this.maxInvocationHistory = maxInvocationHistory;
   }
 
   @Override
@@ -196,20 +240,25 @@
             }
             String stagingSessionToken = stagingSessionTokens.get(preparationId);
             stagingSessionTokens.remove(preparationId);
-            if (cleanupJobFn != null) {
-              try {
+            try {
+              if (cleanupJobFn != null) {
                 cleanupJobFn.accept(stagingSessionToken);
-              } catch (Exception e) {
-                LOG.warn(
-                    "Failed to remove job staging directory for token {}: {}",
-                    stagingSessionToken,
-                    e);
               }
+            } catch (Exception e) {
+              LOG.warn(
+                  "Failed to remove job staging directory for token {}: {}",
+                  stagingSessionToken,
+                  e);
+            } finally {
+              onFinishedInvocationCleanup(invocationId);
             }
           });
 
       invocation.start();
       invocations.put(invocationId, invocation);
+      // Cleanup this preparation because we are running it now.
+      // If we fail, we need to prepare again.
+      preparations.remove(preparationId);
       RunJobResponse response = RunJobResponse.newBuilder().setJobId(invocationId).build();
       responseObserver.onNext(response);
       responseObserver.onCompleted();
@@ -426,4 +475,14 @@
     }
     return invocation;
   }
+
+  private void onFinishedInvocationCleanup(String invocationId) {
+    completedInvocationsIds.addLast(invocationId);
+    while (completedInvocationsIds.size() > maxInvocationHistory) {
+      // Clean up invocations
+      // "preparations" is cleaned up when adding to "invocations"
+      // "stagingTokens" is cleaned up when the invocation finishes
+      invocations.remove(completedInvocationsIds.removeFirst());
+    }
+  }
 }
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/JobServerDriver.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/JobServerDriver.java
index d0061d2..d61162f 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/JobServerDriver.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/jobsubmission/JobServerDriver.java
@@ -63,7 +63,8 @@
             artifactStagingServer.getService().removeArtifacts(stagingSessionToken);
           }
         },
-        invoker);
+        invoker,
+        configuration.getMaxInvocationHistory());
   }
 
   /** Configuration for the jobServer. */
@@ -97,6 +98,9 @@
         handler = ExplicitBooleanOptionHandler.class)
     private boolean cleanArtifactsPerJob = true;
 
+    @Option(name = "--history-size", usage = "The maximum number of completed jobs to keep.")
+    private int maxInvocationHistory = 10;
+
     public String getHost() {
       return host;
     }
@@ -120,6 +124,10 @@
     public boolean isCleanArtifactsPerJob() {
       return cleanArtifactsPerJob;
     }
+
+    public int getMaxInvocationHistory() {
+      return maxInvocationHistory;
+    }
   }
 
   protected static ServerFactory createJobServerFactory(ServerConfiguration configuration) {
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobServiceTest.java
index a0a2bef..2a9b366 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobServiceTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/jobsubmission/InMemoryJobServiceTest.java
@@ -17,18 +17,21 @@
  */
 package org.apache.beam.runners.fnexecution.jobsubmission;
 
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.collection.IsCollectionWithSize.hasSize;
 import static org.hamcrest.core.Is.is;
 import static org.hamcrest.core.Is.isA;
 import static org.hamcrest.core.IsNull.notNullValue;
-import static org.junit.Assert.assertThat;
 import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.UUID;
+import java.util.function.Consumer;
 import org.apache.beam.model.jobmanagement.v1.JobApi;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -39,12 +42,14 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 /** Tests for {@link InMemoryJobService}. */
 @RunWith(JUnit4.class)
 public class InMemoryJobServiceTest {
+
   private static final String TEST_JOB_NAME = "test-job";
   private static final String TEST_JOB_ID = "test-job-id";
   private static final String TEST_RETRIEVAL_TOKEN = "test-staging-token";
@@ -57,6 +62,8 @@
           .setPipelineOptions(TEST_OPTIONS)
           .build();
 
+  private final int maxInvocationHistory = 3;
+
   Endpoints.ApiServiceDescriptor stagingServiceDescriptor;
   @Mock JobInvoker invoker;
   @Mock JobInvocation invocation;
@@ -186,7 +193,79 @@
     verify(invocation, times(1)).start();
   }
 
+  @Test
+  public void testInvocationCleanup() {
+    final int maxInvocationHistory = 3;
+    service =
+        InMemoryJobService.create(
+            stagingServiceDescriptor, session -> "token", null, invoker, maxInvocationHistory);
+
+    assertThat(getNumberOfInvocations(), is(0));
+
+    Job job1 = runJob();
+    assertThat(getNumberOfInvocations(), is(1));
+    Job job2 = runJob();
+    assertThat(getNumberOfInvocations(), is(2));
+    Job job3 = runJob();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory));
+
+    // All running invocations must be available and never be discarded
+    // even if they exceed the max history size
+    Job job4 = runJob();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory + 1));
+
+    // We need to have more than maxInvocationHistory completed jobs for the cleanup to trigger
+    job1.finish();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory + 1));
+    job2.finish();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory + 1));
+    job3.finish();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory + 1));
+
+    // The fourth finished job exceeds maxInvocationHistory and triggers the cleanup
+    job4.finish();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory));
+
+    // Run a new job after the cleanup
+    Job job5 = runJob();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory + 1));
+    job5.finish();
+    assertThat(getNumberOfInvocations(), is(maxInvocationHistory));
+  }
+
+  private Job runJob() {
+    when(invocation.getId()).thenReturn(UUID.randomUUID().toString());
+    prepareAndRunJob();
+    // Retrieve the state listener for this invocation
+    ArgumentCaptor<Consumer<JobApi.JobStateEvent>> stateListener =
+        ArgumentCaptor.forClass(Consumer.class);
+    verify(invocation, atLeastOnce()).addStateListener(stateListener.capture());
+    return new Job(stateListener.getValue());
+  }
+
+  private int getNumberOfInvocations() {
+    RecordingObserver<JobApi.GetJobsResponse> recorder = new RecordingObserver<>();
+    final JobApi.GetJobsRequest getJobsRequest = JobApi.GetJobsRequest.newBuilder().build();
+    service.getJobs(getJobsRequest, recorder);
+    return recorder.getValue().getJobInfoCount();
+  }
+
+  private static class Job {
+    private Consumer<JobApi.JobStateEvent> stateListener;
+
+    private Job(Consumer<JobApi.JobStateEvent> stateListener) {
+      this.stateListener = stateListener;
+    }
+
+    void finish() {
+      JobApi.JobStateEvent terminalEvent =
+          JobApi.JobStateEvent.newBuilder().setState(JobApi.JobState.Enum.DONE).build();
+      stateListener.accept(terminalEvent);
+    }
+  }
+
   private static class RecordingObserver<T> implements StreamObserver<T> {
+
     ArrayList<T> values = new ArrayList<>();
     Throwable error = null;
     boolean isCompleted = false;
@@ -206,6 +285,11 @@
       isCompleted = true;
     }
 
+    T getValue() {
+      assert values.size() == 1;
+      return values.get(0);
+    }
+
     boolean isSuccessful() {
       return isCompleted && error == null;
     }
diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaJobServerDriver.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaJobServerDriver.java
index 5788ad5..b820687 100644
--- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaJobServerDriver.java
+++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaJobServerDriver.java
@@ -101,7 +101,8 @@
           }
         },
         stagingSessionToken -> {},
-        jobInvoker);
+        jobInvoker,
+        InMemoryJobService.DEFAULT_MAX_INVOCATION_HISTORY);
   }
 
   public void run() throws Exception {