| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.tomcat.websocket.server; |
| |
| import java.io.IOException; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.EnumSet; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.SortedSet; |
| import java.util.TreeSet; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ConcurrentMap; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.SynchronousQueue; |
| import java.util.concurrent.ThreadFactory; |
| import java.util.concurrent.ThreadPoolExecutor; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicLong; |
| |
| import javax.servlet.DispatcherType; |
| import javax.servlet.FilterRegistration; |
| import javax.servlet.ServletContext; |
| import javax.servlet.ServletException; |
| import javax.servlet.http.HttpServletRequest; |
| import javax.servlet.http.HttpServletResponse; |
| import javax.websocket.CloseReason; |
| import javax.websocket.CloseReason.CloseCodes; |
| import javax.websocket.DeploymentException; |
| import javax.websocket.Encoder; |
| import javax.websocket.Endpoint; |
| import javax.websocket.server.ServerContainer; |
| import javax.websocket.server.ServerEndpoint; |
| import javax.websocket.server.ServerEndpointConfig; |
| import javax.websocket.server.ServerEndpointConfig.Configurator; |
| |
| import org.apache.juli.logging.Log; |
| import org.apache.juli.logging.LogFactory; |
| import org.apache.tomcat.InstanceManager; |
| import org.apache.tomcat.util.res.StringManager; |
| import org.apache.tomcat.websocket.WsSession; |
| import org.apache.tomcat.websocket.WsWebSocketContainer; |
| import org.apache.tomcat.websocket.pojo.PojoEndpointServer; |
| import org.apache.tomcat.websocket.pojo.PojoMethodMapping; |
| |
| /** |
| * Provides a per class loader (i.e. per web application) instance of a |
| * ServerContainer. Web application wide defaults may be configured by setting |
| * the following servlet context initialisation parameters to the desired |
| * values. |
| * <ul> |
| * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> |
| * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> |
| * </ul> |
| */ |
| public class WsServerContainer extends WsWebSocketContainer |
| implements ServerContainer { |
| |
| private static final StringManager sm = |
| StringManager.getManager(Constants.PACKAGE_NAME); |
| private final Log log = LogFactory.getLog(WsServerContainer.class); // must not be static |
| |
| private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = |
| new CloseReason(CloseCodes.VIOLATED_POLICY, |
| "This connection was established under an authenticated " + |
| "HTTP session that has ended."); |
| |
| private final WsWriteTimeout wsWriteTimeout = new WsWriteTimeout(); |
| |
| private final ServletContext servletContext; |
| private final Map<String,ServerEndpointConfig> configExactMatchMap = |
| new ConcurrentHashMap<>(); |
| private final ConcurrentMap<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap = |
| new ConcurrentHashMap<>(); |
| private volatile boolean enforceNoAddAfterHandshake = |
| org.apache.tomcat.websocket.Constants.STRICT_SPEC_COMPLIANCE; |
| private volatile boolean addAllowed = true; |
| private final ConcurrentMap<String,Set<WsSession>> authenticatedSessions = |
| new ConcurrentHashMap<>(); |
| private final ExecutorService executorService; |
| private final ThreadGroup threadGroup; |
| private volatile boolean endpointsRegistered = false; |
| |
| WsServerContainer(ServletContext servletContext) { |
| |
| this.servletContext = servletContext; |
| setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName())); |
| |
| // Configure servlet context wide defaults |
| String value = servletContext.getInitParameter( |
| Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); |
| if (value != null) { |
| setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value)); |
| } |
| |
| value = servletContext.getInitParameter( |
| Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); |
| if (value != null) { |
| setDefaultMaxTextMessageBufferSize(Integer.parseInt(value)); |
| } |
| |
| value = servletContext.getInitParameter( |
| Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM); |
| if (value != null) { |
| setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value)); |
| } |
| // Executor config |
| int executorCoreSize = 0; |
| long executorKeepAliveTimeSeconds = 60; |
| value = servletContext.getInitParameter( |
| Constants.EXECUTOR_CORE_SIZE_INIT_PARAM); |
| if (value != null) { |
| executorCoreSize = Integer.parseInt(value); |
| } |
| value = servletContext.getInitParameter( |
| Constants.EXECUTOR_KEEPALIVETIME_SECONDS_INIT_PARAM); |
| if (value != null) { |
| executorKeepAliveTimeSeconds = Long.parseLong(value); |
| } |
| |
| FilterRegistration.Dynamic fr = servletContext.addFilter( |
| "Tomcat WebSocket (JSR356) Filter", new WsFilter()); |
| fr.setAsyncSupported(true); |
| |
| EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST, |
| DispatcherType.FORWARD); |
| |
| fr.addMappingForUrlPatterns(types, true, "/*"); |
| |
| // Use a per web application executor for any threads that the WebSocket |
| // server code needs to create. Group all of the threads under a single |
| // ThreadGroup. |
| StringBuffer threadGroupName = new StringBuffer("WebSocketServer-"); |
| threadGroupName.append(servletContext.getVirtualServerName()); |
| threadGroupName.append('-'); |
| if ("".equals(servletContext.getContextPath())) { |
| threadGroupName.append("ROOT"); |
| } else { |
| threadGroupName.append(servletContext.getContextPath()); |
| } |
| threadGroup = new ThreadGroup(threadGroupName.toString()); |
| WsThreadFactory wsThreadFactory = new WsThreadFactory(threadGroup); |
| |
| executorService = new ThreadPoolExecutor(executorCoreSize, |
| Integer.MAX_VALUE, executorKeepAliveTimeSeconds, TimeUnit.SECONDS, |
| new SynchronousQueue<Runnable>(), wsThreadFactory); |
| } |
| |
| |
| /** |
| * Published the provided endpoint implementation at the specified path with |
| * the specified configuration. {@link #WsServerContainer(ServletContext)} |
| * must be called before calling this method. |
| * |
| * @param sec The configuration to use when creating endpoint instances |
| * @throws DeploymentException if the endpoint can not be published as |
| * requested |
| */ |
| @Override |
| public void addEndpoint(ServerEndpointConfig sec) |
| throws DeploymentException { |
| |
| if (enforceNoAddAfterHandshake && !addAllowed) { |
| throw new DeploymentException( |
| sm.getString("serverContainer.addNotAllowed")); |
| } |
| |
| if (servletContext == null) { |
| throw new DeploymentException( |
| sm.getString("serverContainer.servletContextMissing")); |
| } |
| String path = sec.getPath(); |
| |
| // Add method mapping to user properties |
| PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), |
| sec.getDecoders(), path); |
| if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null |
| || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) { |
| sec.getUserProperties().put( |
| PojoEndpointServer.POJO_METHOD_MAPPING_KEY, |
| methodMapping); |
| } |
| |
| UriTemplate uriTemplate = new UriTemplate(path); |
| if (uriTemplate.hasParameters()) { |
| Integer key = Integer.valueOf(uriTemplate.getSegmentCount()); |
| SortedSet<TemplatePathMatch> templateMatches = |
| configTemplateMatchMap.get(key); |
| if (templateMatches == null) { |
| // Ensure that if concurrent threads execute this block they |
| // both end up using the same TreeSet instance |
| templateMatches = new TreeSet<>( |
| TemplatePathMatchComparator.getInstance()); |
| configTemplateMatchMap.putIfAbsent(key, templateMatches); |
| templateMatches = configTemplateMatchMap.get(key); |
| } |
| if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) { |
| // Duplicate uriTemplate; |
| throw new DeploymentException( |
| sm.getString("serverContainer.duplicatePaths", path, |
| sec.getEndpointClass(), |
| sec.getEndpointClass())); |
| } |
| } else { |
| // Exact match |
| ServerEndpointConfig old = configExactMatchMap.put(path, sec); |
| if (old != null) { |
| // Duplicate path mappings |
| throw new DeploymentException( |
| sm.getString("serverContainer.duplicatePaths", path, |
| old.getEndpointClass(), |
| sec.getEndpointClass())); |
| } |
| } |
| |
| endpointsRegistered = true; |
| } |
| |
| |
| /** |
| * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)} |
| * for publishing plain old java objects (POJOs) that have been annotated as |
| * WebSocket endpoints. |
| * |
| * @param pojo The annotated POJO |
| */ |
| @Override |
| public void addEndpoint(Class<?> pojo) throws DeploymentException { |
| |
| ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class); |
| if (annotation == null) { |
| throw new DeploymentException( |
| sm.getString("serverContainer.missingAnnotation", |
| pojo.getName())); |
| } |
| String path = annotation.value(); |
| |
| // Validate encoders |
| validateEncoders(annotation.encoders()); |
| |
| // ServerEndpointConfig |
| ServerEndpointConfig sec; |
| Class<? extends Configurator> configuratorClazz = |
| annotation.configurator(); |
| Configurator configurator = null; |
| if (!configuratorClazz.equals(Configurator.class)) { |
| try { |
| configurator = annotation.configurator().getConstructor().newInstance(); |
| } catch (ReflectiveOperationException e) { |
| throw new DeploymentException(sm.getString( |
| "serverContainer.configuratorFail", |
| annotation.configurator().getName(), |
| pojo.getClass().getName()), e); |
| } |
| } |
| sec = ServerEndpointConfig.Builder.create(pojo, path). |
| decoders(Arrays.asList(annotation.decoders())). |
| encoders(Arrays.asList(annotation.encoders())). |
| subprotocols(Arrays.asList(annotation.subprotocols())). |
| configurator(configurator). |
| build(); |
| |
| addEndpoint(sec); |
| } |
| |
| |
| @Override |
| public void destroy() { |
| shutdownExecutor(); |
| super.destroy(); |
| // If the executor hasn't fully shutdown it won't be possible to |
| // destroy this thread group as there will still be threads running. |
| // Mark the thread group as daemon one, so that it destroys itself |
| // when thread count reaches zero. |
| // Synchronization on threadGroup is needed, as there is a race between |
| // destroy() call from termination of the last thread in thread group |
| // marked as daemon versus the explicit destroy() call. |
| int threadCount = threadGroup.activeCount(); |
| boolean success = false; |
| try { |
| while (true) { |
| int oldThreadCount = threadCount; |
| synchronized (threadGroup) { |
| if (threadCount > 0) { |
| Thread.yield(); |
| threadCount = threadGroup.activeCount(); |
| } |
| if (threadCount > 0 && threadCount != oldThreadCount) { |
| // Value not stabilized. Retry. |
| continue; |
| } |
| if (threadCount > 0) { |
| threadGroup.setDaemon(true); |
| } else { |
| threadGroup.destroy(); |
| success = true; |
| } |
| break; |
| } |
| } |
| } catch (IllegalThreadStateException exception) { |
| // Fall-through |
| } |
| if (!success) { |
| log.warn(sm.getString("serverContainer.threadGroupNotDestroyed", |
| threadGroup.getName(), Integer.valueOf(threadCount))); |
| } |
| } |
| |
| |
| boolean areEndpointsRegistered() { |
| return endpointsRegistered; |
| } |
| |
| |
| /** |
| * Until the WebSocket specification provides such a mechanism, this Tomcat |
| * proprietary method is provided to enable applications to programmatically |
| * determine whether or not to upgrade an individual request to WebSocket. |
| * <p> |
| * Note: This method is not used by Tomcat but is used directly by |
| * third-party code and must not be removed. |
| * |
| * @param request The request object to be upgraded |
| * @param response The response object to be populated with the result of |
| * the upgrade |
| * @param sec The server endpoint to use to process the upgrade request |
| * @param pathParams The path parameters associated with the upgrade request |
| * |
| * @throws ServletException If a configuration error prevents the upgrade |
| * from taking place |
| * @throws IOException If an I/O error occurs during the upgrade process |
| */ |
| public void doUpgrade(HttpServletRequest request, |
| HttpServletResponse response, ServerEndpointConfig sec, |
| Map<String,String> pathParams) |
| throws ServletException, IOException { |
| UpgradeUtil.doUpgrade(this, request, response, sec, pathParams); |
| } |
| |
| |
| public WsMappingResult findMapping(String path) { |
| |
| // Prevent registering additional endpoints once the first attempt has |
| // been made to use one |
| if (addAllowed) { |
| addAllowed = false; |
| } |
| |
| // Check an exact match. Simple case as there are no templates. |
| ServerEndpointConfig sec = configExactMatchMap.get(path); |
| if (sec != null) { |
| return new WsMappingResult(sec, Collections.<String, String>emptyMap()); |
| } |
| |
| // No exact match. Need to look for template matches. |
| UriTemplate pathUriTemplate = null; |
| try { |
| pathUriTemplate = new UriTemplate(path); |
| } catch (DeploymentException e) { |
| // Path is not valid so can't be matched to a WebSocketEndpoint |
| return null; |
| } |
| |
| // Number of segments has to match |
| Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount()); |
| SortedSet<TemplatePathMatch> templateMatches = |
| configTemplateMatchMap.get(key); |
| |
| if (templateMatches == null) { |
| // No templates with an equal number of segments so there will be |
| // no matches |
| return null; |
| } |
| |
| // List is in alphabetical order of normalised templates. |
| // Correct match is the first one that matches. |
| Map<String,String> pathParams = null; |
| for (TemplatePathMatch templateMatch : templateMatches) { |
| pathParams = templateMatch.getUriTemplate().match(pathUriTemplate); |
| if (pathParams != null) { |
| sec = templateMatch.getConfig(); |
| break; |
| } |
| } |
| |
| if (sec == null) { |
| // No match |
| return null; |
| } |
| |
| return new WsMappingResult(sec, pathParams); |
| } |
| |
| |
| |
| public boolean isEnforceNoAddAfterHandshake() { |
| return enforceNoAddAfterHandshake; |
| } |
| |
| |
| public void setEnforceNoAddAfterHandshake( |
| boolean enforceNoAddAfterHandshake) { |
| this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake; |
| } |
| |
| |
| protected WsWriteTimeout getTimeout() { |
| return wsWriteTimeout; |
| } |
| |
| |
| /** |
| * {@inheritDoc} |
| * |
| * Overridden to make it visible to other classes in this package. |
| */ |
| @Override |
| protected void registerSession(Endpoint endpoint, WsSession wsSession) { |
| super.registerSession(endpoint, wsSession); |
| if (wsSession.isOpen() && |
| wsSession.getUserPrincipal() != null && |
| wsSession.getHttpSessionId() != null) { |
| registerAuthenticatedSession(wsSession, |
| wsSession.getHttpSessionId()); |
| } |
| } |
| |
| |
| /** |
| * {@inheritDoc} |
| * |
| * Overridden to make it visible to other classes in this package. |
| */ |
| @Override |
| protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { |
| if (wsSession.getUserPrincipal() != null && |
| wsSession.getHttpSessionId() != null) { |
| unregisterAuthenticatedSession(wsSession, |
| wsSession.getHttpSessionId()); |
| } |
| super.unregisterSession(endpoint, wsSession); |
| } |
| |
| |
| private void registerAuthenticatedSession(WsSession wsSession, |
| String httpSessionId) { |
| Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); |
| if (wsSessions == null) { |
| wsSessions = Collections.newSetFromMap( |
| new ConcurrentHashMap<WsSession,Boolean>()); |
| authenticatedSessions.putIfAbsent(httpSessionId, wsSessions); |
| wsSessions = authenticatedSessions.get(httpSessionId); |
| } |
| wsSessions.add(wsSession); |
| } |
| |
| |
| private void unregisterAuthenticatedSession(WsSession wsSession, |
| String httpSessionId) { |
| Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); |
| // wsSessions will be null if the HTTP session has ended |
| if (wsSessions != null) { |
| wsSessions.remove(wsSession); |
| } |
| } |
| |
| |
| public void closeAuthenticatedSession(String httpSessionId) { |
| Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId); |
| |
| if (wsSessions != null && !wsSessions.isEmpty()) { |
| for (WsSession wsSession : wsSessions) { |
| try { |
| wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED); |
| } catch (IOException e) { |
| // Any IOExceptions during close will have been caught and the |
| // onError method called. |
| } |
| } |
| } |
| } |
| |
| |
| ExecutorService getExecutorService() { |
| return executorService; |
| } |
| |
| |
| private void shutdownExecutor() { |
| if (executorService == null) { |
| return; |
| } |
| executorService.shutdown(); |
| try { |
| executorService.awaitTermination(10, TimeUnit.SECONDS); |
| } catch (InterruptedException e) { |
| // Ignore the interruption and carry on |
| } |
| } |
| |
| private static void validateEncoders(Class<? extends Encoder>[] encoders) |
| throws DeploymentException { |
| |
| for (Class<? extends Encoder> encoder : encoders) { |
| // Need to instantiate decoder to ensure it is valid and that |
| // deployment can be failed if it is not |
| @SuppressWarnings("unused") |
| Encoder instance; |
| try { |
| encoder.getConstructor().newInstance(); |
| } catch(ReflectiveOperationException e) { |
| throw new DeploymentException(sm.getString( |
| "serverContainer.encoderFail", encoder.getName()), e); |
| } |
| } |
| } |
| |
| |
| private static class TemplatePathMatch { |
| private final ServerEndpointConfig config; |
| private final UriTemplate uriTemplate; |
| |
| public TemplatePathMatch(ServerEndpointConfig config, |
| UriTemplate uriTemplate) { |
| this.config = config; |
| this.uriTemplate = uriTemplate; |
| } |
| |
| |
| public ServerEndpointConfig getConfig() { |
| return config; |
| } |
| |
| |
| public UriTemplate getUriTemplate() { |
| return uriTemplate; |
| } |
| } |
| |
| |
| /** |
| * This Comparator implementation is thread-safe so only create a single |
| * instance. |
| */ |
| private static class TemplatePathMatchComparator |
| implements Comparator<TemplatePathMatch> { |
| |
| private static final TemplatePathMatchComparator INSTANCE = |
| new TemplatePathMatchComparator(); |
| |
| public static TemplatePathMatchComparator getInstance() { |
| return INSTANCE; |
| } |
| |
| private TemplatePathMatchComparator() { |
| // Hide default constructor |
| } |
| |
| @Override |
| public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) { |
| return tpm1.getUriTemplate().getNormalizedPath().compareTo( |
| tpm2.getUriTemplate().getNormalizedPath()); |
| } |
| } |
| |
| |
| private static class WsThreadFactory implements ThreadFactory { |
| |
| private final ThreadGroup tg; |
| private final AtomicLong count = new AtomicLong(0); |
| |
| private WsThreadFactory(ThreadGroup tg) { |
| this.tg = tg; |
| } |
| |
| @Override |
| public Thread newThread(Runnable r) { |
| Thread t = new Thread(tg, r); |
| t.setName(tg.getName() + "-" + count.incrementAndGet()); |
| return t; |
| } |
| } |
| } |