KAFKA-14006: Parameterize WorkerConnectorTest suite (#12307)

Reviewers: Sagar Rao <sagarmeansocean@gmail.com>, Christo Lolov <lolovc@amazon.com>, Kvicii <kvicii.yu@gmail.com>, Mickael Maison <mickael.maison@gmail.com>
diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java
index cf07d5d..9ced632 100644
--- a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java
+++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java
@@ -18,6 +18,7 @@
 
 import org.apache.kafka.connect.connector.Connector;
 import org.apache.kafka.connect.errors.ConnectException;
+import org.apache.kafka.connect.health.ConnectorType;
 import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup;
 import org.apache.kafka.connect.runtime.isolation.Plugins;
 import org.apache.kafka.connect.sink.SinkConnector;
@@ -29,15 +30,22 @@
 import org.apache.kafka.connect.util.Callback;
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 
+
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 import org.mockito.ArgumentCaptor;
 import org.mockito.InOrder;
 import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+import org.mockito.quality.Strictness;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -54,7 +62,7 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
-@RunWith(MockitoJUnitRunner.StrictStubs.class)
+@RunWith(Parameterized.class)
 public class WorkerConnectorTest {
 
     private static final String VERSION = "1.1";
@@ -68,15 +76,41 @@
     public ConnectorConfig connectorConfig;
     public MockConnectMetrics metrics;
 
+    @Rule
+    public MockitoRule rule = MockitoJUnit.rule().strictness(Strictness.STRICT_STUBS);
+
     @Mock private Plugins plugins;
-    @Mock private SourceConnector sourceConnector;
-    @Mock private SinkConnector sinkConnector;
     @Mock private CloseableConnectorContext ctx;
     @Mock private ConnectorStatus.Listener listener;
-    @Mock private CloseableOffsetStorageReader offsetStorageReader;
-    @Mock private ConnectorOffsetBackingStore offsetStore;
     @Mock private ClassLoader classLoader;
-    private Connector connector;
+
+    private final ConnectorType connectorType;
+    private final Connector connector;
+    private final CloseableOffsetStorageReader offsetStorageReader;
+    private final ConnectorOffsetBackingStore offsetStore;
+
+    @Parameterized.Parameters
+    public static Collection<ConnectorType> parameters() {
+        return Arrays.asList(ConnectorType.SOURCE, ConnectorType.SINK);
+    }
+
+    public WorkerConnectorTest(ConnectorType connectorType) {
+        this.connectorType = connectorType;
+        switch (connectorType) {
+            case SINK:
+                this.connector = mock(SinkConnector.class);
+                this.offsetStorageReader = null;
+                this.offsetStore = null;
+                break;
+            case SOURCE:
+                this.connector = mock(SourceConnector.class);
+                this.offsetStorageReader = mock(CloseableOffsetStorageReader.class);
+                this.offsetStore = mock(ConnectorOffsetBackingStore.class);
+                break;
+            default:
+                throw new IllegalStateException("Unexpected connector type: " + connectorType);
+        }
+    }
 
     @Before
     public void setup() {
@@ -92,7 +126,6 @@
     @Test
     public void testInitializeFailure() {
         RuntimeException exception = new RuntimeException();
-        connector = sourceConnector;
 
         when(connector.version()).thenReturn(VERSION);
         doThrow(exception).when(connector).initialize(any());
@@ -113,13 +146,12 @@
     @Test
     public void testFailureIsFinalState() {
         RuntimeException exception = new RuntimeException();
-        connector = sinkConnector;
 
         when(connector.version()).thenReturn(VERSION);
         doThrow(exception).when(connector).initialize(any());
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
         assertFailedMetric(workerConnector);
@@ -140,15 +172,13 @@
 
     @Test
     public void testStartupAndShutdown() {
-        connector = sourceConnector;
-
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.shutdown();
@@ -166,14 +196,12 @@
 
     @Test
     public void testStartupAndPause() {
-        connector = sinkConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSinkMetric(workerConnector);
 
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
@@ -197,14 +225,13 @@
 
     @Test
     public void testStartupAndStop() {
-        connector = sinkConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSinkMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
 
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
@@ -228,8 +255,6 @@
 
     @Test
     public void testOnResume() {
-        connector = sourceConnector;
-
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
@@ -237,7 +262,7 @@
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange);
         assertPausedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
@@ -260,14 +285,13 @@
 
     @Test
     public void testStartupPaused() {
-        connector = sinkConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSinkMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange);
         assertPausedMetric(workerConnector);
         workerConnector.shutdown();
@@ -285,14 +309,13 @@
 
     @Test
     public void testStartupStopped() {
-        connector = sinkConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSinkMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STOPPED, onStateChange);
         assertStoppedMetric(workerConnector);
         workerConnector.shutdown();
@@ -311,16 +334,15 @@
     @Test
     public void testStartupFailure() {
         RuntimeException exception = new RuntimeException();
-        connector = sinkConnector;
 
         when(connector.version()).thenReturn(VERSION);
         doThrow(exception).when(connector).start(CONFIG);
 
         Callback<TargetState> onStateChange = mockCallback();
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, null, null, classLoader);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSinkMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertFailedMetric(workerConnector);
         workerConnector.shutdown();
@@ -339,7 +361,6 @@
     @Test
     public void testStopFailure() {
         RuntimeException exception = new RuntimeException();
-        connector = sourceConnector;
 
         when(connector.version()).thenReturn(VERSION);
 
@@ -352,7 +373,7 @@
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onFirstStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STOPPED, onSecondStateChange);
@@ -382,7 +403,6 @@
     @Test
     public void testShutdownFailure() {
         RuntimeException exception = new RuntimeException();
-        connector = sourceConnector;
 
         when(connector.version()).thenReturn(VERSION);
 
@@ -392,7 +412,7 @@
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.shutdown();
@@ -410,8 +430,6 @@
 
     @Test
     public void testTransitionStartedToStarted() {
-        connector = sourceConnector;
-
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
@@ -419,7 +437,7 @@
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
@@ -439,14 +457,13 @@
 
     @Test
     public void testTransitionPausedToPaused() {
-        connector = sourceConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange);
@@ -471,14 +488,13 @@
 
     @Test
     public void testTransitionStoppedToStopped() {
-        connector = sourceConnector;
         when(connector.version()).thenReturn(VERSION);
 
         Callback<TargetState> onStateChange = mockCallback();
         WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
-        assertInitializedSourceMetric(workerConnector);
+        assertInitializedMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STARTED, onStateChange);
         assertRunningMetric(workerConnector);
         workerConnector.doTransitionTo(TargetState.STOPPED, onStateChange);
@@ -503,13 +519,13 @@
 
     @Test
     public void testFailConnectorThatIsNeitherSourceNorSink() {
-        connector = mock(Connector.class);
-        when(connector.version()).thenReturn(VERSION);
-        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
+        Connector badConnector = mock(Connector.class);
+        when(badConnector.version()).thenReturn(VERSION);
+        WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, badConnector, connectorConfig, ctx, metrics, listener, offsetStorageReader, offsetStore, classLoader);
 
         workerConnector.initialize();
 
-        verify(connector).version();
+        verify(badConnector).version();
         ArgumentCaptor<Throwable> exceptionCapture = ArgumentCaptor.forClass(Throwable.class);
         verify(listener).onFailure(eq(CONNECTOR), exceptionCapture.capture());
         Throwable e = exceptionCapture.getValue();
@@ -557,12 +573,19 @@
         assertFalse(workerConnector.metrics().isRunning());
     }
 
-    protected void assertInitializedSinkMetric(WorkerConnector workerConnector) {
-        assertInitializedMetric(workerConnector, "sink");
-    }
-
-    protected void assertInitializedSourceMetric(WorkerConnector workerConnector) {
-        assertInitializedMetric(workerConnector, "source");
+    protected void assertInitializedMetric(WorkerConnector workerConnector) {
+        String expectedType;
+        switch (connectorType) {
+            case SINK:
+                expectedType = "sink";
+                break;
+            case SOURCE:
+                expectedType = "source";
+                break;
+            default:
+                throw new IllegalStateException("Unexpected connector type: " + connectorType);
+        }
+        assertInitializedMetric(workerConnector, expectedType);
     }
 
     protected void assertInitializedMetric(WorkerConnector workerConnector, String expectedType) {
@@ -588,10 +611,10 @@
 
     private void verifyInitialize() {
         verify(connector).version();
-        if (connector instanceof SourceConnector) {
+        if (connectorType == ConnectorType.SOURCE) {
             verify(offsetStore).start();
             verify(connector).initialize(any(SourceConnectorContext.class));
-        } else {
+        } else if (connectorType == ConnectorType.SINK) {
             verify(connector).initialize(any(SinkConnectorContext.class));
         }
     }
@@ -606,7 +629,7 @@
 
     private void verifyShutdown(int connectorStops, boolean clean, boolean started) {
         verify(ctx).close();
-        if (connector instanceof SourceConnector) {
+        if (connectorType == ConnectorType.SOURCE) {
             verify(offsetStorageReader).close();
             verify(offsetStore).stop();
         }