Fix new flaky test (#1306)

* SAMZA-2476:Fix More flaky tests in TestAzureBlobOutputStream

* add JIRA in comments

* Trigger Travis Build
diff --git a/samza-azure/src/main/java/org/apache/samza/system/azureblob/avro/AzureBlobOutputStream.java b/samza-azure/src/main/java/org/apache/samza/system/azureblob/avro/AzureBlobOutputStream.java
index 21dc2a9..9db15a3 100644
--- a/samza-azure/src/main/java/org/apache/samza/system/azureblob/avro/AzureBlobOutputStream.java
+++ b/samza-azure/src/main/java/org/apache/samza/system/azureblob/avro/AzureBlobOutputStream.java
@@ -179,10 +179,7 @@
           blobAsyncClient.getBlobUrl().toString(), pendingUpload.size());
       throw new AzureException(msg, e);
     } finally {
-      blockList.clear();
-      pendingUpload.stream().forEach(future -> future.cancel(true));
-      pendingUpload.clear();
-      isClosed = true;
+      clearAndMarkClosed();
     }
   }
 
@@ -233,6 +230,21 @@
     blobAsyncClient.commitBlockListWithResponse(blockList, null, blobMetadata, null, null).block();
   }
 
+  // SAMZA-2476 stubbing BlockBlobAsyncClient.stageBlock was causing flaky tests.
+  @VisibleForTesting
+  void stageBlock(String blockIdEncoded, ByteBuffer outputStream, int blockSize) {
+    blobAsyncClient.stageBlock(blockIdEncoded, Flux.just(outputStream), blockSize).block();
+  }
+
+  // blockList cleared makes it hard to test close
+  @VisibleForTesting
+  void clearAndMarkClosed() {
+    blockList.clear();
+    pendingUpload.stream().forEach(future -> future.cancel(true));
+    pendingUpload.clear();
+    isClosed = true;
+  }
+
   /**
    * This api will async upload the outputstream into block using stageBlocks,
    * reint outputstream
@@ -275,7 +287,7 @@
             LOG.info("{} Upload block start for blob: {} for block size:{}.", blobAsyncClient.getBlobUrl().toString(), blockId, blockSize);
             metrics.updateAzureUploadMetrics();
             // StageBlock generates exception on Failure.
-            blobAsyncClient.stageBlock(blockIdEncoded, Flux.just(outputStream), blockSize).block();
+            stageBlock(blockIdEncoded, outputStream, blockSize);
             break;
           } catch (Exception e) {
             attemptCount += 1;
diff --git a/samza-azure/src/test/java/org/apache/samza/system/azureblob/avro/TestAzureBlobOutputStream.java b/samza-azure/src/test/java/org/apache/samza/system/azureblob/avro/TestAzureBlobOutputStream.java
index 34c5a4b..d635693 100644
--- a/samza-azure/src/test/java/org/apache/samza/system/azureblob/avro/TestAzureBlobOutputStream.java
+++ b/samza-azure/src/test/java/org/apache/samza/system/azureblob/avro/TestAzureBlobOutputStream.java
@@ -19,8 +19,7 @@
 
 package org.apache.samza.system.azureblob.avro;
 
-import com.azure.core.http.rest.SimpleResponse;
-import com.azure.core.implementation.util.FluxUtil;
+import java.util.Arrays;
 import org.apache.samza.AzureException;
 import org.apache.samza.system.azureblob.compression.Compression;
 import org.apache.samza.system.azureblob.producer.AzureBlobWriterMetrics;
@@ -42,16 +41,16 @@
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
-import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
 import static org.mockito.Mockito.any;
-import static org.mockito.Mockito.anyList;
+import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.anyLong;
 import static org.mockito.Mockito.anyMap;
 import static org.mockito.Mockito.anyString;
-import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -83,10 +82,6 @@
     mockByteArrayOutputStream = spy(new ByteArrayOutputStream(THRESHOLD));
 
     mockBlobAsyncClient = PowerMockito.mock(BlockBlobAsyncClient.class);
-    when(mockBlobAsyncClient.stageBlock(anyString(), any(), anyLong())).thenReturn(
-        Mono.just(new SimpleResponse(null, 200, null, null)).flatMap(FluxUtil::toMono));
-    when(mockBlobAsyncClient.commitBlockListWithResponse(any(), any(), any(), any(), any())).thenReturn(
-        Mono.just(new SimpleResponse(null, 200, null, null)));
 
     when(mockBlobAsyncClient.getBlobUrl()).thenReturn("https://samza.blob.core.windows.net/fake-blob-url");
 
@@ -97,13 +92,17 @@
 
     azureBlobOutputStream = spy(new AzureBlobOutputStream(mockBlobAsyncClient, threadPool, mockMetrics,
         60000, THRESHOLD, mockByteArrayOutputStream, mockCompression));
+
+    doNothing().when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
+    doNothing().when(azureBlobOutputStream).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
+    doNothing().when(azureBlobOutputStream).clearAndMarkClosed();
   }
 
   @Test
   public void testWrite() {
     byte[] b = new byte[THRESHOLD - 10];
     azureBlobOutputStream.write(b, 0, THRESHOLD - 10);
-    verify(mockBlobAsyncClient, never()).stageBlock(any(), any(), anyLong()); // since size of byte[] written is less than threshold
+    verify(azureBlobOutputStream, never()).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
     verify(mockMetrics).updateWriteByteMetrics(THRESHOLD - 10);
     verify(mockMetrics, never()).updateAzureUploadMetrics();
   }
@@ -127,12 +126,12 @@
     // invoked 2 times for the data which is 2*threshold
     verify(mockCompression).compress(largeRecordFirstHalf);
     verify(mockCompression).compress(largeRecordSecondHalf);
-    ArgumentCaptor<Flux> argument0 = ArgumentCaptor.forClass(Flux.class);
-    ArgumentCaptor<Flux> argument1 = ArgumentCaptor.forClass(Flux.class);
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(0)), argument0.capture(), eq((long) compressB1.length));
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(1)), argument1.capture(), eq((long) compressB2.length));
-    Assert.assertEquals(ByteBuffer.wrap(compressB1), argument0.getAllValues().get(0).blockFirst());
-    Assert.assertEquals(ByteBuffer.wrap(compressB2), argument1.getAllValues().get(0).blockFirst());
+    ArgumentCaptor<ByteBuffer> argument0 = ArgumentCaptor.forClass(ByteBuffer.class);
+    ArgumentCaptor<ByteBuffer> argument1 = ArgumentCaptor.forClass(ByteBuffer.class);
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(0)), argument0.capture(), eq((int) compressB1.length));
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(1)), argument1.capture(), eq((int) compressB2.length));
+    Assert.assertEquals(ByteBuffer.wrap(compressB1), argument0.getAllValues().get(0));
+    Assert.assertEquals(ByteBuffer.wrap(compressB2), argument1.getAllValues().get(0));
     verify(mockMetrics).updateWriteByteMetrics(2 * THRESHOLD);
     verify(mockMetrics, times(2)).updateAzureUploadMetrics();
   }
@@ -163,15 +162,15 @@
     verify(mockCompression, times(2)).compress(fullBlock);
     verify(mockCompression).compress(halfBlock);
 
-    ArgumentCaptor<Flux> argument = ArgumentCaptor.forClass(Flux.class);
-    ArgumentCaptor<Flux> argument2 = ArgumentCaptor.forClass(Flux.class);
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(0)), argument.capture(), eq((long) fullBlockCompressedByte.length));
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(1)), argument.capture(), eq((long) fullBlockCompressedByte.length));
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(2)), argument2.capture(), eq((long) halfBlockCompressedByte.length));
-    argument.getAllValues().forEach(flux -> {
-        Assert.assertEquals(ByteBuffer.wrap(fullBlockCompressedByte), flux.blockFirst());
+    ArgumentCaptor<ByteBuffer> argument = ArgumentCaptor.forClass(ByteBuffer.class);
+    ArgumentCaptor<ByteBuffer> argument2 = ArgumentCaptor.forClass(ByteBuffer.class);
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(0)), argument.capture(), eq((int) fullBlockCompressedByte.length));
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(1)), argument.capture(), eq((int) fullBlockCompressedByte.length));
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(2)), argument2.capture(), eq((int) halfBlockCompressedByte.length));
+    argument.getAllValues().forEach(byteBuffer -> {
+        Assert.assertEquals(ByteBuffer.wrap(fullBlockCompressedByte), byteBuffer);
       });
-    Assert.assertEquals(ByteBuffer.wrap(halfBlockCompressedByte), ((Flux) argument2.getValue()).blockFirst());
+    Assert.assertEquals(ByteBuffer.wrap(halfBlockCompressedByte), argument2.getAllValues().get(0));
     verify(mockMetrics, times(3)).updateAzureUploadMetrics();
   }
 
@@ -184,9 +183,9 @@
     azureBlobOutputStream.close();
 
     verify(mockCompression).compress(BYTES);
-    ArgumentCaptor<Flux> argument = ArgumentCaptor.forClass(Flux.class);
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded(0)), argument.capture(), eq((long) COMPRESSED_BYTES.length)); // since size of byte[] written is less than threshold
-    Assert.assertEquals(ByteBuffer.wrap(COMPRESSED_BYTES), ((Flux) argument.getValue()).blockFirst());
+    ArgumentCaptor<ByteBuffer> argument = ArgumentCaptor.forClass(ByteBuffer.class);
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(0)), argument.capture(), eq((int) COMPRESSED_BYTES.length)); // since size of byte[] written is less than threshold
+    Assert.assertEquals(ByteBuffer.wrap(COMPRESSED_BYTES), argument.getAllValues().get(0));
     verify(mockMetrics, times(2)).updateWriteByteMetrics(THRESHOLD / 2);
     verify(mockMetrics, times(1)).updateAzureUploadMetrics();
   }
@@ -209,17 +208,15 @@
     String blockId = String.format("%05d", blockNum);
     String blockIdEncoded = Base64.getEncoder().encodeToString(blockId.getBytes());
 
-    doAnswer(invocation -> {
-        ArrayList<String> blockListArg = (ArrayList<String>) invocation.getArguments()[0];
-        String blockIdArg = (String) blockListArg.toArray()[0];
-        Assert.assertEquals(blockIdEncoded, blockIdArg);
-        Map<String, String> blobMetadata = (Map<String, String>) invocation.getArguments()[1];
-        Assert.assertEquals(blobMetadata.get(AzureBlobOutputStream.BLOB_RAW_SIZE_BYTES_METADATA), Long.toString(THRESHOLD));
-        return null;
-      }).when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
-
     azureBlobOutputStream.close();
     verify(mockMetrics).updateAzureCommitMetrics();
+
+    ArgumentCaptor<ArrayList> blockListArgument = ArgumentCaptor.forClass(ArrayList.class);
+    ArgumentCaptor<Map> blobMetadataArg = ArgumentCaptor.forClass(Map.class);
+    verify(azureBlobOutputStream).commitBlob(blockListArgument.capture(), blobMetadataArg.capture());
+    Assert.assertEquals(Arrays.asList(blockIdEncoded), blockListArgument.getAllValues().get(0));
+    Map<String, String> blobMetadata = (Map<String, String>) blobMetadataArg.getAllValues().get(0);
+    Assert.assertEquals(blobMetadata.get(AzureBlobOutputStream.BLOB_RAW_SIZE_BYTES_METADATA), Long.toString(THRESHOLD));
   }
 
   @Test
@@ -234,24 +231,27 @@
     int blockNum1 = 1;
     String blockId1 = String.format("%05d", blockNum1);
     String blockIdEncoded1 = Base64.getEncoder().encodeToString(blockId1.getBytes());
-
-    doAnswer(invocation -> {
-        ArrayList<String> blockListArg = (ArrayList<String>) invocation.getArguments()[0];
-        String blockIdArg = (String) blockListArg.toArray()[0];
-        Assert.assertEquals(blockIdEncoded, blockIdArg);
-        Map<String, String> blobMetadata = (Map<String, String>) invocation.getArguments()[1];
-        Assert.assertEquals(blobMetadata.get(AzureBlobOutputStream.BLOB_RAW_SIZE_BYTES_METADATA), Long.toString(2 * THRESHOLD));
-        return null;
-      }).when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
     azureBlobOutputStream.close();
     verify(mockMetrics).updateAzureCommitMetrics();
+    ArgumentCaptor<ArrayList> blockListArgument = ArgumentCaptor.forClass(ArrayList.class);
+    ArgumentCaptor<Map> blobMetadataArg = ArgumentCaptor.forClass(Map.class);
+    verify(azureBlobOutputStream).commitBlob(blockListArgument.capture(), blobMetadataArg.capture());
+    Assert.assertEquals(blockIdEncoded, blockListArgument.getAllValues().get(0).toArray()[0]);
+    Assert.assertEquals(blockIdEncoded1, blockListArgument.getAllValues().get(0).toArray()[1]);
+    Map<String, String> blobMetadata = (Map<String, String>) blobMetadataArg.getAllValues().get(0);
+    Assert.assertEquals(blobMetadata.get(AzureBlobOutputStream.BLOB_RAW_SIZE_BYTES_METADATA), Long.toString(2 * THRESHOLD));
   }
 
   @Test(expected = AzureException.class)
   public void testCloseFailed() {
-    when(mockBlobAsyncClient.commitBlockListWithResponse(anyList(), any(), any(), any(), any()))
-        .thenReturn(Mono.error(new Exception("Test Failed")));
 
+    azureBlobOutputStream = spy(new AzureBlobOutputStream(mockBlobAsyncClient, threadPool, mockMetrics,
+        60000, THRESHOLD, mockByteArrayOutputStream, mockCompression));
+
+    //doNothing().when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
+    doNothing().when(azureBlobOutputStream).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
+    doThrow(new IllegalArgumentException("Test Failed")).when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
+    doNothing().when(azureBlobOutputStream).clearAndMarkClosed();
     byte[] b = new byte[100];
     azureBlobOutputStream.write(b, 0, THRESHOLD);
     azureBlobOutputStream.close();
@@ -276,17 +276,24 @@
     String blockIdEncoded = Base64.getEncoder().encodeToString(blockId.getBytes());
 
     verify(mockCompression).compress(BYTES);
-    ArgumentCaptor<Flux> argument = ArgumentCaptor.forClass(Flux.class);
-    verify(mockBlobAsyncClient).stageBlock(eq(blockIdEncoded), argument.capture(), eq((long) COMPRESSED_BYTES.length)); // since size of byte[] written is less than threshold
-    Assert.assertEquals(ByteBuffer.wrap(COMPRESSED_BYTES), ((Flux) argument.getValue()).blockFirst());
+    ArgumentCaptor<ByteBuffer> argument = ArgumentCaptor.forClass(ByteBuffer.class);
+    // since size of byte[] written is less than threshold
+    verify(azureBlobOutputStream).stageBlock(eq(blockIdEncoded(0)), argument.capture(), eq((int) COMPRESSED_BYTES.length));
+    Assert.assertEquals(ByteBuffer.wrap(COMPRESSED_BYTES), argument.getAllValues().get(0));
     verify(mockMetrics).updateAzureUploadMetrics();
   }
 
   @Test (expected = AzureException.class)
   public void testFlushFailed() throws IOException {
+    azureBlobOutputStream = spy(new AzureBlobOutputStream(mockBlobAsyncClient, threadPool, mockMetrics,
+        60000, THRESHOLD, mockByteArrayOutputStream, mockCompression));
+
+    doNothing().when(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
+    //doNothing().when(azureBlobOutputStream).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
+    doThrow(new IllegalArgumentException("Test Failed")).when(azureBlobOutputStream).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
+    doNothing().when(azureBlobOutputStream).clearAndMarkClosed();
+
     azureBlobOutputStream.write(BYTES);
-    when(mockBlobAsyncClient.stageBlock(anyString(), any(), anyLong()))
-           .thenReturn(Mono.error(new Exception("Test Failed")));
 
     azureBlobOutputStream.flush();
     // azureBlobOutputStream.close waits on the CompletableFuture which does the actual stageBlock in uploadBlockAsync
@@ -315,14 +322,14 @@
     // mockByteArrayOutputStream.close called only once during releaseBuffer and not during azureBlobOutputStream.close
     verify(mockByteArrayOutputStream).close();
     // azureBlobOutputStream.close still commits the list of blocks.
-    verify(mockBlobAsyncClient).commitBlockListWithResponse(any(), any(), any(), any(), any());
+    verify(azureBlobOutputStream).commitBlob(any(ArrayList.class), anyMap());
   }
 
   @Test
   public void testFlushAfterReleaseBuffer() throws Exception {
     azureBlobOutputStream.releaseBuffer();
     azureBlobOutputStream.flush(); // becomes no-op after release buffer
-    verify(mockBlobAsyncClient, never()).stageBlock(anyString(), any(), anyLong());
+    verify(azureBlobOutputStream, never()).stageBlock(anyString(), any(ByteBuffer.class), anyInt());
   }
 
   @Test