Util.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.websocket;
import java.io.InputStream;
import java.io.Reader;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.naming.NamingException;
import jakarta.websocket.CloseReason.CloseCode;
import jakarta.websocket.CloseReason.CloseCodes;
import jakarta.websocket.Decoder;
import jakarta.websocket.Decoder.Binary;
import jakarta.websocket.Decoder.BinaryStream;
import jakarta.websocket.Decoder.Text;
import jakarta.websocket.Decoder.TextStream;
import jakarta.websocket.DeploymentException;
import jakarta.websocket.Encoder;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Extension;
import jakarta.websocket.MessageHandler;
import jakarta.websocket.PongMessage;
import jakarta.websocket.Session;
import org.apache.tomcat.InstanceManager;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerPartialBinary;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeBinary;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeText;
/**
* Utility class for internal use only within the {@link org.apache.tomcat.websocket} package.
*/
public class Util {
private static final StringManager sm = StringManager.getManager(Util.class);
private static final Queue<SecureRandom> randoms = new ConcurrentLinkedQueue<>();
private Util() {
// Hide default constructor
}
static boolean isControl(byte opCode) {
return (opCode & 0x08) != 0;
}
static boolean isText(byte opCode) {
return opCode == Constants.OPCODE_TEXT;
}
static boolean isContinuation(byte opCode) {
return opCode == Constants.OPCODE_CONTINUATION;
}
static CloseCode getCloseCode(int code) {
if (code > 2999 && code < 5000) {
return CloseCodes.getCloseCode(code);
}
switch (code) {
case 1000:
return CloseCodes.NORMAL_CLOSURE;
case 1001:
return CloseCodes.GOING_AWAY;
case 1002:
return CloseCodes.PROTOCOL_ERROR;
case 1003:
return CloseCodes.CANNOT_ACCEPT;
case 1004:
// Should not be used in a close frame
// return CloseCodes.RESERVED;
return CloseCodes.PROTOCOL_ERROR;
case 1005:
// Should not be used in a close frame
// return CloseCodes.NO_STATUS_CODE;
return CloseCodes.PROTOCOL_ERROR;
case 1006:
// Should not be used in a close frame
// return CloseCodes.CLOSED_ABNORMALLY;
return CloseCodes.PROTOCOL_ERROR;
case 1007:
return CloseCodes.NOT_CONSISTENT;
case 1008:
return CloseCodes.VIOLATED_POLICY;
case 1009:
return CloseCodes.TOO_BIG;
case 1010:
return CloseCodes.NO_EXTENSION;
case 1011:
return CloseCodes.UNEXPECTED_CONDITION;
case 1012:
// Not in RFC6455
// return CloseCodes.SERVICE_RESTART;
return CloseCodes.PROTOCOL_ERROR;
case 1013:
// Not in RFC6455
// return CloseCodes.TRY_AGAIN_LATER;
return CloseCodes.PROTOCOL_ERROR;
case 1015:
// Should not be used in a close frame
// return CloseCodes.TLS_HANDSHAKE_FAILURE;
return CloseCodes.PROTOCOL_ERROR;
default:
return CloseCodes.PROTOCOL_ERROR;
}
}
static byte[] generateMask() {
// SecureRandom is not thread-safe so need to make sure only one thread
// uses it at a time. In theory, the pool could grow to the same size
// as the number of request processing threads. In reality it will be
// a lot smaller.
// Get a SecureRandom from the pool
SecureRandom sr = randoms.poll();
// If one isn't available, generate a new one
if (sr == null) {
try {
sr = SecureRandom.getInstance("SHA1PRNG");
} catch (NoSuchAlgorithmException e) {
// Fall back to platform default
sr = new SecureRandom();
}
}
// Generate the mask
byte[] result = new byte[4];
sr.nextBytes(result);
// Put the SecureRandom back in the poll
randoms.add(sr);
return result;
}
static Class<?> getMessageType(MessageHandler listener) {
return getGenericType(MessageHandler.class, listener.getClass()).getClazz();
}
private static Class<?> getDecoderType(Class<? extends Decoder> decoder) {
return getGenericType(Decoder.class, decoder).getClazz();
}
static Class<?> getEncoderType(Class<? extends Encoder> encoder) {
return getGenericType(Encoder.class, encoder).getClazz();
}
private static <T> TypeResult getGenericType(Class<T> type, Class<? extends T> clazz) {
// Look to see if this class implements the interface of interest
// Get all the interfaces
Type[] interfaces = clazz.getGenericInterfaces();
for (Type iface : interfaces) {
// Only need to check interfaces that use generics
if (iface instanceof ParameterizedType) {
ParameterizedType pi = (ParameterizedType) iface;
// Look for the interface of interest
if (pi.getRawType() instanceof Class) {
if (type.isAssignableFrom((Class<?>) pi.getRawType())) {
return getTypeParameter(clazz, pi.getActualTypeArguments()[0]);
}
}
}
}
// Interface not found on this class. Look at the superclass.
@SuppressWarnings("unchecked")
Class<? extends T> superClazz = (Class<? extends T>) clazz.getSuperclass();
if (superClazz == null) {
// Finished looking up the class hierarchy without finding anything
return null;
}
TypeResult superClassTypeResult = getGenericType(type, superClazz);
int dimension = superClassTypeResult.getDimension();
if (superClassTypeResult.getIndex() == -1 && dimension == 0) {
// Superclass implements interface and defines explicit type for
// the interface of interest
return superClassTypeResult;
}
if (superClassTypeResult.getIndex() > -1) {
// Superclass implements interface and defines unknown type for
// the interface of interest
// Map that unknown type to the generic types defined in this class
ParameterizedType superClassType = (ParameterizedType) clazz.getGenericSuperclass();
TypeResult result = getTypeParameter(clazz,
superClassType.getActualTypeArguments()[superClassTypeResult.getIndex()]);
result.incrementDimension(superClassTypeResult.getDimension());
if (result.getClazz() != null && result.getDimension() > 0) {
superClassTypeResult = result;
} else {
return result;
}
}
if (superClassTypeResult.getDimension() > 0) {
StringBuilder className = new StringBuilder();
for (int i = 0; i < dimension; i++) {
className.append('[');
}
className.append('L');
className.append(superClassTypeResult.getClazz().getCanonicalName());
className.append(';');
Class<?> arrayClazz;
try {
arrayClazz = Class.forName(className.toString());
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException(e);
}
return new TypeResult(arrayClazz, -1, 0);
}
// Error will be logged further up the call stack
return null;
}
/*
* For a generic parameter, return either the Class used or if the type is unknown, the index for the type in
* definition of the class
*/
private static TypeResult getTypeParameter(Class<?> clazz, Type argType) {
if (argType instanceof Class<?>) {
return new TypeResult((Class<?>) argType, -1, 0);
} else if (argType instanceof ParameterizedType) {
return new TypeResult((Class<?>) ((ParameterizedType) argType).getRawType(), -1, 0);
} else if (argType instanceof GenericArrayType) {
Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType();
TypeResult result = getTypeParameter(clazz, arrayElementType);
result.incrementDimension(1);
return result;
} else {
TypeVariable<?>[] tvs = clazz.getTypeParameters();
for (int i = 0; i < tvs.length; i++) {
if (tvs[i].equals(argType)) {
return new TypeResult(null, i, 0);
}
}
return null;
}
}
public static boolean isPrimitive(Class<?> clazz) {
if (clazz.isPrimitive()) {
return true;
} else if (clazz.equals(Boolean.class) || clazz.equals(Byte.class) || clazz.equals(Character.class) ||
clazz.equals(Double.class) || clazz.equals(Float.class) || clazz.equals(Integer.class) ||
clazz.equals(Long.class) || clazz.equals(Short.class)) {
return true;
}
return false;
}
public static Object coerceToType(Class<?> type, String value) {
if (type.equals(String.class)) {
return value;
} else if (type.equals(boolean.class) || type.equals(Boolean.class)) {
return Boolean.valueOf(value);
} else if (type.equals(byte.class) || type.equals(Byte.class)) {
return Byte.valueOf(value);
} else if (type.equals(char.class) || type.equals(Character.class)) {
return Character.valueOf(value.charAt(0));
} else if (type.equals(double.class) || type.equals(Double.class)) {
return Double.valueOf(value);
} else if (type.equals(float.class) || type.equals(Float.class)) {
return Float.valueOf(value);
} else if (type.equals(int.class) || type.equals(Integer.class)) {
return Integer.valueOf(value);
} else if (type.equals(long.class) || type.equals(Long.class)) {
return Long.valueOf(value);
} else if (type.equals(short.class) || type.equals(Short.class)) {
return Short.valueOf(value);
} else {
throw new IllegalArgumentException(sm.getString("util.invalidType", value, type.getName()));
}
}
/**
* Build the list of decoder entries from a set of decoder implementations.
*
* @param decoderClazzes Decoder implementation classes
* @param instanceManager Instance manager to use to create Decoder instances
*
* @return List of mappings from target type to associated decoder
*
* @throws DeploymentException If a provided decoder class is not valid
*/
public static List<DecoderEntry> getDecoders(List<Class<? extends Decoder>> decoderClazzes,
InstanceManager instanceManager) throws DeploymentException {
List<DecoderEntry> result = new ArrayList<>();
if (decoderClazzes != null) {
for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
// Need to instantiate decoder to ensure it is valid and that
// deployment can be failed if it is not
Decoder instance;
try {
if (instanceManager == null) {
instance = decoderClazz.getConstructor().newInstance();
} else {
instance = (Decoder) instanceManager.newInstance(decoderClazz);
// Don't need this instance, so destroy it
instanceManager.destroyInstance(instance);
}
} catch (ReflectiveOperationException | IllegalArgumentException | SecurityException
| NamingException e) {
throw new DeploymentException(
sm.getString("pojoMethodMapping.invalidDecoder", decoderClazz.getName()), e);
}
DecoderEntry entry = new DecoderEntry(getDecoderType(decoderClazz), decoderClazz);
result.add(entry);
}
}
return result;
}
static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, MessageHandler listener,
EndpointConfig endpointConfig, Session session) {
// Will never be more than 2 types
Set<MessageHandlerResult> results = new HashSet<>(2);
// Simple cases - handlers already accepts one of the types expected by
// the frame handling code
if (String.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.TEXT);
results.add(result);
} else if (ByteBuffer.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.BINARY);
results.add(result);
} else if (PongMessage.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.PONG);
results.add(result);
// Handler needs wrapping and optional decoder to convert it to one of
// the types expected by the frame handling code
} else if (byte[].class.isAssignableFrom(target)) {
boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass());
MessageHandlerResult result = new MessageHandlerResult(whole
? new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig,
matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()),
new Object[1], 0, true, -1, false, -1)
: new PojoMessageHandlerPartialBinary(listener, getOnMessagePartialMethod(listener), session,
new Object[2], 0, true, 1, -1, -1),
MessageHandlerResultType.BINARY);
results.add(result);
} else if (InputStream.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig,
matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()),
new Object[1], 0, true, -1, true, -1),
MessageHandlerResultType.BINARY);
results.add(result);
} else if (Reader.class.isAssignableFrom(target)) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeText(listener, getOnMessageMethod(listener), session, endpointConfig,
matchDecoders(target, endpointConfig, false, ((WsSession) session).getInstanceManager()),
new Object[1], 0, true, -1, -1),
MessageHandlerResultType.TEXT);
results.add(result);
} else {
// Handler needs wrapping and requires decoder to convert it to one
// of the types expected by the frame handling code
DecoderMatch decoderMatch = matchDecoders(target, endpointConfig,
((WsSession) session).getInstanceManager());
Method m = getOnMessageMethod(listener);
if (decoderMatch.getBinaryDecoders().size() > 0) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeBinary(listener, m, session, endpointConfig,
decoderMatch.getBinaryDecoders(), new Object[1], 0, false, -1, false, -1),
MessageHandlerResultType.BINARY);
results.add(result);
}
if (decoderMatch.getTextDecoders().size() > 0) {
MessageHandlerResult result = new MessageHandlerResult(
new PojoMessageHandlerWholeText(listener, m, session, endpointConfig,
decoderMatch.getTextDecoders(), new Object[1], 0, false, -1, -1),
MessageHandlerResultType.TEXT);
results.add(result);
}
}
if (results.size() == 0) {
throw new IllegalArgumentException(sm.getString("wsSession.unknownHandler", listener, target));
}
return results;
}
private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, EndpointConfig endpointConfig,
boolean binary, InstanceManager instanceManager) {
DecoderMatch decoderMatch = matchDecoders(target, endpointConfig, instanceManager);
if (binary) {
if (decoderMatch.getBinaryDecoders().size() > 0) {
return decoderMatch.getBinaryDecoders();
}
} else if (decoderMatch.getTextDecoders().size() > 0) {
return decoderMatch.getTextDecoders();
}
return null;
}
private static DecoderMatch matchDecoders(Class<?> target, EndpointConfig endpointConfig,
InstanceManager instanceManager) {
DecoderMatch decoderMatch;
try {
List<Class<? extends Decoder>> decoders = endpointConfig.getDecoders();
List<DecoderEntry> decoderEntries = getDecoders(decoders, instanceManager);
decoderMatch = new DecoderMatch(target, decoderEntries);
} catch (DeploymentException e) {
throw new IllegalArgumentException(e);
}
return decoderMatch;
}
public static void parseExtensionHeader(List<Extension> extensions, String header) {
// The relevant ABNF for the Sec-WebSocket-Extensions is as follows:
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ; When using the quoted-string syntax variant, the value
// ; after quoted-string unescaping MUST conform to the
// ; 'token' ABNF.
//
// The limiting of parameter values to tokens or "quoted tokens" makes
// the parsing of the header significantly simpler and allows a number
// of short-cuts to be taken.
// Step one, split the header into individual extensions using ',' as a
// separator
String unparsedExtensions[] = header.split(",");
for (String unparsedExtension : unparsedExtensions) {
// Step two, split the extension into the registered name and
// parameter/value pairs using ';' as a separator
String unparsedParameters[] = unparsedExtension.split(";");
WsExtension extension = new WsExtension(unparsedParameters[0].trim());
for (int i = 1; i < unparsedParameters.length; i++) {
int equalsPos = unparsedParameters[i].indexOf('=');
String name;
String value;
if (equalsPos == -1) {
name = unparsedParameters[i].trim();
value = null;
} else {
name = unparsedParameters[i].substring(0, equalsPos).trim();
value = unparsedParameters[i].substring(equalsPos + 1).trim();
int len = value.length();
if (len > 1) {
if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') {
value = value.substring(1, value.length() - 1);
}
}
}
// Make sure value doesn't contain any of the delimiters since
// that would indicate something went wrong
if (containsDelims(name) || containsDelims(value)) {
throw new IllegalArgumentException(sm.getString("util.notToken", name, value));
}
if (value != null && (value.indexOf(',') > -1 || value.indexOf(';') > -1 || value.indexOf('\"') > -1 ||
value.indexOf('=') > -1)) {
throw new IllegalArgumentException(sm.getString("util.invalidValue", value));
}
extension.addParameter(new WsExtensionParameter(name, value));
}
extensions.add(extension);
}
}
private static boolean containsDelims(String input) {
if (input == null || input.length() == 0) {
return false;
}
for (char c : input.toCharArray()) {
switch (c) {
case ',':
case ';':
case '\"':
case '=':
return true;
default:
// NO_OP
}
}
return false;
}
private static Method getOnMessageMethod(MessageHandler listener) {
try {
return listener.getClass().getMethod("onMessage", Object.class);
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e);
}
}
private static Method getOnMessagePartialMethod(MessageHandler listener) {
try {
return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE);
} catch (NoSuchMethodException | SecurityException e) {
throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e);
}
}
public static class DecoderMatch {
private final List<Class<? extends Decoder>> textDecoders = new ArrayList<>();
private final List<Class<? extends Decoder>> binaryDecoders = new ArrayList<>();
private final Class<?> target;
public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) {
this.target = target;
for (DecoderEntry decoderEntry : decoderEntries) {
if (decoderEntry.getClazz().isAssignableFrom(target)) {
if (Binary.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
binaryDecoders.add(decoderEntry.getDecoderClazz());
// willDecode() method means this decoder may or may not
// decode a message so need to carry on checking for
// other matches
} else if (BinaryStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
binaryDecoders.add(decoderEntry.getDecoderClazz());
// Stream decoders have to process the message so no
// more decoders can be matched
break;
} else if (Text.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
textDecoders.add(decoderEntry.getDecoderClazz());
// willDecode() method means this decoder may or may not
// decode a message so need to carry on checking for
// other matches
} else if (TextStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
textDecoders.add(decoderEntry.getDecoderClazz());
// Stream decoders have to process the message so no
// more decoders can be matched
break;
} else {
throw new IllegalArgumentException(sm.getString("util.unknownDecoderType"));
}
}
}
}
public List<Class<? extends Decoder>> getTextDecoders() {
return textDecoders;
}
public List<Class<? extends Decoder>> getBinaryDecoders() {
return binaryDecoders;
}
public Class<?> getTarget() {
return target;
}
public boolean hasMatches() {
return (textDecoders.size() > 0) || (binaryDecoders.size() > 0);
}
}
private static class TypeResult {
private final Class<?> clazz;
private final int index;
private int dimension;
TypeResult(Class<?> clazz, int index, int dimension) {
this.clazz = clazz;
this.index = index;
this.dimension = dimension;
}
public Class<?> getClazz() {
return clazz;
}
public int getIndex() {
return index;
}
public int getDimension() {
return dimension;
}
public void incrementDimension(int inc) {
dimension += inc;
}
}
}