Util.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.io.InputStream;
import java.io.Reader;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;

import javax.naming.NamingException;

import jakarta.websocket.CloseReason.CloseCode;
import jakarta.websocket.CloseReason.CloseCodes;
import jakarta.websocket.Decoder;
import jakarta.websocket.Decoder.Binary;
import jakarta.websocket.Decoder.BinaryStream;
import jakarta.websocket.Decoder.Text;
import jakarta.websocket.Decoder.TextStream;
import jakarta.websocket.DeploymentException;
import jakarta.websocket.Encoder;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.Extension;
import jakarta.websocket.MessageHandler;
import jakarta.websocket.PongMessage;
import jakarta.websocket.Session;

import org.apache.tomcat.InstanceManager;
import org.apache.tomcat.util.res.StringManager;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerPartialBinary;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeBinary;
import org.apache.tomcat.websocket.pojo.PojoMessageHandlerWholeText;

/**
 * Utility class for internal use only within the {@link org.apache.tomcat.websocket} package.
 */
public class Util {

    private static final StringManager sm = StringManager.getManager(Util.class);
    private static final Queue<SecureRandom> randoms = new ConcurrentLinkedQueue<>();

    private Util() {
        // Hide default constructor
    }


    static boolean isControl(byte opCode) {
        return (opCode & 0x08) != 0;
    }


    static boolean isText(byte opCode) {
        return opCode == Constants.OPCODE_TEXT;
    }


    static boolean isContinuation(byte opCode) {
        return opCode == Constants.OPCODE_CONTINUATION;
    }


    static CloseCode getCloseCode(int code) {
        if (code > 2999 && code < 5000) {
            return CloseCodes.getCloseCode(code);
        }
        switch (code) {
            case 1000:
                return CloseCodes.NORMAL_CLOSURE;
            case 1001:
                return CloseCodes.GOING_AWAY;
            case 1002:
                return CloseCodes.PROTOCOL_ERROR;
            case 1003:
                return CloseCodes.CANNOT_ACCEPT;
            case 1004:
                // Should not be used in a close frame
                // return CloseCodes.RESERVED;
                return CloseCodes.PROTOCOL_ERROR;
            case 1005:
                // Should not be used in a close frame
                // return CloseCodes.NO_STATUS_CODE;
                return CloseCodes.PROTOCOL_ERROR;
            case 1006:
                // Should not be used in a close frame
                // return CloseCodes.CLOSED_ABNORMALLY;
                return CloseCodes.PROTOCOL_ERROR;
            case 1007:
                return CloseCodes.NOT_CONSISTENT;
            case 1008:
                return CloseCodes.VIOLATED_POLICY;
            case 1009:
                return CloseCodes.TOO_BIG;
            case 1010:
                return CloseCodes.NO_EXTENSION;
            case 1011:
                return CloseCodes.UNEXPECTED_CONDITION;
            case 1012:
                // Not in RFC6455
                // return CloseCodes.SERVICE_RESTART;
                return CloseCodes.PROTOCOL_ERROR;
            case 1013:
                // Not in RFC6455
                // return CloseCodes.TRY_AGAIN_LATER;
                return CloseCodes.PROTOCOL_ERROR;
            case 1015:
                // Should not be used in a close frame
                // return CloseCodes.TLS_HANDSHAKE_FAILURE;
                return CloseCodes.PROTOCOL_ERROR;
            default:
                return CloseCodes.PROTOCOL_ERROR;
        }
    }


    static byte[] generateMask() {
        // SecureRandom is not thread-safe so need to make sure only one thread
        // uses it at a time. In theory, the pool could grow to the same size
        // as the number of request processing threads. In reality it will be
        // a lot smaller.

        // Get a SecureRandom from the pool
        SecureRandom sr = randoms.poll();

        // If one isn't available, generate a new one
        if (sr == null) {
            try {
                sr = SecureRandom.getInstance("SHA1PRNG");
            } catch (NoSuchAlgorithmException e) {
                // Fall back to platform default
                sr = new SecureRandom();
            }
        }

        // Generate the mask
        byte[] result = new byte[4];
        sr.nextBytes(result);

        // Put the SecureRandom back in the poll
        randoms.add(sr);

        return result;
    }


    static Class<?> getMessageType(MessageHandler listener) {
        return getGenericType(MessageHandler.class, listener.getClass()).getClazz();
    }


    private static Class<?> getDecoderType(Class<? extends Decoder> decoder) {
        return getGenericType(Decoder.class, decoder).getClazz();
    }


    static Class<?> getEncoderType(Class<? extends Encoder> encoder) {
        return getGenericType(Encoder.class, encoder).getClazz();
    }


    private static <T> TypeResult getGenericType(Class<T> type, Class<? extends T> clazz) {

        // Look to see if this class implements the interface of interest

        // Get all the interfaces
        Type[] interfaces = clazz.getGenericInterfaces();
        for (Type iface : interfaces) {
            // Only need to check interfaces that use generics
            if (iface instanceof ParameterizedType) {
                ParameterizedType pi = (ParameterizedType) iface;
                // Look for the interface of interest
                if (pi.getRawType() instanceof Class) {
                    if (type.isAssignableFrom((Class<?>) pi.getRawType())) {
                        return getTypeParameter(clazz, pi.getActualTypeArguments()[0]);
                    }
                }
            }
        }

        // Interface not found on this class. Look at the superclass.
        @SuppressWarnings("unchecked")
        Class<? extends T> superClazz = (Class<? extends T>) clazz.getSuperclass();
        if (superClazz == null) {
            // Finished looking up the class hierarchy without finding anything
            return null;
        }

        TypeResult superClassTypeResult = getGenericType(type, superClazz);
        int dimension = superClassTypeResult.getDimension();
        if (superClassTypeResult.getIndex() == -1 && dimension == 0) {
            // Superclass implements interface and defines explicit type for
            // the interface of interest
            return superClassTypeResult;
        }

        if (superClassTypeResult.getIndex() > -1) {
            // Superclass implements interface and defines unknown type for
            // the interface of interest
            // Map that unknown type to the generic types defined in this class
            ParameterizedType superClassType = (ParameterizedType) clazz.getGenericSuperclass();
            TypeResult result = getTypeParameter(clazz,
                    superClassType.getActualTypeArguments()[superClassTypeResult.getIndex()]);
            result.incrementDimension(superClassTypeResult.getDimension());
            if (result.getClazz() != null && result.getDimension() > 0) {
                superClassTypeResult = result;
            } else {
                return result;
            }
        }

        if (superClassTypeResult.getDimension() > 0) {
            StringBuilder className = new StringBuilder();
            for (int i = 0; i < dimension; i++) {
                className.append('[');
            }
            className.append('L');
            className.append(superClassTypeResult.getClazz().getCanonicalName());
            className.append(';');

            Class<?> arrayClazz;
            try {
                arrayClazz = Class.forName(className.toString());
            } catch (ClassNotFoundException e) {
                throw new IllegalArgumentException(e);
            }

            return new TypeResult(arrayClazz, -1, 0);
        }

        // Error will be logged further up the call stack
        return null;
    }


    /*
     * For a generic parameter, return either the Class used or if the type is unknown, the index for the type in
     * definition of the class
     */
    private static TypeResult getTypeParameter(Class<?> clazz, Type argType) {
        if (argType instanceof Class<?>) {
            return new TypeResult((Class<?>) argType, -1, 0);
        } else if (argType instanceof ParameterizedType) {
            return new TypeResult((Class<?>) ((ParameterizedType) argType).getRawType(), -1, 0);
        } else if (argType instanceof GenericArrayType) {
            Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType();
            TypeResult result = getTypeParameter(clazz, arrayElementType);
            result.incrementDimension(1);
            return result;
        } else {
            TypeVariable<?>[] tvs = clazz.getTypeParameters();
            for (int i = 0; i < tvs.length; i++) {
                if (tvs[i].equals(argType)) {
                    return new TypeResult(null, i, 0);
                }
            }
            return null;
        }
    }


    public static boolean isPrimitive(Class<?> clazz) {
        if (clazz.isPrimitive()) {
            return true;
        } else if (clazz.equals(Boolean.class) || clazz.equals(Byte.class) || clazz.equals(Character.class) ||
                clazz.equals(Double.class) || clazz.equals(Float.class) || clazz.equals(Integer.class) ||
                clazz.equals(Long.class) || clazz.equals(Short.class)) {
            return true;
        }
        return false;
    }


    public static Object coerceToType(Class<?> type, String value) {
        if (type.equals(String.class)) {
            return value;
        } else if (type.equals(boolean.class) || type.equals(Boolean.class)) {
            return Boolean.valueOf(value);
        } else if (type.equals(byte.class) || type.equals(Byte.class)) {
            return Byte.valueOf(value);
        } else if (type.equals(char.class) || type.equals(Character.class)) {
            return Character.valueOf(value.charAt(0));
        } else if (type.equals(double.class) || type.equals(Double.class)) {
            return Double.valueOf(value);
        } else if (type.equals(float.class) || type.equals(Float.class)) {
            return Float.valueOf(value);
        } else if (type.equals(int.class) || type.equals(Integer.class)) {
            return Integer.valueOf(value);
        } else if (type.equals(long.class) || type.equals(Long.class)) {
            return Long.valueOf(value);
        } else if (type.equals(short.class) || type.equals(Short.class)) {
            return Short.valueOf(value);
        } else {
            throw new IllegalArgumentException(sm.getString("util.invalidType", value, type.getName()));
        }
    }


    /**
     * Build the list of decoder entries from a set of decoder implementations.
     *
     * @param decoderClazzes  Decoder implementation classes
     * @param instanceManager Instance manager to use to create Decoder instances
     *
     * @return List of mappings from target type to associated decoder
     *
     * @throws DeploymentException If a provided decoder class is not valid
     */
    public static List<DecoderEntry> getDecoders(List<Class<? extends Decoder>> decoderClazzes,
            InstanceManager instanceManager) throws DeploymentException {

        List<DecoderEntry> result = new ArrayList<>();
        if (decoderClazzes != null) {
            for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
                // Need to instantiate decoder to ensure it is valid and that
                // deployment can be failed if it is not
                Decoder instance;
                try {
                    if (instanceManager == null) {
                        instance = decoderClazz.getConstructor().newInstance();
                    } else {
                        instance = (Decoder) instanceManager.newInstance(decoderClazz);
                        // Don't need this instance, so destroy it
                        instanceManager.destroyInstance(instance);
                    }
                } catch (ReflectiveOperationException | IllegalArgumentException | SecurityException
                        | NamingException e) {
                    throw new DeploymentException(
                            sm.getString("pojoMethodMapping.invalidDecoder", decoderClazz.getName()), e);
                }
                DecoderEntry entry = new DecoderEntry(getDecoderType(decoderClazz), decoderClazz);
                result.add(entry);
            }
        }

        return result;
    }


    static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, MessageHandler listener,
            EndpointConfig endpointConfig, Session session) {

        // Will never be more than 2 types
        Set<MessageHandlerResult> results = new HashSet<>(2);

        // Simple cases - handlers already accepts one of the types expected by
        // the frame handling code
        if (String.class.isAssignableFrom(target)) {
            MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.TEXT);
            results.add(result);
        } else if (ByteBuffer.class.isAssignableFrom(target)) {
            MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.BINARY);
            results.add(result);
        } else if (PongMessage.class.isAssignableFrom(target)) {
            MessageHandlerResult result = new MessageHandlerResult(listener, MessageHandlerResultType.PONG);
            results.add(result);
            // Handler needs wrapping and optional decoder to convert it to one of
            // the types expected by the frame handling code
        } else if (byte[].class.isAssignableFrom(target)) {
            boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass());
            MessageHandlerResult result = new MessageHandlerResult(whole
                    ? new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig,
                            matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()),
                            new Object[1], 0, true, -1, false, -1)
                    : new PojoMessageHandlerPartialBinary(listener, getOnMessagePartialMethod(listener), session,
                            new Object[2], 0, true, 1, -1, -1),
                    MessageHandlerResultType.BINARY);
            results.add(result);
        } else if (InputStream.class.isAssignableFrom(target)) {
            MessageHandlerResult result = new MessageHandlerResult(
                    new PojoMessageHandlerWholeBinary(listener, getOnMessageMethod(listener), session, endpointConfig,
                            matchDecoders(target, endpointConfig, true, ((WsSession) session).getInstanceManager()),
                            new Object[1], 0, true, -1, true, -1),
                    MessageHandlerResultType.BINARY);
            results.add(result);
        } else if (Reader.class.isAssignableFrom(target)) {
            MessageHandlerResult result = new MessageHandlerResult(
                    new PojoMessageHandlerWholeText(listener, getOnMessageMethod(listener), session, endpointConfig,
                            matchDecoders(target, endpointConfig, false, ((WsSession) session).getInstanceManager()),
                            new Object[1], 0, true, -1, -1),
                    MessageHandlerResultType.TEXT);
            results.add(result);
        } else {
            // Handler needs wrapping and requires decoder to convert it to one
            // of the types expected by the frame handling code
            DecoderMatch decoderMatch = matchDecoders(target, endpointConfig,
                    ((WsSession) session).getInstanceManager());
            Method m = getOnMessageMethod(listener);
            if (decoderMatch.getBinaryDecoders().size() > 0) {
                MessageHandlerResult result = new MessageHandlerResult(
                        new PojoMessageHandlerWholeBinary(listener, m, session, endpointConfig,
                                decoderMatch.getBinaryDecoders(), new Object[1], 0, false, -1, false, -1),
                        MessageHandlerResultType.BINARY);
                results.add(result);
            }
            if (decoderMatch.getTextDecoders().size() > 0) {
                MessageHandlerResult result = new MessageHandlerResult(
                        new PojoMessageHandlerWholeText(listener, m, session, endpointConfig,
                                decoderMatch.getTextDecoders(), new Object[1], 0, false, -1, -1),
                        MessageHandlerResultType.TEXT);
                results.add(result);
            }
        }

        if (results.size() == 0) {
            throw new IllegalArgumentException(sm.getString("wsSession.unknownHandler", listener, target));
        }

        return results;
    }

    private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, EndpointConfig endpointConfig,
            boolean binary, InstanceManager instanceManager) {
        DecoderMatch decoderMatch = matchDecoders(target, endpointConfig, instanceManager);
        if (binary) {
            if (decoderMatch.getBinaryDecoders().size() > 0) {
                return decoderMatch.getBinaryDecoders();
            }
        } else if (decoderMatch.getTextDecoders().size() > 0) {
            return decoderMatch.getTextDecoders();
        }
        return null;
    }

    private static DecoderMatch matchDecoders(Class<?> target, EndpointConfig endpointConfig,
            InstanceManager instanceManager) {
        DecoderMatch decoderMatch;
        try {
            List<Class<? extends Decoder>> decoders = endpointConfig.getDecoders();
            List<DecoderEntry> decoderEntries = getDecoders(decoders, instanceManager);
            decoderMatch = new DecoderMatch(target, decoderEntries);
        } catch (DeploymentException e) {
            throw new IllegalArgumentException(e);
        }
        return decoderMatch;
    }

    public static void parseExtensionHeader(List<Extension> extensions, String header) {
        // The relevant ABNF for the Sec-WebSocket-Extensions is as follows:
        // extension-list = 1#extension
        // extension = extension-token *( ";" extension-param )
        // extension-token = registered-token
        // registered-token = token
        // extension-param = token [ "=" (token | quoted-string) ]
        // ; When using the quoted-string syntax variant, the value
        // ; after quoted-string unescaping MUST conform to the
        // ; 'token' ABNF.
        //
        // The limiting of parameter values to tokens or "quoted tokens" makes
        // the parsing of the header significantly simpler and allows a number
        // of short-cuts to be taken.

        // Step one, split the header into individual extensions using ',' as a
        // separator
        String unparsedExtensions[] = header.split(",");
        for (String unparsedExtension : unparsedExtensions) {
            // Step two, split the extension into the registered name and
            // parameter/value pairs using ';' as a separator
            String unparsedParameters[] = unparsedExtension.split(";");
            WsExtension extension = new WsExtension(unparsedParameters[0].trim());

            for (int i = 1; i < unparsedParameters.length; i++) {
                int equalsPos = unparsedParameters[i].indexOf('=');
                String name;
                String value;
                if (equalsPos == -1) {
                    name = unparsedParameters[i].trim();
                    value = null;
                } else {
                    name = unparsedParameters[i].substring(0, equalsPos).trim();
                    value = unparsedParameters[i].substring(equalsPos + 1).trim();
                    int len = value.length();
                    if (len > 1) {
                        if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') {
                            value = value.substring(1, value.length() - 1);
                        }
                    }
                }
                // Make sure value doesn't contain any of the delimiters since
                // that would indicate something went wrong
                if (containsDelims(name) || containsDelims(value)) {
                    throw new IllegalArgumentException(sm.getString("util.notToken", name, value));
                }
                if (value != null && (value.indexOf(',') > -1 || value.indexOf(';') > -1 || value.indexOf('\"') > -1 ||
                        value.indexOf('=') > -1)) {
                    throw new IllegalArgumentException(sm.getString("util.invalidValue", value));
                }
                extension.addParameter(new WsExtensionParameter(name, value));
            }
            extensions.add(extension);
        }
    }


    private static boolean containsDelims(String input) {
        if (input == null || input.length() == 0) {
            return false;
        }
        for (char c : input.toCharArray()) {
            switch (c) {
                case ',':
                case ';':
                case '\"':
                case '=':
                    return true;
                default:
                    // NO_OP
            }

        }
        return false;
    }

    private static Method getOnMessageMethod(MessageHandler listener) {
        try {
            return listener.getClass().getMethod("onMessage", Object.class);
        } catch (NoSuchMethodException | SecurityException e) {
            throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e);
        }
    }

    private static Method getOnMessagePartialMethod(MessageHandler listener) {
        try {
            return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE);
        } catch (NoSuchMethodException | SecurityException e) {
            throw new IllegalArgumentException(sm.getString("util.invalidMessageHandler"), e);
        }
    }


    public static class DecoderMatch {

        private final List<Class<? extends Decoder>> textDecoders = new ArrayList<>();
        private final List<Class<? extends Decoder>> binaryDecoders = new ArrayList<>();
        private final Class<?> target;

        public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) {
            this.target = target;
            for (DecoderEntry decoderEntry : decoderEntries) {
                if (decoderEntry.getClazz().isAssignableFrom(target)) {
                    if (Binary.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
                        binaryDecoders.add(decoderEntry.getDecoderClazz());
                        // willDecode() method means this decoder may or may not
                        // decode a message so need to carry on checking for
                        // other matches
                    } else if (BinaryStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
                        binaryDecoders.add(decoderEntry.getDecoderClazz());
                        // Stream decoders have to process the message so no
                        // more decoders can be matched
                        break;
                    } else if (Text.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
                        textDecoders.add(decoderEntry.getDecoderClazz());
                        // willDecode() method means this decoder may or may not
                        // decode a message so need to carry on checking for
                        // other matches
                    } else if (TextStream.class.isAssignableFrom(decoderEntry.getDecoderClazz())) {
                        textDecoders.add(decoderEntry.getDecoderClazz());
                        // Stream decoders have to process the message so no
                        // more decoders can be matched
                        break;
                    } else {
                        throw new IllegalArgumentException(sm.getString("util.unknownDecoderType"));
                    }
                }
            }
        }


        public List<Class<? extends Decoder>> getTextDecoders() {
            return textDecoders;
        }


        public List<Class<? extends Decoder>> getBinaryDecoders() {
            return binaryDecoders;
        }


        public Class<?> getTarget() {
            return target;
        }


        public boolean hasMatches() {
            return (textDecoders.size() > 0) || (binaryDecoders.size() > 0);
        }
    }


    private static class TypeResult {
        private final Class<?> clazz;
        private final int index;
        private int dimension;

        TypeResult(Class<?> clazz, int index, int dimension) {
            this.clazz = clazz;
            this.index = index;
            this.dimension = dimension;
        }

        public Class<?> getClazz() {
            return clazz;
        }

        public int getIndex() {
            return index;
        }

        public int getDimension() {
            return dimension;
        }

        public void incrementDimension(int inc) {
            dimension += inc;
        }
    }
}