[BEAM-12634] JmsIO auto scaling feature (#15464)

* [BEAM-12634]: Provide generic Autoscaler interface,
- Provide DefaultAutoscaler implementation
- Add new autoscaler field to JmsIO

* [BEAM-12634]: Add Unit Tests for DefaultAutoScaler and mock custom Autoscaler

* [BEAM-12634] => Apply suggestions from code review

- Delete getSplitBacklogBytes method
- Add new comments

Co-authored-by: Lukasz Cwik <lcwik@google.com>

* [BEAM-12634]: Fix tests and some JavaDoc comments consecutively to code review

* [BEAM-12634]: Add comment on Autoscaler#getTotalBacklogBytes

* Update sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java

Co-authored-by: Lukasz Cwik <lcwik@google.com>
diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle
index 69cd88f..3f9d0c7 100644
--- a/sdks/java/io/jms/build.gradle
+++ b/sdks/java/io/jms/build.gradle
@@ -36,6 +36,7 @@
   testCompile library.java.activemq_kahadb_store
   testCompile library.java.activemq_client
   testCompile library.java.junit
+  testCompile library.java.mockito_core
   testRuntimeOnly library.java.slf4j_jdk14
   testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
 }
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
new file mode 100644
index 0000000..0e023d1
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.io.UnboundedSource;
+
+/**
+ * Enables users to specify their own `JMS` backlog reporters enabling {@link JmsIO} to report
+ * {@link UnboundedSource.UnboundedReader#getTotalBacklogBytes()}.
+ */
+public interface AutoScaler extends Serializable {
+
+  /** The {@link AutoScaler} is started when the {@link JmsIO.UnboundedJmsReader} is started. */
+  void start();
+
+  /**
+   * Returns the size of the backlog of unread data in the underlying data source represented by all
+   * splits of this source.
+   */
+  long getTotalBacklogBytes();
+
+  /** The {@link AutoScaler} is stopped when the {@link JmsIO.UnboundedJmsReader} is closed. */
+  void stop();
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
new file mode 100644
index 0000000..2b05cf6
--- /dev/null
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jms;
+
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
+
+/**
+ * Default implementation of {@link AutoScaler}. Returns {@link
+ * org.apache.beam.sdk.io.UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} as the default value.
+ */
+public class DefaultAutoscaler implements AutoScaler {
+  @Override
+  public void start() {}
+
+  @Override
+  public long getTotalBacklogBytes() {
+    return BACKLOG_UNKNOWN;
+  }
+
+  @Override
+  public void stop() {}
+}
diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
index 4999e10..9fa4492 100644
--- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
+++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java
@@ -196,6 +196,8 @@
 
     abstract @Nullable Coder<T> getCoder();
 
+    abstract @Nullable AutoScaler getAutoScaler();
+
     abstract Builder<T> builder();
 
     @AutoValue.Builder
@@ -218,6 +220,8 @@
 
       abstract Builder<T> setCoder(Coder<T> coder);
 
+      abstract Builder<T> setAutoScaler(AutoScaler autoScaler);
+
       abstract Read<T> build();
     }
 
@@ -344,6 +348,14 @@
       return builder().setCoder(coder).build();
     }
 
+    /**
+     * Sets the {@link AutoScaler} to use for reporting backlog during the execution of this source.
+     */
+    public Read<T> withAutoScaler(AutoScaler autoScaler) {
+      checkArgument(autoScaler != null, "autoScaler can not be null");
+      return builder().setAutoScaler(autoScaler).build();
+    }
+
     @Override
     public PCollection<T> expand(PBegin input) {
       checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
@@ -447,6 +459,7 @@
     private Connection connection;
     private Session session;
     private MessageConsumer consumer;
+    private AutoScaler autoScaler;
 
     private T currentMessage;
     private Instant currentTimestamp;
@@ -474,6 +487,12 @@
         }
         connection.start();
         this.connection = connection;
+        if (spec.getAutoScaler() == null) {
+          this.autoScaler = new DefaultAutoscaler();
+        } else {
+          this.autoScaler = spec.getAutoScaler();
+        }
+        this.autoScaler.start();
       } catch (Exception e) {
         throw new IOException("Error connecting to JMS", e);
       }
@@ -545,6 +564,11 @@
     }
 
     @Override
+    public long getTotalBacklogBytes() {
+      return this.autoScaler.getTotalBacklogBytes();
+    }
+
+    @Override
     public UnboundedSource<T, ?> getCurrentSource() {
       return source;
     }
@@ -565,6 +589,10 @@
           connection.close();
           connection = null;
         }
+        if (autoScaler != null) {
+          autoScaler.stop();
+          autoScaler = null;
+        }
       } catch (Exception e) {
         throw new IOException(e);
       }
diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
index c335f8a..a9f3c3f 100644
--- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
+++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java
@@ -17,12 +17,17 @@
  */
 package org.apache.beam.sdk.io.jms;
 
+import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsString;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import java.io.IOException;
 import java.lang.reflect.Proxy;
@@ -421,6 +426,50 @@
     CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark);
   }
 
+  @Test
+  public void testDefaultAutoscaler() throws IOException {
+    JmsIO.Read spec =
+        JmsIO.read()
+            .withConnectionFactory(connectionFactory)
+            .withUsername(USERNAME)
+            .withPassword(PASSWORD)
+            .withQueue(QUEUE);
+    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+    // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+    reader.start();
+    assertEquals(BACKLOG_UNKNOWN, reader.getSplitBacklogBytes());
+    assertEquals(BACKLOG_UNKNOWN, reader.getTotalBacklogBytes());
+    reader.close();
+  }
+
+  @Test
+  public void testCustomAutoscaler() throws IOException {
+    long excpectedTotalBacklogBytes = 1111L;
+
+    AutoScaler autoScaler = mock(DefaultAutoscaler.class);
+    when(autoScaler.getTotalBacklogBytes()).thenReturn(excpectedTotalBacklogBytes);
+    JmsIO.Read spec =
+        JmsIO.read()
+            .withConnectionFactory(connectionFactory)
+            .withUsername(USERNAME)
+            .withPassword(PASSWORD)
+            .withQueue(QUEUE)
+            .withAutoScaler(autoScaler);
+
+    JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
+    JmsIO.UnboundedJmsReader reader = source.createReader(null, null);
+
+    // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
+    reader.start();
+    verify(autoScaler, times(1)).start();
+    assertEquals(excpectedTotalBacklogBytes, reader.getTotalBacklogBytes());
+    verify(autoScaler, times(1)).getTotalBacklogBytes();
+    reader.close();
+    verify(autoScaler, times(1)).stop();
+  }
+
   private int count(String queue) throws Exception {
     Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
     connection.start();