| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.tomcat.websocket; |
| |
| import java.io.InputStream; |
| import java.io.Reader; |
| import java.lang.reflect.GenericArrayType; |
| import java.lang.reflect.Method; |
| import java.lang.reflect.ParameterizedType; |
| import java.lang.reflect.Type; |
| import java.lang.reflect.TypeVariable; |
| import java.nio.ByteBuffer; |
| import java.security.NoSuchAlgorithmException; |
| import java.security.SecureRandom; |
| import java.util.ArrayList; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Queue; |
| import java.util.Set; |
| import java.util.concurrent.ConcurrentLinkedQueue; |
| |
| import javax.naming.NamingException; |
| |
| import jakarta.websocket.CloseReason.CloseCode; |
| import jakarta.websocket.CloseReason.CloseCodes; |
| import jakarta.websocket.Decoder; |
| import jakarta.websocket.Decoder.Binary; |
| import jakarta.websocket.Decoder.BinaryStream; |
| import jakarta.websocket.Decoder.Text; |
| import jakarta.websocket.Decoder.TextStream; |
| import jakarta.websocket.DeploymentException; |
| import jakarta.websocket.Encoder; |
| import jakarta.websocket.EndpointConfig; |
| import jakarta.websocket.Extension; |
| import jakarta.websocket.MessageHandler; |
| import jakarta.websocket.PongMessage; |
| import jakarta.websocket.Session; |
| |
| import org.apache.tomcat.InstanceManager; |
| import org.apache.tomcat.util.res.StringManager; |
| import org.apache.tomcat.websocket.pojo.PojoMessageHandlerPartialBinary; |
| import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeBinary; |
| import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeText; |
| |
| /** |
| * Utility class for internal use only within the {@link org.apache.tomcat.websocket} package. |
| */ |
| public class Util { |
| |
| private static final StringManager sm = StringManager.getManager(Util.class); |
| private static final Queue<SecureRandom> randoms = new ConcurrentLinkedQueue<>(); |
| |
| private Util() { |
| // Hide default constructor |
| } |
| |
| |
| static boolean isControl(byte opCode) { |
| return (opCode & 0x08) != 0; |
| } |
| |
| |
| static boolean isText(byte opCode) { |
| return opCode == Constants.OPCODE_TEXT; |
| } |
| |
| |
| static boolean isContinuation(byte opCode) { |
| return opCode == Constants.OPCODE_CONTINUATION; |
| } |
| |
| |
| static CloseCode getCloseCode(int code) { |
| if (code > 2999 && code < 5000) { |
| return CloseCodes.getCloseCode(code); |
| } |
| switch (code) { |
| case 1000: |
| return CloseCodes.NORMAL_CLOSURE; |
| case 1001: |
| return CloseCodes.GOING_AWAY; |
| case 1002: |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1003: |
| return CloseCodes.CANNOT_ACCEPT; |
| case 1004: |
| // Should not be used in a close frame |
| // return CloseCodes.RESERVED; |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1005: |
| // Should not be used in a close frame |
| // return CloseCodes.NO_STATUS_CODE; |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1006: |
| // Should not be used in a close frame |
| // return CloseCodes.CLOSED_ABNORMALLY; |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1007: |
| return CloseCodes.NOT_CONSISTENT; |
| case 1008: |
| return CloseCodes.VIOLATED_POLICY; |
| case 1009: |
| return CloseCodes.TOO_BIG; |
| case 1010: |
| return CloseCodes.NO_EXTENSION; |
| case 1011: |
| return CloseCodes.UNEXPECTED_CONDITION; |
| case 1012: |
| // Not in RFC6455 |
| // return CloseCodes.SERVICE_RESTART; |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1013: |
| // Not in RFC6455 |
| // return CloseCodes.TRY_AGAIN_LATER; |
| return CloseCodes.PROTOCOL_ERROR; |
| case 1015: |
| // Should not be used in a close frame |
| // return CloseCodes.TLS_HANDSHAKE_FAILURE; |
| return CloseCodes.PROTOCOL_ERROR; |
| default: |
| return CloseCodes.PROTOCOL_ERROR; |
| } |
| } |
| |
| |
| static byte[] generateMask() { |
| // SecureRandom is not thread-safe so need to make sure only one thread |
| // uses it at a time. In theory, the pool could grow to the same size |
| // as the number of request processing threads. In reality it will be |
| // a lot smaller. |
| |
| // Get a SecureRandom from the pool |
| SecureRandom sr = randoms.poll(); |
| |
| // If one isn't available, generate a new one |
| if (sr == null) { |
| try { |
| sr = SecureRandom.getInstance("SHA1PRNG"); |
| } catch (NoSuchAlgorithmException e) { |
| // Fall back to platform default |
| sr = new SecureRandom(); |
| } |
| } |
| |
| // Generate the mask |
| byte[] result = new byte[4]; |
| sr.nextBytes(result); |
| |
| // Put the SecureRandom back in the poll |
| randoms.add(sr); |
| |
| return result; |
| } |
| |
| |
| static Class<?> getMessageType(MessageHandler listener) { |
| return getGenericType(MessageHandler.class, listener.getClass()).getClazz(); |
| } |
| |
| |
| private static Class<?> getDecoderType(Class<? extends Decoder> decoder) { |
| return getGenericType(Decoder.class, decoder).getClazz(); |
| } |
| |
| |
| static Class<?> getEncoderType(Class<? extends Encoder> encoder) { |
| return getGenericType(Encoder.class, encoder).getClazz(); |
| } |
| |
| |
| private static <T> TypeResult getGenericType(Class<T> type, Class<? extends T> clazz) { |
| |
| // Look to see if this class implements the interface of interest |
| |
| // Get all the interfaces |
| Type[] interfaces = clazz.getGenericInterfaces(); |
| for (Type iface : interfaces) { |
| // Only need to check interfaces that use generics |
| if (iface instanceof ParameterizedType) { |
| ParameterizedType pi = (ParameterizedType) iface; |
| // Look for the interface of interest |
| if (pi.getRawType() instanceof Class) { |
| if (type.isAssignableFrom((Class<?>) pi.getRawType())) { |
| return getTypeParameter(clazz, pi.getActualTypeArguments()[0]); |
| } |
| } |
| } |
| } |
| |
| // Interface not found on this class. Look at the superclass. |
| @SuppressWarnings("unchecked") |
| Class<? extends T> superClazz = (Class<? extends T>) clazz.getSuperclass(); |
| if (superClazz == null) { |
| // Finished looking up the class hierarchy without finding anything |
| return null; |
| } |
| |
| TypeResult superClassTypeResult = getGenericType(type, superClazz); |
| int dimension = superClassTypeResult.getDimension(); |
| if (superClassTypeResult.getIndex() == -1 && dimension == 0) { |
| // Superclass implements interface and defines explicit type for |
| // the interface of interest |
| return superClassTypeResult; |
| } |
| |
| if (superClassTypeResult.getIndex() > -1) { |
| // Superclass implements interface and defines unknown type for |
| // the interface of interest |
| // Map that unknown type to the generic types defined in this class |
| ParameterizedType superClassType = (ParameterizedType) clazz.getGenericSuperclass(); |
| TypeResult result = getTypeParameter(clazz, |
| superClassType.getActualTypeArguments()[superClassTypeResult.getIndex()]); |
| result.incrementDimension(superClassTypeResult.getDimension()); |
| if (result.getClazz() != null && result.getDimension() > 0) { |
| superClassTypeResult = result; |
| } else { |
| return result; |
| } |
| } |
| |
| if (superClassTypeResult.getDimension() > 0) { |
| StringBuilder className = new StringBuilder(); |
| for (int i = 0; i < dimension; i++) { |
| className.append('['); |
| } |
| className.append('L'); |
| className.append(superClassTypeResult.getClazz().getCanonicalName()); |
| className.append(';'); |
| |
| Class<?> arrayClazz; |
| try { |
| arrayClazz = Class.forName(className.toString()); |
| } catch (ClassNotFoundException e) { |
| throw new IllegalArgumentException(e); |
| } |
| |
| return new TypeResult(arrayClazz, -1, 0); |
| } |
| |
| // Error will be logged further up the call stack |
| return null; |
| } |
| |
| |
| /* |
| * For a generic parameter, return either the Class used or if the type is unknown, the index for the type in |
| * definition of the class |
| */ |
| private static TypeResult getTypeParameter(Class<?> clazz, Type argType) { |
| if (argType instanceof Class<?>) { |
| return new TypeResult((Class<?>) argType, -1, 0); |
| } else if (argType instanceof ParameterizedType) { |
| return new TypeResult((Class<?>) ((ParameterizedType) argType).getRawType(), -1, 0); |
| } else if (argType instanceof GenericArrayType) { |
| Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType(); |
| TypeResult result = getTypeParameter(clazz, arrayElementType); |
| result.incrementDimension(1); |
| return result; |
| } else { |
| TypeVariable<?>[] tvs = clazz.getTypeParameters(); |
| for (int i = 0; i < tvs.length; i++) { |
| if (tvs[i].equals(argType)) { |
| return new TypeResult(null, i, 0); |
| } |
| } |
| return null; |
| } |
| } |
| |
| |
| public static boolean isPrimitive(Class<?> clazz) { |
| if (clazz.isPrimitive()) { |
| return true; |
| } else if (clazz.equals(Boolean.class) || clazz.equals(Byte.class) || clazz.equals(Character.class) || |
| clazz.equals(Double.class) || clazz.equals(Float.class) || clazz.equals(Integer.class) || |
| clazz.equals(Long.class) || clazz.equals(Short.class)) { |
| return true; |
| } |
| return false; |
| } |
| |
| |
| public static Object coerceToType(Class<?> type, String value) { |
| if (type.equals(String.class)) { |
| return value; |
| } else if (type.equals(boolean.class) || type.equals(Boolean.class)) { |
| return Boolean.valueOf(value); |
| } else if (type.equals(byte.class) || type.equals(Byte.class)) { |
| return Byte.valueOf(value); |
| } else if (type.equals(char.class) || type.equals(Character.class)) { |
| return Character.valueOf(value.charAt(0)); |
| } else if (type.equals(double.class) || type.equals(Double.class)) { |
| return Double.valueOf(value); |
| } else if (type.equals(float.class) || type.equals(Float.class)) { |
| return Float.valueOf(value); |
| } else if (type.equals(int.class) || type.equals(Integer.class)) { |
| return Integer.valueOf(value); |
| } else if (type.equals(long.class) || type.equals(Long.class)) { |
| return Long.valueOf(value); |
| } else if (type.equals(short.class) || type.equals(Short.class)) { |
| return Short.valueOf(value); |
| } else { |
| throw new IllegalArgumentException(sm.getString("util.invalidType", value, type.getName())); |
| } |
| } |
| |
| |
| /** |
| * Build the list of decoder entries from a set of decoder implementations. |
| * |
| * @param decoderClazzes Decoder implementation classes |
| * @param instanceManager Instance manager to use to create Decoder instances |
| * |
| * @return List of mappings from target type to associated decoder |
| * |
| * @throws DeploymentException If a provided decoder class is not valid |
| */ |
| public static List<DecoderEntry> getDecoders(List<Class<? extends Decoder>> decoderClazzes, |
| InstanceManager instanceManager) throws DeploymentException { |
| |
| List<DecoderEntry> result = new ArrayList<>(); |
| if (decoderClazzes != null) { |
| for (Class<? extends Decoder> decoderClazz : decoderClazzes) { |
| // Need to instantiate decoder to ensure it is valid and that |
| // deployment can be failed if it is not |
| Decoder instance; |
| try { |
| if (instanceManager == null) { |
| instance = decoderClazz.getConstructor().newInstance(); |
| } else { |
| instance = (Decoder) instanceManager.newInstance(decoderClazz); |
| // Don't need this instance, so destroy it |
| instanceManager.destroyInstance(instance); |
| } |
| } catch (ReflectiveOperationException | IllegalArgumentException | SecurityException |
| | NamingException e) { |
| throw new DeploymentException( |
| sm.getString("pojoMethodMapping.invalidDecoder", decoderClazz.getName()), e); |
| } |
| DecoderEntry entry = new DecoderEntry(getDecoderType(decoderClazz), decoderClazz); |
| result.add(entry); |
| } |
| } |
| |
| return result; |
| } |
| |
| |
| static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, MessageHandler listener, |
| EndpointConfig endpointConfig, Session session) { |
| |
| // Will never be more than 2 types |
| Set<MessageHandlerResult> results = new HashSet<>(2); |
| |
| // Simple cases - handlers already accepts one of the types expected by |
| // the frame handling code |
| if (String.class.isAssignableFrom(target)) { |
| MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.TEXT); |
| results.add(result); |
| } else if (ByteBuffer.class.isAssignableFrom(target)) { |
| MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.BINARY); |
| results.add(result); |
| } else if (PongMessage.class.isAssignableFrom(target)) { |
| MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.PONG); |
| results.add(result); |
| // Handler needs wrapping and optional decoder to convert it to one of |
| // the types expected by the frame handling code |
| } else if (byte[].class.isAssignableFrom(target)) { |
| boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass()); |
| MessageHandlerResult result = new MessageHandlerResult(whole |
| ? new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig, |
| matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()), |
| new Object[1], 0, true, -1, false, -1) |
| : new PojoMessageHandlerPartialBinary(listener, getOnMessagePartialMethod(listener), session, |
| new Object[2], 0, true, 1, -1, -1), |
| MessageHandlerResultType.BINARY); |
| results.add(result); |
| } else if (InputStream.class.isAssignableFrom(target)) { |
| MessageHandlerResult result = new MessageHandlerResult( |
| new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig, |
| matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()), |
| new Object[1], 0, true, -1, true, -1), |
| MessageHandlerResultType.BINARY); |
| results.add(result); |
| } else if (Reader.class.isAssignableFrom(target)) { |
| MessageHandlerResult result = new MessageHandlerResult( |
| new PojoMessageHandlerWholeText(listener, getOnMessageMethod(listener), session, endpointConfig, |
| matchDecoders(target, endpointConfig, false, ((WsSession) session).getInstanceManager()), |
| new Object[1], 0, true, -1, -1), |
| MessageHandlerResultType.TEXT); |
| results.add(result); |
| } else { |
| // Handler needs wrapping and requires decoder to convert it to one |
| // of the types expected by the frame handling code |
| DecoderMatch decoderMatch = matchDecoders(target, endpointConfig, |
| ((WsSession) session).getInstanceManager()); |
| Method m = getOnMessageMethod(listener); |
| if (decoderMatch.getBinaryDecoders().size() > 0) { |
| MessageHandlerResult result = new MessageHandlerResult( |
| new PojoMessageHandlerWholeBinary(listener, m, session, endpointConfig, |
| decoderMatch.getBinaryDecoders(), new Object[1], 0, false, -1, false, -1), |
| MessageHandlerResultType.BINARY); |
| results.add(result); |
| } |
| if (decoderMatch.getTextDecoders().size() > 0) { |
| MessageHandlerResult result = new MessageHandlerResult( |
| new PojoMessageHandlerWholeText(listener, m, session, endpointConfig, |
| decoderMatch.getTextDecoders(), new Object[1], 0, false, -1, -1), |
| MessageHandlerResultType.TEXT); |
| results.add(result); |
| } |
| } |
| |
| if (results.size() == 0) { |
| throw new IllegalArgumentException(sm.getString("wsSession.unknownHandler", listener, target)); |
| } |
| |
| return results; |
| } |
| |
| private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, EndpointConfig endpointConfig, |
| boolean binary, InstanceManager instanceManager) { |
| DecoderMatch decoderMatch = matchDecoders(target, endpointConfig, instanceManager); |
| if (binary) { |
| if (decoderMatch.getBinaryDecoders().size() > 0) { |
| return decoderMatch.getBinaryDecoders(); |
| } |
| } else if (decoderMatch.getTextDecoders().size() > 0) { |
| return decoderMatch.getTextDecoders(); |
| } |
| return null; |
| } |
| |
| private static DecoderMatch matchDecoders(Class<?> target, EndpointConfig endpointConfig, |
| InstanceManager instanceManager) { |
| DecoderMatch decoderMatch; |
| try { |
| List<Class<? extends Decoder>> decoders = endpointConfig.getDecoders(); |
| List<DecoderEntry> decoderEntries = getDecoders(decoders, instanceManager); |
| decoderMatch = new DecoderMatch(target, decoderEntries); |
| } catch (DeploymentException e) { |
| throw new IllegalArgumentException(e); |
| } |
| return decoderMatch; |
| } |
| |
| public static void parseExtensionHeader(List<Extension> extensions, String header) { |
| // The relevant ABNF for the Sec-WebSocket-Extensions is as follows: |
| // extension-list = 1#extension |
| // extension = extension-token *( ";" extension-param ) |
| // extension-token = registered-token |
| // registered-token = token |
| // extension-param = token [ "=" (token | quoted-string) ] |
| // ; When using the quoted-string syntax variant, the value |
| // ; after quoted-string unescaping MUST conform to the |
| // ; 'token' ABNF. |
| // |
| // The limiting of parameter values to tokens or "quoted tokens" makes |
| // the parsing of the header significantly simpler and allows a number |
| // of short-cuts to be taken. |
| |
| // Step one, split the header into individual extensions using ',' as a |
| // separator |
| String unparsedExtensions[] = header.split(","); |
| for (String unparsedExtension : unparsedExtensions) { |
| // Step two, split the extension into the registered name and |
| // parameter/value pairs using ';' as a separator |
| String unparsedParameters[] = unparsedExtension.split(";"); |
| WsExtension extension = new WsExtension(unparsedParameters[0].trim()); |
| |
| for (int i = 1; i < unparsedParameters.length; i++) { |
| int equalsPos = unparsedParameters[i].indexOf('='); |
| String name; |
| String value; |
| if (equalsPos == -1) { |
| name = unparsedParameters[i].trim(); |
| value = null; |
| } else { |
| name = unparsedParameters[i].substring(0, equalsPos).trim(); |
| value = unparsedParameters[i].substring(equalsPos + 1).trim(); |
| int len = value.length(); |
| if (len > 1) { |
| if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') { |
| value = value.substring(1, value.length() - 1); |
| } |
| } |
| } |
| // Make sure value doesn't contain any of the delimiters since |
| // that would indicate something went wrong |
| if (containsDelims(name) || containsDelims(value)) { |
| throw new IllegalArgumentException(sm.getString("util.notToken", name, value)); |
| } |
| if (value != null && (value.indexOf(',') > -1 || value.indexOf(';') > -1 || value.indexOf('\"') > -1 || |
| value.indexOf('=') > -1)) { |
| throw new IllegalArgumentException(sm.getString("util.invalidValue", value)); |
| } |
| extension.addParameter(new WsExtensionParameter(name, value)); |
| } |
| extensions.add(extension); |
| } |
| } |
| |
| |
| private static boolean containsDelims(String input) { |
| if (input == null || input.length() == 0) { |
| return false; |
| } |
| for (char c : input.toCharArray()) { |
| switch (c) { |
| case ',': |
| case ';': |
| case '\"': |
| case '=': |
| return true; |
| default: |
| // NO_OP |
| } |
| |
| } |
| return false; |
| } |
| |
| private static Method getOnMessageMethod(MessageHandler listener) { |
| try { |
| return listener.getClass().getMethod("onMessage", Object.class); |
| } catch (NoSuchMethodException | SecurityException e) { |
| throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e); |
| } |
| } |
| |
| private static Method getOnMessagePartialMethod(MessageHandler listener) { |
| try { |
| return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE); |
| } catch (NoSuchMethodException | SecurityException e) { |
| throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e); |
| } |
| } |
| |
| |
| public static class DecoderMatch { |
| |
| private final List<Class<? extends Decoder>> textDecoders = new ArrayList<>(); |
| private final List<Class<? extends Decoder>> binaryDecoders = new ArrayList<>(); |
| private final Class<?> target; |
| |
| public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) { |
| this.target = target; |
| for (DecoderEntry decoderEntry : decoderEntries) { |
| if (decoderEntry.getClazz().isAssignableFrom(target)) { |
| if (Binary.class.isAssignableFrom(decoderEntry.getDecoderClazz())) { |
| binaryDecoders.add(decoderEntry.getDecoderClazz()); |
| // willDecode() method means this decoder may or may not |
| // decode a message so need to carry on checking for |
| // other matches |
| } else if (BinaryStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) { |
| binaryDecoders.add(decoderEntry.getDecoderClazz()); |
| // Stream decoders have to process the message so no |
| // more decoders can be matched |
| break; |
| } else if (Text.class.isAssignableFrom(decoderEntry.getDecoderClazz())) { |
| textDecoders.add(decoderEntry.getDecoderClazz()); |
| // willDecode() method means this decoder may or may not |
| // decode a message so need to carry on checking for |
| // other matches |
| } else if (TextStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) { |
| textDecoders.add(decoderEntry.getDecoderClazz()); |
| // Stream decoders have to process the message so no |
| // more decoders can be matched |
| break; |
| } else { |
| throw new IllegalArgumentException(sm.getString("util.unknownDecoderType")); |
| } |
| } |
| } |
| } |
| |
| |
| public List<Class<? extends Decoder>> getTextDecoders() { |
| return textDecoders; |
| } |
| |
| |
| public List<Class<? extends Decoder>> getBinaryDecoders() { |
| return binaryDecoders; |
| } |
| |
| |
| public Class<?> getTarget() { |
| return target; |
| } |
| |
| |
| public boolean hasMatches() { |
| return (textDecoders.size() > 0) || (binaryDecoders.size() > 0); |
| } |
| } |
| |
| |
| private static class TypeResult { |
| private final Class<?> clazz; |
| private final int index; |
| private int dimension; |
| |
| TypeResult(Class<?> clazz, int index, int dimension) { |
| this.clazz = clazz; |
| this.index = index; |
| this.dimension = dimension; |
| } |
| |
| public Class<?> getClazz() { |
| return clazz; |
| } |
| |
| public int getIndex() { |
| return index; |
| } |
| |
| public int getDimension() { |
| return dimension; |
| } |
| |
| public void incrementDimension(int inc) { |
| dimension += inc; |
| } |
| } |
| } |