| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.tomcat.websocket.server; |
| |
| import java.io.IOException; |
| import java.util.HashSet; |
| import java.util.Set; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.TimeUnit; |
| |
| import javax.servlet.ServletContextEvent; |
| import javax.websocket.CloseReason; |
| import javax.websocket.CloseReason.CloseCode; |
| import javax.websocket.CloseReason.CloseCodes; |
| import javax.websocket.DeploymentException; |
| import javax.websocket.OnClose; |
| import javax.websocket.OnError; |
| import javax.websocket.OnMessage; |
| import javax.websocket.OnOpen; |
| import javax.websocket.Session; |
| import javax.websocket.server.ServerContainer; |
| import javax.websocket.server.ServerEndpointConfig; |
| |
| import org.junit.Assert; |
| import org.junit.Before; |
| import org.junit.Test; |
| |
| import org.apache.catalina.Context; |
| import org.apache.catalina.LifecycleException; |
| import org.apache.catalina.servlets.DefaultServlet; |
| import org.apache.catalina.startup.Tomcat; |
| import org.apache.juli.logging.Log; |
| import org.apache.juli.logging.LogFactory; |
| import org.apache.tomcat.websocket.WebSocketBaseTest; |
| |
| /** |
| * Test the behavior of closing websockets under various conditions. |
| */ |
| public class TestClose extends WebSocketBaseTest { |
| |
| private static Log log = LogFactory.getLog(TestClose.class); |
| |
| // These are static because it is simpler than trying to inject them into |
| // the endpoint |
| private static volatile Events events; |
| |
| |
| public static class Events { |
| // Used to block in the @OnMessage |
| public final CountDownLatch onMessageWait = new CountDownLatch(1); |
| |
| // Used to check which methods of a server endpoint were called |
| public final CountDownLatch onErrorCalled = new CountDownLatch(1); |
| public final CountDownLatch onMessageCalled = new CountDownLatch(1); |
| public final CountDownLatch onCloseCalled = new CountDownLatch(1); |
| |
| // Parameter of an @OnClose call |
| public volatile CloseReason closeReason = null; |
| // Parameter of an @OnError call |
| public volatile Throwable onErrorThrowable = null; |
| |
| //This is set to true for tests where the @OnMessage should send a message |
| public volatile boolean onMessageSends = false; |
| } |
| |
| |
| private static void awaitLatch(CountDownLatch latch, String failMessage) { |
| try { |
| if (!latch.await(30000, TimeUnit.MILLISECONDS)) { |
| Assert.fail(failMessage); |
| } |
| } catch (InterruptedException e) { |
| // Won't happen |
| throw new RuntimeException(e); |
| } |
| } |
| |
| |
| public static void awaitOnClose(CloseCode... codes) { |
| Set<CloseCode> set = new HashSet<>(); |
| for (CloseCode code : codes) { |
| set.add(code); |
| } |
| awaitOnClose(set); |
| } |
| |
| |
| public static void awaitOnClose(Set<CloseCode> codes) { |
| awaitLatch(events.onCloseCalled, "onClose not called"); |
| CloseCode received = events.closeReason.getCloseCode(); |
| Assert.assertTrue("Rx: " + received, codes.contains(received)); |
| } |
| |
| |
| public static void awaitOnError(Class<? extends Throwable> exceptionClazz) { |
| awaitLatch(events.onErrorCalled, "onError not called"); |
| Assert.assertTrue(events.onErrorThrowable.getClass().getName(), |
| exceptionClazz.isAssignableFrom(events.onErrorThrowable.getClass())); |
| } |
| |
| |
| @Override |
| @Before |
| public void setUp() throws Exception { |
| super.setUp(); |
| events = new Events(); |
| } |
| |
| |
| @Test |
| public void testTcpClose() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.closeSocket(); |
| |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| @Test |
| public void testTcpReset() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.forceCloseSocket(); |
| |
| // TODO: I'm not entirely sure when onError should be called |
| awaitOnError(IOException.class); |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| @Test |
| public void testWsCloseThenTcpClose() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendCloseFrame(CloseCodes.GOING_AWAY); |
| client.closeSocket(); |
| // The close code seen by the server endpoint depends on how quickly the |
| // server processes the socket close. If it processes it before it |
| // finishes processing the GOING_AWAY (NIO2 can do this) then the server |
| // endpoint will see CLOSED_ABNORMALLY rather than GOING_AWAY. |
| awaitOnClose(CloseCodes.GOING_AWAY, CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| @Test |
| public void testWsCloseThenTcpReset() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendCloseFrame(CloseCodes.GOING_AWAY); |
| client.forceCloseSocket(); |
| |
| // WebSocket 1.1, section 2.1.5 requires this to be CLOSED_ABNORMALLY if |
| // the container initiates the close and the close code from the client |
| // if the client initiates it. When the client resets the TCP connection |
| // after sending the close, different operating systems react different |
| // ways. Some present the close message then drop the connection, some |
| // just drop the connection. Therefore, this test has to handle both |
| // close codes. |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY, CloseCodes.GOING_AWAY); |
| } |
| |
| |
| @Test |
| public void testTcpCloseInOnMessage() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendMessage("Test"); |
| awaitLatch(events.onMessageCalled, "onMessage not called"); |
| |
| client.closeSocket(); |
| events.onMessageWait.countDown(); |
| |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| @Test |
| public void testTcpResetInOnMessage() throws Exception { |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendMessage("Test"); |
| awaitLatch(events.onMessageCalled, "onMessage not called"); |
| |
| client.forceCloseSocket(); |
| events.onMessageWait.countDown(); |
| |
| awaitOnError(IOException.class); |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| @Test |
| public void testTcpCloseWhenOnMessageSends() throws Exception { |
| events.onMessageSends = true; |
| testTcpCloseInOnMessage(); |
| } |
| |
| |
| @Test |
| public void testTcpResetWhenOnMessageSends() throws Exception { |
| events.onMessageSends = true; |
| testTcpResetInOnMessage(); |
| } |
| |
| |
| @Test |
| public void testWsCloseThenTcpCloseWhenOnMessageSends() throws Exception { |
| events.onMessageSends = true; |
| |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendMessage("Test"); |
| awaitLatch(events.onMessageCalled, "onMessage not called"); |
| |
| client.sendCloseFrame(CloseCodes.NORMAL_CLOSURE); |
| client.closeSocket(); |
| events.onMessageWait.countDown(); |
| |
| // BIO will see close from client before it sees the TCP close |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY, CloseCodes.NORMAL_CLOSURE); |
| } |
| |
| |
| @Test |
| public void testWsCloseThenTcpResetWhenOnMessageSends() throws Exception { |
| events.onMessageSends = true; |
| |
| startServer(TestEndpointConfig.class); |
| |
| TesterWsCloseClient client = new TesterWsCloseClient("localhost", getPort()); |
| client.httpUpgrade(BaseEndpointConfig.PATH); |
| client.sendMessage("Test"); |
| awaitLatch(events.onMessageCalled, "onMessage not called"); |
| |
| client.sendCloseFrame(CloseCodes.NORMAL_CLOSURE); |
| client.forceCloseSocket(); |
| events.onMessageWait.countDown(); |
| |
| awaitOnClose(CloseCodes.CLOSED_ABNORMALLY); |
| } |
| |
| |
| public static class TestEndpoint { |
| |
| @OnOpen |
| public void onOpen(Session session) { |
| session.getUserProperties().put( |
| "org.apache.tomcat.websocket.BLOCKING_SEND_TIMEOUT", Long.valueOf(3000)); |
| log.info("Session opened"); |
| } |
| |
| @OnMessage |
| public void onMessage(Session session, String message) { |
| log.info("Message received: " + message); |
| events.onMessageCalled.countDown(); |
| awaitLatch(events.onMessageWait, "onMessageWait not triggered"); |
| |
| if (events.onMessageSends) { |
| try { |
| int count = 0; |
| // The latches above are meant to ensure the correct |
| // sequence of events but in some cases, particularly with |
| // APR, there is a short delay between the client closing / |
| // resetting the connection and the server recognising that |
| // fact. This loop tries to ensure that it lasts much longer |
| // than that delay so any close / reset from the client |
| // triggers an error here. |
| while (count < 10) { |
| count++; |
| session.getBasicRemote().sendText("Test reply"); |
| Thread.sleep(500); |
| } |
| } catch (IOException | InterruptedException e) { |
| // Expected to fail |
| } |
| } |
| } |
| |
| @OnError |
| public void onError(Throwable t) { |
| log.info("onError", t); |
| events.onErrorThrowable = t; |
| events.onErrorCalled.countDown(); |
| } |
| |
| @OnClose |
| public void onClose(CloseReason cr) { |
| log.info("onClose: " + cr); |
| events.closeReason = cr; |
| events.onCloseCalled.countDown(); |
| } |
| } |
| |
| |
| public static class TestEndpointConfig extends BaseEndpointConfig { |
| |
| @Override |
| protected Class<?> getEndpointClass() { |
| return TestEndpoint.class; |
| } |
| |
| } |
| |
| |
| private Tomcat startServer( |
| final Class<? extends WsContextListener> configClass) |
| throws LifecycleException { |
| |
| Tomcat tomcat = getTomcatInstance(); |
| // No file system docBase required |
| Context ctx = tomcat.addContext("", null); |
| ctx.addApplicationListener(configClass.getName()); |
| Tomcat.addServlet(ctx, "default", new DefaultServlet()); |
| ctx.addServletMappingDecoded("/", "default"); |
| |
| tomcat.start(); |
| return tomcat; |
| } |
| |
| |
| public abstract static class BaseEndpointConfig extends WsContextListener { |
| |
| public static final String PATH = "/test"; |
| |
| protected abstract Class<?> getEndpointClass(); |
| |
| @Override |
| public void contextInitialized(ServletContextEvent sce) { |
| super.contextInitialized(sce); |
| |
| ServerContainer sc = (ServerContainer) sce |
| .getServletContext() |
| .getAttribute( |
| Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); |
| |
| ServerEndpointConfig sec = ServerEndpointConfig.Builder.create( |
| getEndpointClass(), PATH).build(); |
| |
| try { |
| sc.addEndpoint(sec); |
| } catch (DeploymentException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| } |
| } |