Revise Filter in elasticjob-restful (#1762)

diff --git a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
index ed618ce..65fb24b 100644
--- a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
+++ b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
@@ -32,7 +32,6 @@
      * @param httpRequest  HTTP request
      * @param httpResponse HTTP response
      * @param filterChain  filter chain
-     * @return pass through the filter if true, else do response
      */
-    boolean doFilter(FullHttpRequest httpRequest, FullHttpResponse httpResponse, FilterChain filterChain);
+    void doFilter(FullHttpRequest httpRequest, FullHttpResponse httpResponse, FilterChain filterChain);
 }
diff --git a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
index 927ca4c..13f493a 100644
--- a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
+++ b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
@@ -39,7 +39,9 @@
     
     private int current;
     
-    private boolean finished;
+    private boolean passedThrough;
+    
+    private boolean replied;
     
     public DefaultFilterChain(final List<Filter> filterInstances, final ChannelHandlerContext ctx, final HandleContext<?> handleContext) {
         filters = filterInstances.toArray(new Filter[0]);
@@ -49,23 +51,22 @@
     
     @Override
     public void next(final FullHttpRequest httpRequest) {
-        Preconditions.checkState(!finished, "FilterChain has already finished.");
+        Preconditions.checkState(!passedThrough && !replied, "FilterChain has already finished.");
         if (current < filters.length) {
-            Filter currentFilter = filters[current++];
-            boolean passThrough = currentFilter.doFilter(httpRequest, handleContext.getHttpResponse(), this);
-            if (!passThrough) {
-                finished = true;
+            filters[current++].doFilter(httpRequest, handleContext.getHttpResponse(), this);
+            if (!passedThrough && !replied) {
                 doResponse();
             }
             return;
         }
-        finished = true;
+        passedThrough = true;
         ctx.fireChannelRead(handleContext);
     }
     
     private void doResponse() {
         try {
             ctx.writeAndFlush(handleContext.getHttpResponse());
+            replied = true;
         } finally {
             ReferenceCountUtil.release(handleContext.getHttpRequest());
         }
diff --git a/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java b/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
index 067fdd1..85a2586 100644
--- a/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
+++ b/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
@@ -37,9 +37,10 @@
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 @RunWith(MockitoJUnitRunner.class)
 public final class DefaultFilterChainTest {
@@ -53,15 +54,6 @@
     @Mock
     private FullHttpResponse httpResponse;
     
-    @Mock
-    private Filter firstFilter;
-    
-    @Mock
-    private Filter secondFilter;
-    
-    @Mock
-    private Filter thirdFilter;
-    
     private HandleContext<Handler> handleContext;
     
     @Before
@@ -75,56 +67,61 @@
         filterChain.next(httpRequest);
         verify(ctx, never()).writeAndFlush(httpResponse);
         verify(ctx).fireChannelRead(handleContext);
+        assertTrue(isPassedThrough(filterChain));
+        assertFalse(isReplied(filterChain));
     }
     
     @Test
     public void assertWithSingleFilterPassed() {
-        DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(firstFilter), ctx, handleContext);
-        when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
+        Filter passableFilter = spy(new PassableFilter());
+        DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(passableFilter), ctx, handleContext);
         filterChain.next(httpRequest);
-        verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
-        filterChain.next(httpRequest);
+        verify(passableFilter).doFilter(httpRequest, httpResponse, filterChain);
         verify(ctx).fireChannelRead(handleContext);
         verify(ctx, never()).writeAndFlush(httpResponse);
+        assertTrue(isPassedThrough(filterChain));
+        assertFalse(isReplied(filterChain));
     }
     
     @Test
     public void assertWithSingleFilterDoResponse() {
-        DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(firstFilter), ctx, handleContext);
+        Filter impassableFilter = mock(Filter.class);
+        DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(impassableFilter), ctx, handleContext);
         filterChain.next(httpRequest);
-        verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
+        verify(impassableFilter).doFilter(httpRequest, httpResponse, filterChain);
         verify(ctx, never()).fireChannelRead(any(HandleContext.class));
         verify(ctx).writeAndFlush(httpResponse);
+        assertTrue(isReplied(filterChain));
+        assertFalse(isPassedThrough(filterChain));
     }
     
     @Test
     public void assertWithThreeFiltersPassed() {
+        Filter firstFilter = spy(new PassableFilter());
+        Filter secondFilter = spy(new PassableFilter());
+        Filter thirdFilter = spy(new PassableFilter());
         DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter, thirdFilter), ctx, handleContext);
-        when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
         filterChain.next(httpRequest);
         verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
-        when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
-        filterChain.next(httpRequest);
         verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
-        when(thirdFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
-        filterChain.next(httpRequest);
         verify(thirdFilter).doFilter(httpRequest, httpResponse, filterChain);
-        filterChain.next(httpRequest);
+        assertTrue(isPassedThrough(filterChain));
+        assertFalse(isReplied(filterChain));
         verify(ctx).fireChannelRead(handleContext);
         verify(ctx, never()).writeAndFlush(any(FullHttpResponse.class));
     }
     
     @Test
     public void assertWithThreeFiltersDoResponseByTheSecond() {
+        Filter firstFilter = spy(new PassableFilter());
+        Filter secondFilter = mock(Filter.class);
+        Filter thirdFilter = spy(new PassableFilter());
         DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter, thirdFilter), ctx, handleContext);
-        when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
         filterChain.next(httpRequest);
         verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
-        when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(false);
-        assertFalse(isFinished(filterChain));
-        filterChain.next(httpRequest);
         verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
-        assertTrue(isFinished(filterChain));
+        assertFalse(isPassedThrough(filterChain));
+        assertTrue(isReplied(filterChain));
         verify(thirdFilter, never()).doFilter(httpRequest, httpResponse, filterChain);
         verify(ctx, never()).fireChannelRead(any(HandleContext.class));
         verify(ctx).writeAndFlush(httpResponse);
@@ -138,23 +135,37 @@
     }
     
     @Test(expected = IllegalStateException.class)
-    public void assertInvokeFinishedFilterChainWithTwoFilters() {
+    public void assertInvokePassedThroughFilterChainWithTwoFilters() {
+        Filter firstFilter = spy(new PassableFilter());
+        Filter secondFilter = spy(new PassableFilter());
         DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter), ctx, handleContext);
-        when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
         filterChain.next(httpRequest);
         verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
-        when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
-        filterChain.next(httpRequest);
         verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
-        filterChain.next(httpRequest);
         verify(ctx).fireChannelRead(handleContext);
         filterChain.next(httpRequest);
     }
     
+    private boolean isPassedThrough(final DefaultFilterChain filterChain) {
+        return getBoolean(filterChain, "passedThrough");
+    }
+    
+    private boolean isReplied(final DefaultFilterChain filterChain) {
+        return getBoolean(filterChain, "replied");
+    }
+    
     @SneakyThrows
-    private boolean isFinished(final DefaultFilterChain filterChain) {
-        Field field = DefaultFilterChain.class.getDeclaredField("finished");
+    private boolean getBoolean(final DefaultFilterChain filterChain, final String fieldName) {
+        Field field = DefaultFilterChain.class.getDeclaredField(fieldName);
         field.setAccessible(true);
         return (boolean) field.get(filterChain);
     }
+    
+    private static class PassableFilter implements Filter {
+        
+        @Override
+        public void doFilter(final FullHttpRequest httpRequest, final FullHttpResponse httpResponse, final FilterChain filterChain) {
+            filterChain.next(httpRequest);
+        }
+    }
 }