| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.catalina.nonblocking; |
| |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.io.Serializable; |
| import java.net.HttpURLConnection; |
| import java.net.Socket; |
| import java.net.URL; |
| import java.nio.charset.StandardCharsets; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import javax.net.SocketFactory; |
| |
| import jakarta.servlet.AsyncContext; |
| import jakarta.servlet.AsyncEvent; |
| import jakarta.servlet.AsyncListener; |
| import jakarta.servlet.DispatcherType; |
| import jakarta.servlet.ReadListener; |
| import jakarta.servlet.ServletException; |
| import jakarta.servlet.ServletInputStream; |
| import jakarta.servlet.ServletOutputStream; |
| import jakarta.servlet.WriteListener; |
| import jakarta.servlet.annotation.WebServlet; |
| import jakarta.servlet.http.HttpServlet; |
| import jakarta.servlet.http.HttpServletRequest; |
| import jakarta.servlet.http.HttpServletResponse; |
| |
| import org.junit.Assert; |
| import org.junit.Ignore; |
| import org.junit.Test; |
| |
| import org.apache.catalina.Context; |
| import org.apache.catalina.Wrapper; |
| import org.apache.catalina.startup.BytesStreamer; |
| import org.apache.catalina.startup.SimpleHttpClient; |
| import org.apache.catalina.startup.TesterServlet; |
| import org.apache.catalina.startup.Tomcat; |
| import org.apache.catalina.startup.TomcatBaseTest; |
| import org.apache.catalina.valves.TesterAccessLogValve; |
| import org.apache.juli.logging.Log; |
| import org.apache.juli.logging.LogFactory; |
| import org.apache.tomcat.util.buf.ByteChunk; |
| import org.apache.tomcat.util.net.ContainerThreadMarker; |
| |
| public class TestNonBlockingAPI extends TomcatBaseTest { |
| |
| private static final Log log = LogFactory.getLog(TestNonBlockingAPI.class); |
| |
| private static final int CHUNK_SIZE = 1024 * 1024; |
| private static final int WRITE_SIZE = CHUNK_SIZE * 10; |
| private static final byte[] DATA = new byte[WRITE_SIZE]; |
| private static final int WRITE_PAUSE_MS = 500; |
| |
| |
| static { |
| // Use this sequence for padding to make it easier to spot errors |
| byte[] padding = new byte[] {'z', 'y', 'x', 'w', 'v', 'u', 't', 's', |
| 'r', 'q', 'p', 'o', 'n', 'm', 'l', 'k'}; |
| int blockSize = padding.length; |
| |
| for (int i = 0; i < WRITE_SIZE / blockSize; i++) { |
| String hex = String.format("%01X", Integer.valueOf(i)); |
| int hexSize = hex.length(); |
| int padSize = blockSize - hexSize; |
| |
| System.arraycopy(padding, 0, DATA, i * blockSize, padSize); |
| System.arraycopy( |
| hex.getBytes(), 0, DATA, i * blockSize + padSize, hexSize); |
| } |
| } |
| |
| |
| @Test |
| public void testNonBlockingRead() throws Exception { |
| doTestNonBlockingRead(false, false); |
| } |
| |
| |
| @Test |
| public void testNonBlockingReadAsync() throws Exception { |
| doTestNonBlockingRead(false, true); |
| } |
| |
| |
| @Test(expected=IOException.class) |
| public void testNonBlockingReadIgnoreIsReady() throws Exception { |
| doTestNonBlockingRead(true, false); |
| } |
| |
| |
| private void doTestNonBlockingRead(boolean ignoreIsReady, boolean async) throws Exception { |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| NBReadServlet servlet = new NBReadServlet(ignoreIsReady, async); |
| String servletName = NBReadServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| |
| tomcat.start(); |
| |
| Map<String, List<String>> resHeaders = new HashMap<>(); |
| int rc = postUrl(true, new DataWriter(async ? 0 : 500, async ? 2000000 : 5), |
| "http://localhost:" + getPort() + "/", new ByteChunk(), resHeaders, null); |
| |
| Assert.assertEquals(HttpServletResponse.SC_OK, rc); |
| if (async) { |
| Assert.assertEquals(2000000 * 8, servlet.listener.body.length()); |
| TestAsyncReadListener listener = (TestAsyncReadListener) servlet.listener; |
| Assert.assertTrue(Math.abs(listener.containerThreadCount.get() - listener.notReadyCount.get()) <= 1); |
| Assert.assertEquals(listener.isReadyCount.get(), listener.nonContainerThreadCount.get()); |
| } else { |
| Assert.assertEquals(5 * 8, servlet.listener.body.length()); |
| } |
| } |
| |
| |
| @Test |
| public void testNonBlockingWrite() throws Exception { |
| testNonBlockingWriteInternal(false); |
| } |
| |
| @Test |
| public void testNonBlockingWriteWithKeepAlive() throws Exception { |
| testNonBlockingWriteInternal(true); |
| } |
| |
| private void testNonBlockingWriteInternal(boolean keepAlive) throws Exception { |
| AtomicBoolean asyncContextIsComplete = new AtomicBoolean(false); |
| |
| Tomcat tomcat = getTomcatInstance(); |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| NBWriteServlet servlet = new NBWriteServlet(asyncContextIsComplete); |
| String servletName = NBWriteServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| // Note: Low values of socket.txBufSize can trigger very poor |
| // performance. Set it just low enough to ensure that the |
| // non-blocking write servlet will see isReady() == false |
| Assert.assertTrue(tomcat.getConnector().setProperty("socket.txBufSize", "1048576")); |
| tomcat.start(); |
| |
| SocketFactory factory = SocketFactory.getDefault(); |
| Socket s = factory.createSocket("localhost", getPort()); |
| |
| InputStream is = s.getInputStream(); |
| byte[] buffer = new byte[8192]; |
| |
| ByteChunk result = new ByteChunk(); |
| |
| OutputStream os = s.getOutputStream(); |
| if (keepAlive) { |
| os.write(("OPTIONS * HTTP/1.1\r\n" + |
| "Host: localhost:" + getPort() + "\r\n" + |
| "\r\n").getBytes(StandardCharsets.ISO_8859_1)); |
| os.flush(); |
| // Make sure the entire response has been read. |
| int read = is.read(buffer); |
| // The response should end with CRLFCRLF |
| Assert.assertEquals(buffer[read - 4], '\r'); |
| Assert.assertEquals(buffer[read - 3], '\n'); |
| Assert.assertEquals(buffer[read - 2], '\r'); |
| Assert.assertEquals(buffer[read - 1], '\n'); |
| } |
| os.write(("GET / HTTP/1.1\r\n" + |
| "Host: localhost:" + getPort() + "\r\n" + |
| "Connection: close\r\n" + |
| "\r\n").getBytes(StandardCharsets.ISO_8859_1)); |
| os.flush(); |
| |
| int read = 0; |
| int readSinceLastPause = 0; |
| while (read != -1) { |
| read = is.read(buffer); |
| if (readSinceLastPause == 0) { |
| log.info("Reading data"); |
| } |
| if (read > 0) { |
| result.append(buffer, 0, read); |
| } |
| readSinceLastPause += read; |
| if (readSinceLastPause > WRITE_SIZE / 16) { |
| log.info("Read " + readSinceLastPause + " bytes, pause 500ms"); |
| readSinceLastPause = 0; |
| Thread.sleep(500); |
| } |
| } |
| |
| os.close(); |
| is.close(); |
| s.close(); |
| |
| // Validate the result. |
| // Response line |
| String resultString = result.toString(); |
| log.info("Client read " + resultString.length() + " bytes"); |
| int lineStart = 0; |
| int lineEnd = resultString.indexOf('\n', 0); |
| String line = resultString.substring(lineStart, lineEnd + 1); |
| Assert.assertEquals("HTTP/1.1 200 \r\n", line); |
| |
| // Check headers - looking to see if response is chunked (it should be) |
| boolean chunked = false; |
| while (line.length() > 2) { |
| lineStart = lineEnd + 1; |
| lineEnd = resultString.indexOf('\n', lineStart); |
| line = resultString.substring(lineStart, lineEnd + 1); |
| if (line.startsWith("Transfer-Encoding:")) { |
| Assert.assertEquals("Transfer-Encoding: chunked\r\n", line); |
| chunked = true; |
| } |
| } |
| Assert.assertTrue(chunked); |
| |
| // Now check body size |
| int totalBodyRead = 0; |
| int chunkSize = -1; |
| |
| while (chunkSize != 0) { |
| // Chunk size in hex |
| lineStart = lineEnd + 1; |
| lineEnd = resultString.indexOf('\n', lineStart); |
| line = resultString.substring(lineStart, lineEnd + 1); |
| Assert.assertTrue(line.endsWith("\r\n")); |
| line = line.substring(0, line.length() - 2); |
| log.info("[" + line + "]"); |
| chunkSize = Integer.parseInt(line, 16); |
| |
| // Read the chunk |
| lineStart = lineEnd + 1; |
| lineEnd = resultString.indexOf('\n', lineStart); |
| log.info("Start : " + lineStart + ", End: " + lineEnd); |
| if (lineEnd > lineStart) { |
| line = resultString.substring(lineStart, lineEnd + 1); |
| } else { |
| line = resultString.substring(lineStart); |
| } |
| if (line.length() > 40) { |
| log.info(line.substring(0, 32)); |
| } else { |
| log.info(line); |
| } |
| if (chunkSize + 2 != line.length()) { |
| log.error("Chunk wrong length. Was " + line.length() + |
| " Expected " + (chunkSize + 2)); |
| |
| byte[] resultBytes = resultString.getBytes(); |
| |
| // Find error |
| boolean found = false; |
| for (int i = totalBodyRead; i < (totalBodyRead + line.length()); i++) { |
| if (DATA[i] != resultBytes[lineStart + i - totalBodyRead]) { |
| int dataStart = i - 64; |
| if (dataStart < 0) { |
| dataStart = 0; |
| } |
| int dataEnd = i + 64; |
| if (dataEnd > DATA.length) { |
| dataEnd = DATA.length; |
| } |
| int resultStart = lineStart + i - totalBodyRead - 64; |
| if (resultStart < 0) { |
| resultStart = 0; |
| } |
| int resultEnd = lineStart + i - totalBodyRead + 64; |
| if (resultEnd > resultString.length()) { |
| resultEnd = resultString.length(); |
| } |
| log.error("Mis-match tx: " + new String( |
| DATA, dataStart, dataEnd - dataStart)); |
| log.error("Mis-match rx: " + |
| resultString.substring(resultStart, resultEnd)); |
| found = true; |
| break; |
| } |
| } |
| if (!found) { |
| log.error("No mismatch. Data truncated"); |
| } |
| } |
| |
| Assert.assertTrue(line, line.endsWith("\r\n")); |
| Assert.assertEquals(chunkSize + 2, line.length()); |
| |
| totalBodyRead += chunkSize; |
| } |
| |
| Assert.assertEquals(WRITE_SIZE, totalBodyRead); |
| Assert.assertTrue("AsyncContext should have been completed.", asyncContextIsComplete.get()); |
| } |
| |
| |
| @Test |
| public void testNonBlockingWriteError01ListenerComplete() throws Exception { |
| doTestNonBlockingWriteError01NoListenerComplete(true); |
| } |
| |
| |
| @Test |
| public void testNonBlockingWriteError01NoListenerComplete() throws Exception { |
| doTestNonBlockingWriteError01NoListenerComplete(false); |
| } |
| |
| |
| private void doTestNonBlockingWriteError01NoListenerComplete(boolean listenerCompletesOnError) throws Exception { |
| AtomicBoolean asyncContextIsComplete = new AtomicBoolean(false); |
| |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| TesterAccessLogValve alv = new TesterAccessLogValve(); |
| ctx.getPipeline().addValve(alv); |
| |
| // Some CI platforms appear to have particularly large write buffers |
| // and appear to ignore the socket.txBufSize below. Therefore, configure |
| // configure the Servlet to keep writing until an error is encountered. |
| NBWriteServlet servlet = new NBWriteServlet(asyncContextIsComplete, true, listenerCompletesOnError); |
| String servletName = NBWriteServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| // Note: Low values of socket.txBufSize can trigger very poor |
| // performance. Set it just low enough to ensure that the |
| // non-blocking write servlet will see isReady() == false |
| Assert.assertTrue(tomcat.getConnector().setProperty("socket.txBufSize", "524228")); |
| tomcat.start(); |
| |
| SocketFactory factory = SocketFactory.getDefault(); |
| Socket s = factory.createSocket("localhost", getPort()); |
| |
| ByteChunk result = new ByteChunk(); |
| OutputStream os = s.getOutputStream(); |
| os.write(("GET / HTTP/1.1\r\n" + |
| "Host: localhost:" + getPort() + "\r\n" + |
| "Connection: close\r\n" + |
| "\r\n").getBytes(StandardCharsets.ISO_8859_1)); |
| os.flush(); |
| |
| InputStream is = s.getInputStream(); |
| byte[] buffer = new byte[8192]; |
| |
| int read = 0; |
| int readSinceLastPause = 0; |
| int readTotal = 0; |
| while (read != -1 && readTotal < WRITE_SIZE / 32) { |
| long start = System.currentTimeMillis(); |
| read = is.read(buffer); |
| long end = System.currentTimeMillis(); |
| log.info("Client read [" + read + "] bytes in [" + (end - start) + |
| "] ms"); |
| if (read > 0) { |
| result.append(buffer, 0, read); |
| } |
| readSinceLastPause += read; |
| readTotal += read; |
| if (readSinceLastPause > WRITE_SIZE / 64) { |
| readSinceLastPause = 0; |
| Thread.sleep(WRITE_PAUSE_MS); |
| } |
| } |
| |
| os.close(); |
| is.close(); |
| s.close(); |
| |
| String resultString = result.toString(); |
| log.info("Client read " + resultString.length() + " bytes"); |
| int lineStart = 0; |
| int lineEnd = resultString.indexOf('\n', 0); |
| String line = resultString.substring(lineStart, lineEnd + 1); |
| Assert.assertEquals("HTTP/1.1 200 \r\n", line); |
| |
| // Listeners are invoked and access valve entries created on a different |
| // thread so give that thread a chance to complete its work. |
| int count = 0; |
| while (count < 100 && !servlet.wlistener.onErrorInvoked) { |
| Thread.sleep(100); |
| count ++; |
| } |
| |
| while (count < 100 && !asyncContextIsComplete.get()) { |
| Thread.sleep(100); |
| count ++; |
| } |
| |
| while (count < 100 && alv.getEntryCount() < 1) { |
| Thread.sleep(100); |
| count ++; |
| } |
| |
| Assert.assertTrue("Error listener should have been invoked.", servlet.wlistener.onErrorInvoked); |
| Assert.assertTrue("Async context should have been completed.", asyncContextIsComplete.get()); |
| |
| // TODO Figure out why non-blocking writes with the NIO connector appear |
| // to be slower on Linux |
| alv.validateAccessLog(1, 500, WRITE_PAUSE_MS, |
| WRITE_PAUSE_MS + 30 * 1000); |
| } |
| |
| |
| @Test |
| public void testBug55438NonBlockingReadWriteEmptyRead() throws Exception { |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| NBReadWriteServlet servlet = new NBReadWriteServlet(); |
| String servletName = NBReadWriteServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| |
| tomcat.start(); |
| |
| Map<String, List<String>> resHeaders = new HashMap<>(); |
| int rc = postUrl(false, new BytesStreamer() { |
| @Override |
| public byte[] next() { |
| return new byte[] {}; |
| } |
| |
| @Override |
| public int getLength() { |
| return 0; |
| } |
| |
| @Override |
| public int available() { |
| return 0; |
| } |
| }, "http://localhost:" + |
| getPort() + "/", new ByteChunk(), resHeaders, null); |
| Assert.assertEquals(HttpServletResponse.SC_OK, rc); |
| } |
| |
| |
| public static class DataWriter implements BytesStreamer { |
| int max = 5; |
| int count = 0; |
| long delay = 0; |
| byte[] b = "WANTMORE".getBytes(StandardCharsets.ISO_8859_1); |
| byte[] f = "FINISHED".getBytes(StandardCharsets.ISO_8859_1); |
| |
| public DataWriter(long delay, int max) { |
| this.delay = delay; |
| this.max = max; |
| } |
| |
| @Override |
| public int getLength() { |
| return b.length * max; |
| } |
| |
| @Override |
| public int available() { |
| if (count < max) { |
| return b.length; |
| } else { |
| return 0; |
| } |
| } |
| |
| @Override |
| public byte[] next() { |
| if (count < max) { |
| if (count > 0) |
| try { |
| if (delay > 0) |
| Thread.sleep(delay); |
| } catch (Exception x) { |
| } |
| count++; |
| if (count < max) |
| return b; |
| else |
| return f; |
| } else { |
| return null; |
| } |
| } |
| |
| } |
| |
| @WebServlet(asyncSupported = true) |
| public static class NBReadServlet extends TesterServlet { |
| private static final long serialVersionUID = 1L; |
| private final boolean async; |
| private final boolean ignoreIsReady; |
| transient TestReadListener listener; |
| |
| public NBReadServlet(boolean ignoreIsReady, boolean async) { |
| this.async = async; |
| this.ignoreIsReady = ignoreIsReady; |
| } |
| |
| @Override |
| protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { |
| // step 1 - start async |
| AsyncContext actx = req.startAsync(); |
| actx.setTimeout(Long.MAX_VALUE); |
| actx.addListener(new AsyncListener() { |
| |
| @Override |
| public void onTimeout(AsyncEvent event) throws IOException { |
| log.info("onTimeout"); |
| |
| } |
| |
| @Override |
| public void onStartAsync(AsyncEvent event) throws IOException { |
| log.info("onStartAsync"); |
| |
| } |
| |
| @Override |
| public void onError(AsyncEvent event) throws IOException { |
| log.info("AsyncListener.onError"); |
| |
| } |
| |
| @Override |
| public void onComplete(AsyncEvent event) throws IOException { |
| log.info("onComplete"); |
| |
| } |
| }); |
| // step 2 - notify on read |
| ServletInputStream in = req.getInputStream(); |
| if (async) { |
| listener = new TestAsyncReadListener(actx, false, ignoreIsReady); |
| } else { |
| listener = new TestReadListener(actx, false, ignoreIsReady); |
| } |
| in.setReadListener(listener); |
| } |
| } |
| |
| @WebServlet(asyncSupported = true) |
| public static class NBWriteServlet extends TesterServlet { |
| private static final long serialVersionUID = 1L; |
| private final AtomicBoolean asyncContextIsComplete; |
| private final boolean unlimited; |
| private final boolean listenerCompletesOnError; |
| public transient volatile TestWriteListener wlistener; |
| |
| public NBWriteServlet(AtomicBoolean asyncContextIsComplete) { |
| this(asyncContextIsComplete, false, true); |
| } |
| |
| |
| public NBWriteServlet(AtomicBoolean asyncContextIsComplete, boolean unlimited, boolean listenerCompletesOnError) { |
| this.asyncContextIsComplete = asyncContextIsComplete; |
| this.unlimited = unlimited; |
| this.listenerCompletesOnError = listenerCompletesOnError; |
| } |
| |
| |
| @Override |
| protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { |
| // step 1 - start async |
| AsyncContext actx = req.startAsync(); |
| actx.setTimeout(Long.MAX_VALUE); |
| actx.addListener(new AsyncListener() { |
| |
| @Override |
| public void onTimeout(AsyncEvent event) throws IOException { |
| log.info("onTimeout"); |
| } |
| |
| @Override |
| public void onStartAsync(AsyncEvent event) throws IOException { |
| log.info("onStartAsync"); |
| } |
| |
| @Override |
| public void onError(AsyncEvent event) throws IOException { |
| log.info("AsyncListener.onError"); |
| if (listenerCompletesOnError) { |
| event.getAsyncContext().complete(); |
| } |
| } |
| |
| @Override |
| public void onComplete(AsyncEvent event) throws IOException { |
| log.info("onComplete"); |
| asyncContextIsComplete.set(true); |
| } |
| }); |
| // step 2 - notify on read |
| ServletOutputStream out = resp.getOutputStream(); |
| resp.setBufferSize(200 * 1024); |
| wlistener = new TestWriteListener(actx, unlimited); |
| out.setWriteListener(wlistener); |
| } |
| |
| |
| } |
| |
| @WebServlet(asyncSupported = true) |
| public static class NBReadWriteServlet extends TesterServlet { |
| private static final long serialVersionUID = 1L; |
| public transient volatile TestReadWriteListener rwlistener; |
| |
| @Override |
| protected void service(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { |
| // step 1 - start async |
| AsyncContext actx = req.startAsync(); |
| actx.setTimeout(Long.MAX_VALUE); |
| |
| // step 2 - notify on read |
| ServletInputStream in = req.getInputStream(); |
| rwlistener = new TestReadWriteListener(actx); |
| in.setReadListener(rwlistener); |
| } |
| } |
| |
| private static class TestReadListener implements ReadListener { |
| protected final AsyncContext ctx; |
| protected final boolean usingNonBlockingWrite; |
| protected final boolean ignoreIsReady; |
| protected final StringBuilder body = new StringBuilder(); |
| |
| |
| public TestReadListener(AsyncContext ctx, |
| boolean usingNonBlockingWrite, |
| boolean ignoreIsReady) { |
| this.ctx = ctx; |
| this.usingNonBlockingWrite = usingNonBlockingWrite; |
| this.ignoreIsReady = ignoreIsReady; |
| } |
| |
| @Override |
| public void onDataAvailable() throws IOException { |
| ServletInputStream in = ctx.getRequest().getInputStream(); |
| String s = ""; |
| byte[] b = new byte[8192]; |
| int read = 0; |
| do { |
| read = in.read(b); |
| if (read == -1) { |
| break; |
| } |
| s += new String(b, 0, read); |
| } while (ignoreIsReady || in.isReady()); |
| log.info(s); |
| body.append(s); |
| } |
| |
| @Override |
| public void onAllDataRead() { |
| log.info("onAllDataRead totalData=" + body.toString().length()); |
| // If non-blocking writes are being used, don't write here as it |
| // will inject unexpected data into the write output. |
| if (!usingNonBlockingWrite) { |
| String msg; |
| if (body.toString().endsWith("FINISHED")) { |
| msg = "OK"; |
| } else { |
| msg = "FAILED"; |
| } |
| try { |
| ctx.getResponse().getOutputStream().print(msg); |
| } catch (IOException ioe) { |
| // Ignore |
| } |
| ctx.complete(); |
| } |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| log.info("ReadListener.onError totalData=" + body.toString().length()); |
| throwable.printStackTrace(); |
| } |
| } |
| |
| private static class TestAsyncReadListener extends TestReadListener { |
| |
| AtomicInteger isReadyCount = new AtomicInteger(0); |
| AtomicInteger notReadyCount = new AtomicInteger(0); |
| AtomicInteger containerThreadCount = new AtomicInteger(0); |
| AtomicInteger nonContainerThreadCount = new AtomicInteger(0); |
| |
| public TestAsyncReadListener(AsyncContext ctx, |
| boolean usingNonBlockingWrite, boolean ignoreIsReady) { |
| super(ctx, usingNonBlockingWrite, ignoreIsReady); |
| } |
| |
| @Override |
| public void onDataAvailable() throws IOException { |
| if (ContainerThreadMarker.isContainerThread()) { |
| containerThreadCount.incrementAndGet(); |
| } else { |
| nonContainerThreadCount.incrementAndGet(); |
| } |
| new Thread() { |
| @Override |
| public void run() { |
| try { |
| ServletInputStream in = ctx.getRequest().getInputStream(); |
| byte[] b = new byte[1024]; |
| int read = in.read(b); |
| if (read == -1) { |
| return; |
| } |
| body.append(new String(b, 0, read)); |
| boolean isReady = ignoreIsReady || in.isReady(); |
| if (isReady) { |
| isReadyCount.incrementAndGet(); |
| } else { |
| notReadyCount.incrementAndGet(); |
| } |
| if (isReady) { |
| onDataAvailable(); |
| } |
| } catch (IOException e) { |
| onError(e); |
| } |
| } |
| }.start(); |
| } |
| |
| @Override |
| public void onAllDataRead() { |
| super.onAllDataRead(); |
| log.info("isReadyCount=" + isReadyCount + " notReadyCount=" + notReadyCount |
| + " containerThreadCount=" + containerThreadCount |
| + " nonContainerThreadCount=" + nonContainerThreadCount); |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| super.onError(throwable); |
| log.info("isReadyCount=" + isReadyCount + " notReadyCount=" + notReadyCount |
| + " containerThreadCount=" + containerThreadCount |
| + " nonContainerThreadCount=" + nonContainerThreadCount); |
| } |
| } |
| |
| private static class TestWriteListener implements WriteListener { |
| AsyncContext ctx; |
| private final boolean unlimited; |
| int written = 0; |
| public volatile boolean onErrorInvoked = false; |
| |
| public TestWriteListener(AsyncContext ctx, boolean unlimted) { |
| this.ctx = ctx; |
| this.unlimited = unlimted; |
| } |
| |
| @Override |
| public void onWritePossible() throws IOException { |
| long start = System.currentTimeMillis(); |
| int before = written; |
| while ((written < WRITE_SIZE || unlimited) && |
| ctx.getResponse().getOutputStream().isReady()) { |
| ctx.getResponse().getOutputStream().write( |
| DATA, written, CHUNK_SIZE); |
| written += CHUNK_SIZE; |
| } |
| if (written == WRITE_SIZE) { |
| // Clear the output buffer else data may be lost when |
| // calling complete |
| ctx.getResponse().flushBuffer(); |
| } |
| log.info("Write took: " + (System.currentTimeMillis() - start) + |
| " ms. Bytes before=" + before + " after=" + written); |
| // only call complete if we have emptied the buffer |
| if (ctx.getResponse().getOutputStream().isReady() && |
| written == WRITE_SIZE) { |
| // it is illegal to call complete |
| // if there is a write in progress |
| ctx.complete(); |
| } |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| log.info("WriteListener.onError"); |
| throwable.printStackTrace(); |
| onErrorInvoked = true; |
| } |
| |
| } |
| |
| private static class TestReadWriteListener implements ReadListener { |
| AsyncContext ctx; |
| private final StringBuilder body = new StringBuilder(); |
| |
| public TestReadWriteListener(AsyncContext ctx) { |
| this.ctx = ctx; |
| } |
| |
| @Override |
| public void onDataAvailable() throws IOException { |
| ServletInputStream in = ctx.getRequest().getInputStream(); |
| String s = ""; |
| byte[] b = new byte[8192]; |
| int read = 0; |
| do { |
| read = in.read(b); |
| if (read == -1) { |
| break; |
| } |
| s += new String(b, 0, read); |
| } while (in.isReady()); |
| log.info("Read [" + s + "]"); |
| body.append(s); |
| } |
| |
| @Override |
| public void onAllDataRead() throws IOException { |
| log.info("onAllDataRead"); |
| ServletOutputStream output = ctx.getResponse().getOutputStream(); |
| output.setWriteListener(new WriteListener() { |
| @Override |
| public void onWritePossible() throws IOException { |
| ServletOutputStream output = ctx.getResponse().getOutputStream(); |
| if (output.isReady()) { |
| log.info("Writing [" + body.toString() + "]"); |
| output.write(body.toString().getBytes("utf-8")); |
| } |
| ctx.complete(); |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| log.info("ReadWriteListener.onError"); |
| throwable.printStackTrace(); |
| } |
| }); |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| log.info("ReadListener.onError"); |
| throwable.printStackTrace(); |
| } |
| |
| } |
| |
| public static int postUrlWithDisconnect(boolean stream, BytesStreamer streamer, String path, |
| Map<String, List<String>> reqHead, Map<String, List<String>> resHead) throws IOException { |
| |
| URL url = new URL(path); |
| HttpURLConnection connection = (HttpURLConnection) url.openConnection(); |
| connection.setDoOutput(true); |
| connection.setReadTimeout(1000000); |
| if (reqHead != null) { |
| for (Map.Entry<String, List<String>> entry : reqHead.entrySet()) { |
| StringBuilder valueList = new StringBuilder(); |
| for (String value : entry.getValue()) { |
| if (valueList.length() > 0) { |
| valueList.append(','); |
| } |
| valueList.append(value); |
| } |
| connection.setRequestProperty(entry.getKey(), valueList.toString()); |
| } |
| } |
| if (streamer != null && stream) { |
| if (streamer.getLength() > 0) { |
| connection.setFixedLengthStreamingMode(streamer.getLength()); |
| } else { |
| connection.setChunkedStreamingMode(1024); |
| } |
| } |
| |
| connection.connect(); |
| |
| // Write the request body |
| try (OutputStream os = connection.getOutputStream()) { |
| while (streamer != null && streamer.available() > 0) { |
| byte[] next = streamer.next(); |
| os.write(next); |
| os.flush(); |
| } |
| } |
| |
| int rc = connection.getResponseCode(); |
| if (resHead != null) { |
| Map<String, List<String>> head = connection.getHeaderFields(); |
| resHead.putAll(head); |
| } |
| try { |
| Thread.sleep(1000); |
| } catch (InterruptedException e) { |
| |
| } |
| if (rc == HttpServletResponse.SC_OK) { |
| connection.getInputStream().close(); |
| connection.disconnect(); |
| } |
| return rc; |
| } |
| |
| |
| @Ignore |
| @Test |
| public void testDelayedNBWrite() throws Exception { |
| Tomcat tomcat = getTomcatInstance(); |
| |
| Context ctx = tomcat.addContext("", null); |
| CountDownLatch latch1 = new CountDownLatch(1); |
| DelayedNBWriteServlet servlet = new DelayedNBWriteServlet(latch1); |
| String servletName = DelayedNBWriteServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| |
| tomcat.start(); |
| |
| CountDownLatch latch2 = new CountDownLatch(2); |
| List<Throwable> exceptions = new ArrayList<>(); |
| |
| Thread t = new Thread( |
| new RequestExecutor("http://localhost:" + getPort() + "/", latch2, exceptions)); |
| t.start(); |
| |
| latch1.await(3000, TimeUnit.MILLISECONDS); |
| |
| Thread t1 = new Thread(new RequestExecutor( |
| "http://localhost:" + getPort() + "/?notify=true", latch2, exceptions)); |
| t1.start(); |
| |
| latch2.await(3000, TimeUnit.MILLISECONDS); |
| |
| if (exceptions.size() > 0) { |
| Assert.fail(); |
| } |
| } |
| |
| private static final class RequestExecutor implements Runnable { |
| private final String url; |
| private final CountDownLatch latch; |
| private final List<Throwable> exceptions; |
| |
| public RequestExecutor(String url, CountDownLatch latch, List<Throwable> exceptions) { |
| this.url = url; |
| this.latch = latch; |
| this.exceptions = exceptions; |
| } |
| |
| @Override |
| public void run() { |
| try { |
| ByteChunk result = new ByteChunk(); |
| int rc = getUrl(url, result, null); |
| Assert.assertTrue(rc == HttpServletResponse.SC_OK); |
| Assert.assertTrue(result.toString().contains("OK")); |
| } catch (Throwable e) { |
| e.printStackTrace(); |
| exceptions.add(e); |
| } finally { |
| latch.countDown(); |
| } |
| } |
| |
| } |
| |
| @WebServlet(asyncSupported = true) |
| private static final class DelayedNBWriteServlet extends TesterServlet { |
| private static final long serialVersionUID = 1L; |
| private final Set<Emitter> emitters = new HashSet<>(); |
| private final transient CountDownLatch latch; |
| |
| public DelayedNBWriteServlet(CountDownLatch latch) { |
| this.latch = latch; |
| } |
| |
| @Override |
| protected void doGet(HttpServletRequest request, HttpServletResponse response) |
| throws ServletException, IOException { |
| boolean notify = Boolean.parseBoolean(request.getParameter("notify")); |
| AsyncContext ctx = request.startAsync(); |
| ctx.setTimeout(1000); |
| if (!notify) { |
| emitters.add(new Emitter(ctx)); |
| latch.countDown(); |
| } else { |
| for (Emitter e : emitters) { |
| e.emit(); |
| } |
| response.getOutputStream().println("OK"); |
| response.getOutputStream().flush(); |
| ctx.complete(); |
| } |
| } |
| |
| } |
| |
| private static final class Emitter implements Serializable { |
| |
| private static final long serialVersionUID = 1L; |
| |
| private final transient AsyncContext ctx; |
| |
| Emitter(AsyncContext ctx) { |
| this.ctx = ctx; |
| } |
| |
| void emit() throws IOException { |
| ctx.getResponse().getOutputStream().setWriteListener(new WriteListener() { |
| private boolean written = false; |
| |
| @Override |
| public void onWritePossible() throws IOException { |
| ServletOutputStream out = ctx.getResponse().getOutputStream(); |
| if (out.isReady() && !written) { |
| out.println("OK"); |
| written = true; |
| } |
| if (out.isReady() && written) { |
| out.flush(); |
| if (out.isReady()) { |
| ctx.complete(); |
| } |
| } |
| } |
| |
| @Override |
| public void onError(Throwable t) { |
| t.printStackTrace(); |
| } |
| |
| }); |
| } |
| } |
| |
| |
| /* |
| * https://bz.apache.org/bugzilla/show_bug.cgi?id=61932 |
| */ |
| @Test |
| public void testNonBlockingReadWithDispatch() throws Exception { |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| NBReadWithDispatchServlet servlet = new NBReadWithDispatchServlet(); |
| String servletName = NBReadWithDispatchServlet.class.getName(); |
| Tomcat.addServlet(ctx, servletName, servlet); |
| ctx.addServletMappingDecoded("/", servletName); |
| |
| tomcat.start(); |
| |
| Map<String, List<String>> resHeaders = new HashMap<>(); |
| int rc = postUrl(true, new DataWriter(500, 5), "http://localhost:" + |
| getPort() + "/", new ByteChunk(), resHeaders, null); |
| |
| Assert.assertEquals(HttpServletResponse.SC_OK, rc); |
| } |
| |
| |
| @WebServlet(asyncSupported = true) |
| private static final class NBReadWithDispatchServlet extends TesterServlet { |
| |
| private static final long serialVersionUID = 1L; |
| |
| @Override |
| protected void doPost(HttpServletRequest req, HttpServletResponse resp) |
| throws ServletException, IOException { |
| |
| CountDownLatch latch = new CountDownLatch(1); |
| |
| // Dispatch to "/error" will end up here |
| if (req.getDispatcherType().equals(DispatcherType.ASYNC)) { |
| // Return without writing anything. This will generate the |
| // expected 200 response. |
| return; |
| } |
| |
| AsyncContext asyncCtx = req.startAsync(); |
| ServletInputStream is = req.getInputStream(); |
| is.setReadListener(new ReadListener() { |
| |
| @Override |
| public void onDataAvailable() { |
| |
| try { |
| byte buffer[] = new byte[1 * 1024]; |
| while (is.isReady() && !is.isFinished()) { |
| is.read(buffer); |
| } |
| |
| } catch (IOException ex) { |
| ex.printStackTrace(); |
| } |
| } |
| |
| @Override |
| public void onAllDataRead() { |
| latch.countDown(); |
| } |
| |
| @Override |
| public void onError(Throwable t) { |
| } |
| }); |
| |
| new Thread(() -> { |
| try { |
| latch.await(); |
| } catch (InterruptedException e) { |
| e.printStackTrace(); |
| } |
| asyncCtx.dispatch("/error"); |
| }).start(); |
| |
| } |
| } |
| |
| |
| @Test |
| public void testCanceledPostChunked() throws Exception { |
| doTestCanceledPost(new String[] { |
| "POST / HTTP/1.1" + SimpleHttpClient.CRLF + |
| "Host: localhost:" + SimpleHttpClient.CRLF + |
| "Transfer-Encoding: Chunked" + SimpleHttpClient.CRLF + |
| SimpleHttpClient.CRLF + |
| "10" + SimpleHttpClient.CRLF + |
| "This is 16 bytes" + SimpleHttpClient.CRLF |
| }); |
| } |
| |
| |
| @Test |
| public void testCanceledPostNoChunking() throws Exception { |
| doTestCanceledPost(new String[] { |
| "POST / HTTP/1.1" + SimpleHttpClient.CRLF + |
| "Host: localhost:" + SimpleHttpClient.CRLF + |
| "Content-Length: 100" + SimpleHttpClient.CRLF + |
| SimpleHttpClient.CRLF + |
| "This is 16 bytes" |
| }); |
| } |
| |
| |
| /* |
| * Tests an error on an non-blocking read when the client closes the |
| * connection before fully writing the request body. |
| * |
| * Required sequence is: |
| * - enter Servlet's service() method |
| * - startAsync() |
| * - configure non-blocking read |
| * - read partial body |
| * - close client connection |
| * - error is triggered |
| * - exit Servlet's service() method |
| * |
| * This test makes extensive use of instance fields in the Servlet that |
| * would normally be considered very poor practice. It is only safe in this |
| * test as the Servlet only processes a single request. |
| */ |
| private void doTestCanceledPost(String[] request) throws Exception { |
| |
| CountDownLatch partialReadLatch = new CountDownLatch(1); |
| CountDownLatch completeLatch = new CountDownLatch(1); |
| |
| AtomicBoolean testFailed = new AtomicBoolean(true); |
| |
| // Setup Tomcat instance |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| PostServlet postServlet = new PostServlet(partialReadLatch, completeLatch, testFailed); |
| Wrapper wrapper = Tomcat.addServlet(ctx, "postServlet", postServlet); |
| wrapper.setAsyncSupported(true); |
| ctx.addServletMappingDecoded("/*", "postServlet"); |
| |
| tomcat.start(); |
| |
| ResponseOKClient client = new ResponseOKClient(); |
| client.setPort(getPort()); |
| client.setRequest(request); |
| client.connect(); |
| client.sendRequest(); |
| |
| // Wait server to read partial request body |
| partialReadLatch.await(); |
| |
| client.disconnect(); |
| |
| completeLatch.await(); |
| |
| Assert.assertFalse(testFailed.get()); |
| } |
| |
| |
| private static final class ResponseOKClient extends SimpleHttpClient { |
| |
| @Override |
| public boolean isResponseBodyOK() { |
| return true; |
| } |
| } |
| |
| |
| private static final class PostServlet extends HttpServlet { |
| |
| private static final long serialVersionUID = 1L; |
| |
| private final transient CountDownLatch partialReadLatch; |
| private final transient CountDownLatch completeLatch; |
| private final AtomicBoolean testFailed; |
| |
| public PostServlet(CountDownLatch doPostLatch, CountDownLatch completeLatch, AtomicBoolean testFailed) { |
| this.partialReadLatch = doPostLatch; |
| this.completeLatch = completeLatch; |
| this.testFailed = testFailed; |
| } |
| |
| @Override |
| protected void doPost(HttpServletRequest req, HttpServletResponse resp) |
| throws ServletException, IOException { |
| |
| AsyncContext ac = req.startAsync(); |
| ac.setTimeout(-1); |
| CanceledPostAsyncListener asyncListener = new CanceledPostAsyncListener(completeLatch); |
| ac.addListener(asyncListener); |
| |
| CanceledPostReadListener readListener = new CanceledPostReadListener(ac, partialReadLatch, testFailed); |
| req.getInputStream().setReadListener(readListener); |
| } |
| } |
| |
| |
| private static final class CanceledPostAsyncListener implements AsyncListener { |
| |
| private final transient CountDownLatch completeLatch; |
| |
| public CanceledPostAsyncListener(CountDownLatch completeLatch) { |
| this.completeLatch = completeLatch; |
| } |
| |
| @Override |
| public void onComplete(AsyncEvent event) throws IOException { |
| System.out.println("complete"); |
| completeLatch.countDown(); |
| } |
| |
| @Override |
| public void onTimeout(AsyncEvent event) throws IOException { |
| System.out.println("onTimeout"); |
| } |
| |
| @Override |
| public void onError(AsyncEvent event) throws IOException { |
| System.out.println("onError-async"); |
| } |
| |
| @Override |
| public void onStartAsync(AsyncEvent event) throws IOException { |
| System.out.println("onStartAsync"); |
| } |
| } |
| |
| private static final class CanceledPostReadListener implements ReadListener { |
| |
| private final AsyncContext ac; |
| private final CountDownLatch partialReadLatch; |
| private final AtomicBoolean testFailed; |
| private int totalRead = 0; |
| |
| public CanceledPostReadListener(AsyncContext ac, CountDownLatch partialReadLatch, AtomicBoolean testFailed) { |
| this.ac = ac; |
| this.partialReadLatch = partialReadLatch; |
| this.testFailed = testFailed; |
| } |
| |
| @Override |
| public void onDataAvailable() throws IOException { |
| ServletInputStream sis = ac.getRequest().getInputStream(); |
| boolean isReady; |
| |
| byte[] buffer = new byte[32]; |
| do { |
| if (partialReadLatch.getCount() == 0) { |
| System.out.println("debug"); |
| } |
| int bytesRead = sis.read(buffer); |
| |
| if (bytesRead == -1) { |
| return; |
| } |
| totalRead += bytesRead; |
| isReady = sis.isReady(); |
| System.out.println("Read [" + bytesRead + |
| "], buffer [" + new String(buffer, 0, bytesRead, StandardCharsets.UTF_8) + |
| "], total read [" + totalRead + |
| "], isReady [" + isReady + "]"); |
| } while (isReady); |
| if (totalRead == 16) { |
| partialReadLatch.countDown(); |
| } |
| } |
| |
| @Override |
| public void onAllDataRead() throws IOException { |
| ac.complete(); |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| throwable.printStackTrace(); |
| // This is the expected behaviour so clear the failed flag. |
| testFailed.set(false); |
| ac.complete(); |
| } |
| } |
| |
| |
| @Test |
| public void testNonBlockingWriteError02NoSwallow() throws Exception { |
| doTestNonBlockingWriteError02(false); |
| } |
| |
| |
| @Test |
| public void testNonBlockingWriteError02Swallow() throws Exception { |
| doTestNonBlockingWriteError02(true); |
| } |
| |
| |
| /* |
| * Tests client disconnect in the following scenario: |
| * - async with non-blocking IO |
| * - response has been committed |
| * - no data in buffers |
| * - client disconnects |
| * - server attempts a write |
| */ |
| private void doTestNonBlockingWriteError02(boolean swallowIoException) throws Exception { |
| CountDownLatch responseCommitLatch = new CountDownLatch(1); |
| CountDownLatch clientCloseLatch = new CountDownLatch(1); |
| CountDownLatch asyncCompleteLatch = new CountDownLatch(1); |
| |
| // Setup Tomcat instance |
| Tomcat tomcat = getTomcatInstance(); |
| |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| |
| NBWriteServlet02 writeServlet = |
| new NBWriteServlet02(responseCommitLatch, clientCloseLatch, asyncCompleteLatch, swallowIoException); |
| Wrapper wrapper = Tomcat.addServlet(ctx, "writeServlet", writeServlet); |
| wrapper.setAsyncSupported(true); |
| ctx.addServletMappingDecoded("/*", "writeServlet"); |
| |
| tomcat.start(); |
| |
| ResponseOKClient client = new ResponseOKClient(); |
| client.setPort(getPort()); |
| client.setRequest(new String[] { |
| "GET / HTTP/1.1" + SimpleHttpClient.CRLF + |
| "Host: localhost:" + SimpleHttpClient.CRLF + |
| SimpleHttpClient.CRLF |
| }); |
| client.connect(); |
| client.sendRequest(); |
| |
| responseCommitLatch.await(); |
| |
| client.disconnect(); |
| clientCloseLatch.countDown(); |
| |
| Assert.assertTrue("Failed to complete async processing", asyncCompleteLatch.await(60, TimeUnit.SECONDS)); |
| } |
| |
| |
| private static class NBWriteServlet02 extends HttpServlet { |
| |
| private static final long serialVersionUID = 1L; |
| |
| private final transient CountDownLatch responseCommitLatch; |
| private final transient CountDownLatch clientCloseLatch; |
| private final transient CountDownLatch asyncCompleteLatch; |
| private final boolean swallowIoException; |
| |
| public NBWriteServlet02(CountDownLatch responseCommitLatch, CountDownLatch clientCloseLatch, |
| CountDownLatch asyncCompleteLatch, boolean swallowIoException) { |
| this.responseCommitLatch = responseCommitLatch; |
| this.clientCloseLatch = clientCloseLatch; |
| this.asyncCompleteLatch = asyncCompleteLatch; |
| this.swallowIoException = swallowIoException; |
| } |
| |
| @Override |
| protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { |
| resp.setContentType("text/plain"); |
| resp.setCharacterEncoding("UTF-8"); |
| |
| AsyncContext ac = req.startAsync(); |
| ac.addListener(new TestAsyncListener02(asyncCompleteLatch)); |
| ac.setTimeout(5000); |
| |
| WriteListener writeListener = |
| new TestWriteListener02(ac, responseCommitLatch, clientCloseLatch, swallowIoException); |
| resp.getOutputStream().setWriteListener(writeListener); |
| } |
| } |
| |
| |
| private static class TestAsyncListener02 implements AsyncListener { |
| |
| private final CountDownLatch asyncCompleteLatch; |
| |
| public TestAsyncListener02(CountDownLatch asyncCompleteLatch) { |
| this.asyncCompleteLatch = asyncCompleteLatch; |
| } |
| |
| @Override |
| public void onComplete(AsyncEvent event) throws IOException { |
| asyncCompleteLatch.countDown(); |
| } |
| |
| @Override |
| public void onTimeout(AsyncEvent event) throws IOException { |
| // NO-OP |
| } |
| |
| @Override |
| public void onError(AsyncEvent event) throws IOException { |
| // NO-OP |
| } |
| |
| @Override |
| public void onStartAsync(AsyncEvent event) throws IOException { |
| // NO-OP |
| } |
| |
| } |
| |
| private static class TestWriteListener02 implements WriteListener { |
| |
| private final AsyncContext ac; |
| private final CountDownLatch responseCommitLatch; |
| private final CountDownLatch clientCloseLatch; |
| private final boolean swallowIoException; |
| private volatile AtomicInteger stage = new AtomicInteger(0); |
| |
| public TestWriteListener02(AsyncContext ac, CountDownLatch responseCommitLatch, |
| CountDownLatch clientCloseLatch, boolean swallowIoException) { |
| this.ac = ac; |
| this.responseCommitLatch = responseCommitLatch; |
| this.clientCloseLatch = clientCloseLatch; |
| this.swallowIoException = swallowIoException; |
| } |
| |
| @Override |
| public void onWritePossible() throws IOException { |
| try { |
| ServletOutputStream sos = ac.getResponse().getOutputStream(); |
| do { |
| if (stage.get() == 0) { |
| // Commit the response |
| ac.getResponse().flushBuffer(); |
| responseCommitLatch.countDown(); |
| stage.incrementAndGet(); |
| } else if (stage.get() == 1) { |
| // Wait for the client to drop the connection |
| try { |
| clientCloseLatch.await(); |
| } catch (InterruptedException e) { |
| // Ignore |
| } |
| sos.print("TEST"); |
| stage.incrementAndGet(); |
| } else if (stage.get() == 2) { |
| sos.flush(); |
| } |
| } while (sos.isReady()); |
| } catch (IOException ioe) { |
| if (!swallowIoException) { |
| throw ioe; |
| } |
| } |
| } |
| |
| @Override |
| public void onError(Throwable throwable) { |
| // NO-OP |
| } |
| } |
| } |