SAMZA-2533: Fix DiagnosticsUtil to use configured MetricsReporterFactory (#1369)

diff --git a/samza-core/src/main/java/org/apache/samza/util/DiagnosticsUtil.java b/samza-core/src/main/java/org/apache/samza/util/DiagnosticsUtil.java
index 7c83466..0b88fa0 100644
--- a/samza-core/src/main/java/org/apache/samza/util/DiagnosticsUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/util/DiagnosticsUtil.java
@@ -36,6 +36,7 @@
 import org.apache.samza.diagnostics.DiagnosticsManager;
 import org.apache.samza.job.model.JobModel;
 import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.metrics.MetricsReporterFactory;
 import org.apache.samza.metrics.reporter.Metrics;
 import org.apache.samza.metrics.reporter.MetricsHeader;
 import org.apache.samza.metrics.reporter.MetricsSnapshot;
@@ -51,7 +52,6 @@
 import org.apache.samza.system.SystemStream;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import scala.Option;
 
 
 public class DiagnosticsUtil {
@@ -100,54 +100,59 @@
       String jobId, JobModel jobModel, String containerId, Optional<String> execEnvContainerId, Config config) {
 
     JobConfig jobConfig = new JobConfig(config);
+    MetricsConfig metricsConfig = new MetricsConfig(config);
     Optional<Pair<DiagnosticsManager, MetricsSnapshotReporter>> diagnosticsManagerReporterPair = Optional.empty();
 
     if (jobConfig.getDiagnosticsEnabled()) {
 
+      // Diagnostics MetricReporter init
+      String diagnosticsReporterName = MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS;
+      String diagnosticsFactoryClassName = metricsConfig.getMetricsFactoryClass(diagnosticsReporterName)
+          .orElseThrow(() -> new SamzaException(
+              String.format("Diagnostics reporter %s missing .class config", diagnosticsReporterName)));
+      MetricsReporterFactory metricsReporterFactory =
+          ReflectionUtil.getObj(diagnosticsFactoryClassName, MetricsReporterFactory.class);
+      MetricsSnapshotReporter diagnosticsReporter =
+          (MetricsSnapshotReporter) metricsReporterFactory.getMetricsReporter(diagnosticsReporterName,
+              "samza-container-" + containerId, config);
+
+      // DiagnosticsManager init
       ClusterManagerConfig clusterManagerConfig = new ClusterManagerConfig(config);
       int containerMemoryMb = clusterManagerConfig.getContainerMemoryMb();
       int containerNumCores = clusterManagerConfig.getNumCores();
       long maxHeapSizeBytes = Runtime.getRuntime().maxMemory();
       int containerThreadPoolSize = jobConfig.getThreadPoolSize();
-
-      // Diagnostic stream, producer, and reporter related parameters
-      String diagnosticsReporterName = MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS;
-      MetricsConfig metricsConfig = new MetricsConfig(config);
-      int publishInterval = metricsConfig.getMetricsSnapshotReporterInterval(diagnosticsReporterName);
       String taskClassVersion = Util.getTaskClassVersion(config);
       String samzaVersion = Util.getSamzaVersion();
       String hostName = Util.getLocalHost().getHostName();
-      Optional<String> diagnosticsReporterStreamName = metricsConfig.getMetricsSnapshotReporterStream(diagnosticsReporterName);
+      Optional<String> diagnosticsReporterStreamName =
+          metricsConfig.getMetricsSnapshotReporterStream(diagnosticsReporterName);
 
       if (!diagnosticsReporterStreamName.isPresent()) {
-        throw new ConfigException("Missing required config: " + String.format(MetricsConfig.METRICS_SNAPSHOT_REPORTER_STREAM, diagnosticsReporterName));
+        throw new ConfigException(
+            "Missing required config: " + String.format(MetricsConfig.METRICS_SNAPSHOT_REPORTER_STREAM,
+                diagnosticsReporterName));
       }
-
       SystemStream diagnosticsSystemStream = StreamUtil.getSystemStreamFromNames(diagnosticsReporterStreamName.get());
 
+      // Create a SystemProducer for DiagnosticsManager. This producer is used by the DiagnosticsManager
+      // to write to the same stream as the MetricsSnapshotReporter called `diagnosticsreporter`.
       Optional<String> diagnosticsSystemFactoryName =
           new SystemConfig(config).getSystemFactory(diagnosticsSystemStream.getSystem());
       if (!diagnosticsSystemFactoryName.isPresent()) {
         throw new SamzaException("Missing factory in config for system " + diagnosticsSystemStream.getSystem());
       }
-
-      // Create a systemProducer for giving to diagnostic-reporter and diagnosticsManager
       SystemFactory systemFactory = ReflectionUtil.getObj(diagnosticsSystemFactoryName.get(), SystemFactory.class);
       SystemProducer systemProducer =
           systemFactory.getProducer(diagnosticsSystemStream.getSystem(), config, new MetricsRegistryMap());
+
       DiagnosticsManager diagnosticsManager =
           new DiagnosticsManager(jobName, jobId, jobModel.getContainers(), containerMemoryMb, containerNumCores,
-              new StorageConfig(config).getNumPersistentStores(), maxHeapSizeBytes, containerThreadPoolSize, containerId, execEnvContainerId.orElse(""),
-              taskClassVersion, samzaVersion, hostName, diagnosticsSystemStream, systemProducer,
+              new StorageConfig(config).getNumPersistentStores(), maxHeapSizeBytes, containerThreadPoolSize,
+              containerId, execEnvContainerId.orElse(""), taskClassVersion, samzaVersion, hostName,
+              diagnosticsSystemStream, systemProducer,
               Duration.ofMillis(new TaskConfig(config).getShutdownMs()), jobConfig.getAutosizingEnabled());
 
-      Option<String> blacklist = ScalaJavaUtil.JavaOptionals$.MODULE$.toRichOptional(
-          metricsConfig.getMetricsSnapshotReporterBlacklist(diagnosticsReporterName)).toOption();
-      MetricsSnapshotReporter diagnosticsReporter =
-          new MetricsSnapshotReporter(systemProducer, diagnosticsSystemStream, publishInterval, jobName, jobId,
-              "samza-container-" + containerId, taskClassVersion, samzaVersion, hostName, new MetricsSnapshotSerdeV2(),
-              blacklist, ScalaJavaUtil.toScalaFunction(() -> System.currentTimeMillis()));
-
       diagnosticsManagerReporterPair = Optional.of(new ImmutablePair<>(diagnosticsManager, diagnosticsReporter));
     }
 
diff --git a/samza-core/src/main/scala/org/apache/samza/diagnostics/DiagnosticsManager.java b/samza-core/src/main/scala/org/apache/samza/diagnostics/DiagnosticsManager.java
index 9131142..f77dab8 100644
--- a/samza-core/src/main/scala/org/apache/samza/diagnostics/DiagnosticsManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/diagnostics/DiagnosticsManager.java
@@ -142,6 +142,7 @@
     this.autosizingEnabled = autosizingEnabled;
 
     resetTime = Instant.now();
+    this.systemProducer.register(getClass().getSimpleName());
 
     try {
       ReflectionUtil.getObjWithArgs("org.apache.samza.logging.log4j.SimpleDiagnosticsAppender",
@@ -161,6 +162,7 @@
   }
 
   public void start() {
+    this.systemProducer.start();
     this.scheduler.scheduleWithFixedDelay(new DiagnosticsStreamPublisher(), 0, DEFAULT_PUBLISH_PERIOD.getSeconds(),
         TimeUnit.SECONDS);
   }
@@ -175,6 +177,7 @@
       LOG.warn("Unable to terminate scheduler");
       scheduler.shutdownNow();
     }
+    this.systemProducer.stop();
   }
 
   public void addExceptionEvent(DiagnosticsExceptionEvent diagnosticsExceptionEvent) {
diff --git a/samza-core/src/test/java/org/apache/samza/diagnostics/TestDiagnosticsManager.java b/samza-core/src/test/java/org/apache/samza/diagnostics/TestDiagnosticsManager.java
index d21bb4b..a7f022e 100644
--- a/samza-core/src/test/java/org/apache/samza/diagnostics/TestDiagnosticsManager.java
+++ b/samza-core/src/test/java/org/apache/samza/diagnostics/TestDiagnosticsManager.java
@@ -44,6 +44,7 @@
 public class TestDiagnosticsManager {
   private DiagnosticsManager diagnosticsManager;
   private MockSystemProducer mockSystemProducer;
+  private ScheduledExecutorService mockExecutorService;
   private SystemStream diagnosticsSystemStream = new SystemStream("kafka", "test stream");
 
   private String jobName = "Testjob";
@@ -68,7 +69,7 @@
     mockSystemProducer = new MockSystemProducer();
 
     // Mocked scheduled executor service which does a synchronous run() on scheduling
-    ScheduledExecutorService mockExecutorService = Mockito.mock(ScheduledExecutorService.class);
+    mockExecutorService = Mockito.mock(ScheduledExecutorService.class);
     Mockito.when(mockExecutorService.scheduleWithFixedDelay(Mockito.any(), Mockito.anyLong(), Mockito.anyLong(),
         Mockito.eq(TimeUnit.SECONDS))).thenAnswer(invocation -> {
             ((Runnable) invocation.getArguments()[0]).run();
@@ -88,6 +89,63 @@
   }
 
   @Test
+  public void testDiagnosticsManagerStart() {
+    SystemProducer mockSystemProducer = Mockito.mock(SystemProducer.class);
+    DiagnosticsManager diagnosticsManager =
+        new DiagnosticsManager(jobName, jobId, containerModels, containerMb, containerNumCores, numPersistentStores,
+            maxHeapSize, containerThreadPoolSize, "0", executionEnvContainerId, taskClassVersion, samzaVersion,
+            hostname, diagnosticsSystemStream, mockSystemProducer, Duration.ofSeconds(1), mockExecutorService,
+            autosizingEnabled);
+
+    diagnosticsManager.start();
+
+    Mockito.verify(mockSystemProducer, Mockito.times(1)).start();
+    Mockito.verify(mockExecutorService, Mockito.times(1))
+        .scheduleWithFixedDelay(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.anyLong(),
+            Mockito.any(TimeUnit.class));
+  }
+
+  @Test
+  public void testDiagnosticsManagerStop() throws InterruptedException {
+    SystemProducer mockSystemProducer = Mockito.mock(SystemProducer.class);
+    Mockito.when(mockExecutorService.isTerminated()).thenReturn(true);
+    Duration terminationDuration = Duration.ofSeconds(1);
+    DiagnosticsManager diagnosticsManager =
+        new DiagnosticsManager(jobName, jobId, containerModels, containerMb, containerNumCores, numPersistentStores,
+            maxHeapSize, containerThreadPoolSize, "0", executionEnvContainerId, taskClassVersion, samzaVersion,
+            hostname, diagnosticsSystemStream, mockSystemProducer, terminationDuration, mockExecutorService,
+            autosizingEnabled);
+
+    diagnosticsManager.stop();
+
+    Mockito.verify(mockExecutorService, Mockito.times(1)).shutdown();
+    Mockito.verify(mockExecutorService, Mockito.times(1))
+        .awaitTermination(terminationDuration.toMillis(), TimeUnit.MILLISECONDS);
+    Mockito.verify(mockExecutorService, Mockito.never()).shutdownNow();
+    Mockito.verify(mockSystemProducer, Mockito.times(1)).stop();
+  }
+
+  @Test
+  public void testDiagnosticsManagerForceStop() throws InterruptedException {
+    SystemProducer mockSystemProducer = Mockito.mock(SystemProducer.class);
+    Mockito.when(mockExecutorService.isTerminated()).thenReturn(false);
+    Duration terminationDuration = Duration.ofSeconds(1);
+    DiagnosticsManager diagnosticsManager =
+        new DiagnosticsManager(jobName, jobId, containerModels, containerMb, containerNumCores, numPersistentStores,
+            maxHeapSize, containerThreadPoolSize, "0", executionEnvContainerId, taskClassVersion, samzaVersion,
+            hostname, diagnosticsSystemStream, mockSystemProducer, terminationDuration, mockExecutorService,
+            autosizingEnabled);
+
+    diagnosticsManager.stop();
+
+    Mockito.verify(mockExecutorService, Mockito.times(1)).shutdown();
+    Mockito.verify(mockExecutorService, Mockito.times(1))
+        .awaitTermination(terminationDuration.toMillis(), TimeUnit.MILLISECONDS);
+    Mockito.verify(mockExecutorService, Mockito.times(1)).shutdownNow();
+    Mockito.verify(mockSystemProducer, Mockito.times(1)).stop();
+  }
+
+  @Test
   public void testDiagnosticsStreamFirstMessagePublish() {
     // invoking start will do a syncrhonous publish to the stream because of our mocked scheduled exec service
     this.diagnosticsManager.start();
diff --git a/samza-core/src/test/java/org/apache/samza/util/TestDiagnosticsUtil.java b/samza-core/src/test/java/org/apache/samza/util/TestDiagnosticsUtil.java
new file mode 100644
index 0000000..f817c47
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/util/TestDiagnosticsUtil.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.util;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.config.MetricsConfig;
+import org.apache.samza.config.SystemConfig;
+import org.apache.samza.diagnostics.DiagnosticsManager;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.MetricsReporterFactory;
+import org.apache.samza.metrics.reporter.MetricsSnapshotReporter;
+import org.apache.samza.system.SystemFactory;
+import org.apache.samza.system.SystemProducer;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.api.mockito.PowerMockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
+import static org.mockito.Mockito.*;
+
+
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({ReflectionUtil.class})
+public class TestDiagnosticsUtil {
+
+  private static final String STREAM_NAME = "someStreamName";
+  private static final String JOB_NAME = "someJob";
+  private static final String JOB_ID = "someId";
+  private static final String CONTAINER_ID = "someContainerId";
+  private static final String ENV_ID = "someEnvID";
+  public static final String REPORTER_FACTORY = "org.apache.samza.metrics.reporter.MetricsSnapshotReporterFactory";
+  public static final String SYSTEM_FACTORY = "com.foo.system.SomeSystemFactory";
+
+  @Test
+  public void testBuildDiagnosticsManagerReturnsConfiguredReporter() {
+    Config config = new MapConfig(buildTestConfigs());
+    JobModel mockJobModel = mock(JobModel.class);
+    SystemFactory systemFactory = mock(SystemFactory.class);
+    SystemProducer mockProducer = mock(SystemProducer.class);
+    MetricsReporterFactory metricsReporterFactory = mock(MetricsReporterFactory.class);
+    MetricsSnapshotReporter mockReporter = mock(MetricsSnapshotReporter.class);
+
+    when(systemFactory.getProducer(anyString(), any(Config.class), any(MetricsRegistry.class))).thenReturn(mockProducer);
+    when(metricsReporterFactory.getMetricsReporter(anyString(), anyString(), any(Config.class))).thenReturn(
+        mockReporter);
+    PowerMockito.mockStatic(ReflectionUtil.class);
+    when(ReflectionUtil.getObj(REPORTER_FACTORY, MetricsReporterFactory.class)).thenReturn(metricsReporterFactory);
+    when(ReflectionUtil.getObj(SYSTEM_FACTORY, SystemFactory.class)).thenReturn(systemFactory);
+
+    Optional<Pair<DiagnosticsManager, MetricsSnapshotReporter>> managerReporterPair =
+        DiagnosticsUtil.buildDiagnosticsManager(JOB_NAME, JOB_ID, mockJobModel, CONTAINER_ID, Optional.of(ENV_ID),
+            config);
+
+    Assert.assertTrue(managerReporterPair.isPresent());
+    Assert.assertEquals(mockReporter, managerReporterPair.get().getValue());
+  }
+
+  private Map<String, String> buildTestConfigs() {
+    Map<String, String> configs = new HashMap<>();
+    configs.put(JobConfig.JOB_DIAGNOSTICS_ENABLED, "true");
+    configs.put(String.format(MetricsConfig.METRICS_REPORTER_FACTORY,
+        MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS), REPORTER_FACTORY);
+    configs.put(String.format(MetricsConfig.METRICS_SNAPSHOT_REPORTER_STREAM,
+        MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS),
+        MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS + "." + STREAM_NAME);
+    configs.put(String.format(SystemConfig.SYSTEM_FACTORY_FORMAT, MetricsConfig.METRICS_SNAPSHOT_REPORTER_NAME_FOR_DIAGNOSTICS),
+        SYSTEM_FACTORY);
+
+    return configs;
+  }
+}