Revise Filter in elasticjob-restful (#1762)
diff --git a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
index ed618ce..65fb24b 100644
--- a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
+++ b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/Filter.java
@@ -32,7 +32,6 @@
* @param httpRequest HTTP request
* @param httpResponse HTTP response
* @param filterChain filter chain
- * @return pass through the filter if true, else do response
*/
- boolean doFilter(FullHttpRequest httpRequest, FullHttpResponse httpResponse, FilterChain filterChain);
+ void doFilter(FullHttpRequest httpRequest, FullHttpResponse httpResponse, FilterChain filterChain);
}
diff --git a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
index 927ca4c..13f493a 100644
--- a/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
+++ b/elasticjob-infra/elasticjob-restful/src/main/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChain.java
@@ -39,7 +39,9 @@
private int current;
- private boolean finished;
+ private boolean passedThrough;
+
+ private boolean replied;
public DefaultFilterChain(final List<Filter> filterInstances, final ChannelHandlerContext ctx, final HandleContext<?> handleContext) {
filters = filterInstances.toArray(new Filter[0]);
@@ -49,23 +51,22 @@
@Override
public void next(final FullHttpRequest httpRequest) {
- Preconditions.checkState(!finished, "FilterChain has already finished.");
+ Preconditions.checkState(!passedThrough && !replied, "FilterChain has already finished.");
if (current < filters.length) {
- Filter currentFilter = filters[current++];
- boolean passThrough = currentFilter.doFilter(httpRequest, handleContext.getHttpResponse(), this);
- if (!passThrough) {
- finished = true;
+ filters[current++].doFilter(httpRequest, handleContext.getHttpResponse(), this);
+ if (!passedThrough && !replied) {
doResponse();
}
return;
}
- finished = true;
+ passedThrough = true;
ctx.fireChannelRead(handleContext);
}
private void doResponse() {
try {
ctx.writeAndFlush(handleContext.getHttpResponse());
+ replied = true;
} finally {
ReferenceCountUtil.release(handleContext.getHttpRequest());
}
diff --git a/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java b/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
index 067fdd1..85a2586 100644
--- a/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
+++ b/elasticjob-infra/elasticjob-restful/src/test/java/org/apache/shardingsphere/elasticjob/restful/filter/DefaultFilterChainTest.java
@@ -37,9 +37,10 @@
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public final class DefaultFilterChainTest {
@@ -53,15 +54,6 @@
@Mock
private FullHttpResponse httpResponse;
- @Mock
- private Filter firstFilter;
-
- @Mock
- private Filter secondFilter;
-
- @Mock
- private Filter thirdFilter;
-
private HandleContext<Handler> handleContext;
@Before
@@ -75,56 +67,61 @@
filterChain.next(httpRequest);
verify(ctx, never()).writeAndFlush(httpResponse);
verify(ctx).fireChannelRead(handleContext);
+ assertTrue(isPassedThrough(filterChain));
+ assertFalse(isReplied(filterChain));
}
@Test
public void assertWithSingleFilterPassed() {
- DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(firstFilter), ctx, handleContext);
- when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
+ Filter passableFilter = spy(new PassableFilter());
+ DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(passableFilter), ctx, handleContext);
filterChain.next(httpRequest);
- verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
- filterChain.next(httpRequest);
+ verify(passableFilter).doFilter(httpRequest, httpResponse, filterChain);
verify(ctx).fireChannelRead(handleContext);
verify(ctx, never()).writeAndFlush(httpResponse);
+ assertTrue(isPassedThrough(filterChain));
+ assertFalse(isReplied(filterChain));
}
@Test
public void assertWithSingleFilterDoResponse() {
- DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(firstFilter), ctx, handleContext);
+ Filter impassableFilter = mock(Filter.class);
+ DefaultFilterChain filterChain = new DefaultFilterChain(Collections.singletonList(impassableFilter), ctx, handleContext);
filterChain.next(httpRequest);
- verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
+ verify(impassableFilter).doFilter(httpRequest, httpResponse, filterChain);
verify(ctx, never()).fireChannelRead(any(HandleContext.class));
verify(ctx).writeAndFlush(httpResponse);
+ assertTrue(isReplied(filterChain));
+ assertFalse(isPassedThrough(filterChain));
}
@Test
public void assertWithThreeFiltersPassed() {
+ Filter firstFilter = spy(new PassableFilter());
+ Filter secondFilter = spy(new PassableFilter());
+ Filter thirdFilter = spy(new PassableFilter());
DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter, thirdFilter), ctx, handleContext);
- when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
filterChain.next(httpRequest);
verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
- when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
- filterChain.next(httpRequest);
verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
- when(thirdFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
- filterChain.next(httpRequest);
verify(thirdFilter).doFilter(httpRequest, httpResponse, filterChain);
- filterChain.next(httpRequest);
+ assertTrue(isPassedThrough(filterChain));
+ assertFalse(isReplied(filterChain));
verify(ctx).fireChannelRead(handleContext);
verify(ctx, never()).writeAndFlush(any(FullHttpResponse.class));
}
@Test
public void assertWithThreeFiltersDoResponseByTheSecond() {
+ Filter firstFilter = spy(new PassableFilter());
+ Filter secondFilter = mock(Filter.class);
+ Filter thirdFilter = spy(new PassableFilter());
DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter, thirdFilter), ctx, handleContext);
- when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
filterChain.next(httpRequest);
verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
- when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(false);
- assertFalse(isFinished(filterChain));
- filterChain.next(httpRequest);
verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
- assertTrue(isFinished(filterChain));
+ assertFalse(isPassedThrough(filterChain));
+ assertTrue(isReplied(filterChain));
verify(thirdFilter, never()).doFilter(httpRequest, httpResponse, filterChain);
verify(ctx, never()).fireChannelRead(any(HandleContext.class));
verify(ctx).writeAndFlush(httpResponse);
@@ -138,23 +135,37 @@
}
@Test(expected = IllegalStateException.class)
- public void assertInvokeFinishedFilterChainWithTwoFilters() {
+ public void assertInvokePassedThroughFilterChainWithTwoFilters() {
+ Filter firstFilter = spy(new PassableFilter());
+ Filter secondFilter = spy(new PassableFilter());
DefaultFilterChain filterChain = new DefaultFilterChain(Arrays.asList(firstFilter, secondFilter), ctx, handleContext);
- when(firstFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
filterChain.next(httpRequest);
verify(firstFilter).doFilter(httpRequest, httpResponse, filterChain);
- when(secondFilter.doFilter(httpRequest, httpResponse, filterChain)).thenReturn(true);
- filterChain.next(httpRequest);
verify(secondFilter).doFilter(httpRequest, httpResponse, filterChain);
- filterChain.next(httpRequest);
verify(ctx).fireChannelRead(handleContext);
filterChain.next(httpRequest);
}
+ private boolean isPassedThrough(final DefaultFilterChain filterChain) {
+ return getBoolean(filterChain, "passedThrough");
+ }
+
+ private boolean isReplied(final DefaultFilterChain filterChain) {
+ return getBoolean(filterChain, "replied");
+ }
+
@SneakyThrows
- private boolean isFinished(final DefaultFilterChain filterChain) {
- Field field = DefaultFilterChain.class.getDeclaredField("finished");
+ private boolean getBoolean(final DefaultFilterChain filterChain, final String fieldName) {
+ Field field = DefaultFilterChain.class.getDeclaredField(fieldName);
field.setAccessible(true);
return (boolean) field.get(filterChain);
}
+
+ private static class PassableFilter implements Filter {
+
+ @Override
+ public void doFilter(final FullHttpRequest httpRequest, final FullHttpResponse httpResponse, final FilterChain filterChain) {
+ filterChain.next(httpRequest);
+ }
+ }
}