PerMessageDeflate.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tomcat.websocket;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import jakarta.websocket.Extension;
import jakarta.websocket.Extension.Parameter;
import jakarta.websocket.SendHandler;

import org.apache.tomcat.util.res.StringManager;

public class PerMessageDeflate implements Transformation {

    private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);

    private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
    private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
    private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
    private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";

    private static final int RSV_BITMASK = 0b100;
    private static final byte[] EOM_BYTES = new byte[] { 0, 0, -1, -1 };

    public static final String NAME = "permessage-deflate";

    private final boolean serverContextTakeover;
    private final int serverMaxWindowBits;
    private final boolean clientContextTakeover;
    private final int clientMaxWindowBits;
    private final boolean isServer;
    private final Inflater inflater = new Inflater(true);
    private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
    private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
    private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];

    private volatile Transformation next;
    private volatile boolean skipDecompression = false;
    private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
    private volatile boolean firstCompressedFrameWritten = false;
    // Flag to track if a message is completely empty
    private volatile boolean emptyMessage = true;

    static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
        // Accept the first preference that the endpoint is able to support
        for (List<Parameter> preference : preferences) {
            boolean ok = true;
            boolean serverContextTakeover = true;
            int serverMaxWindowBits = -1;
            boolean clientContextTakeover = true;
            int clientMaxWindowBits = -1;

            for (Parameter param : preference) {
                if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
                    if (serverContextTakeover) {
                        serverContextTakeover = false;
                    } else {
                        // Duplicate definition
                        throw new IllegalArgumentException(
                                sm.getString("perMessageDeflate.duplicateParameter", SERVER_NO_CONTEXT_TAKEOVER));
                    }
                } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
                    if (clientContextTakeover) {
                        clientContextTakeover = false;
                    } else {
                        // Duplicate definition
                        throw new IllegalArgumentException(
                                sm.getString("perMessageDeflate.duplicateParameter", CLIENT_NO_CONTEXT_TAKEOVER));
                    }
                } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
                    if (serverMaxWindowBits == -1) {
                        serverMaxWindowBits = Integer.parseInt(param.getValue());
                        if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
                            throw new IllegalArgumentException(sm.getString("perMessageDeflate.invalidWindowSize",
                                    SERVER_MAX_WINDOW_BITS, Integer.valueOf(serverMaxWindowBits)));
                        }
                        // Java SE API (as of Java 11) does not expose the API to
                        // control the Window size. It is effectively hard-coded
                        // to 15
                        if (isServer && serverMaxWindowBits != 15) {
                            ok = false;
                            break;
                            // Note server window size is not an issue for the
                            // client since the client will assume 15 and if the
                            // server uses a smaller window everything will
                            // still work
                        }
                    } else {
                        // Duplicate definition
                        throw new IllegalArgumentException(
                                sm.getString("perMessageDeflate.duplicateParameter", SERVER_MAX_WINDOW_BITS));
                    }
                } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
                    if (clientMaxWindowBits == -1) {
                        if (param.getValue() == null) {
                            // Hint to server that the client supports this
                            // option. Java SE API (as of Java 11) does not
                            // expose the API to control the Window size. It is
                            // effectively hard-coded to 15
                            clientMaxWindowBits = 15;
                        } else {
                            clientMaxWindowBits = Integer.parseInt(param.getValue());
                            if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
                                throw new IllegalArgumentException(sm.getString("perMessageDeflate.invalidWindowSize",
                                        CLIENT_MAX_WINDOW_BITS, Integer.valueOf(clientMaxWindowBits)));
                            }
                        }
                        // Java SE API (as of Java 11) does not expose the API to
                        // control the Window size. It is effectively hard-coded
                        // to 15
                        if (!isServer && clientMaxWindowBits != 15) {
                            ok = false;
                            break;
                            // Note client window size is not an issue for the
                            // server since the server will assume 15 and if the
                            // client uses a smaller window everything will
                            // still work
                        }
                    } else {
                        // Duplicate definition
                        throw new IllegalArgumentException(
                                sm.getString("perMessageDeflate.duplicateParameter", CLIENT_MAX_WINDOW_BITS));
                    }
                } else {
                    // Unknown parameter
                    throw new IllegalArgumentException(
                            sm.getString("perMessageDeflate.unknownParameter", param.getName()));
                }
            }
            if (ok) {
                return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits, clientContextTakeover,
                        clientMaxWindowBits, isServer);
            }
        }
        // Failed to negotiate agreeable terms
        return null;
    }


    private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, boolean clientContextTakeover,
            int clientMaxWindowBits, boolean isServer) {
        this.serverContextTakeover = serverContextTakeover;
        this.serverMaxWindowBits = serverMaxWindowBits;
        this.clientContextTakeover = clientContextTakeover;
        this.clientMaxWindowBits = clientMaxWindowBits;
        this.isServer = isServer;
    }


    @Override
    public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException {
        // Control frames are never compressed and may appear in the middle of
        // a WebSocket method. Pass them straight through.
        if (Util.isControl(opCode)) {
            return next.getMoreData(opCode, fin, rsv, dest);
        }

        if (!Util.isContinuation(opCode)) {
            // First frame in new message
            skipDecompression = (rsv & RSV_BITMASK) == 0;
        }

        // Pass uncompressed frames straight through.
        if (skipDecompression) {
            return next.getMoreData(opCode, fin, rsv, dest);
        }

        int written;
        boolean usedEomBytes = false;

        while (dest.remaining() > 0 || usedEomBytes) {
            // Space available in destination. Try and fill it.
            try {
                written = inflater.inflate(dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
            } catch (DataFormatException e) {
                throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
            } catch (NullPointerException e) {
                throw new IOException(sm.getString("perMessageDeflate.alreadyClosed"), e);
            }
            dest.position(dest.position() + written);

            if (inflater.needsInput() && !usedEomBytes) {
                readBuffer.clear();
                TransformationResult nextResult = next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
                inflater.setInput(readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
                if (dest.hasRemaining()) {
                    if (TransformationResult.UNDERFLOW.equals(nextResult)) {
                        return nextResult;
                    } else if (TransformationResult.END_OF_FRAME.equals(nextResult) && readBuffer.position() == 0) {
                        if (fin) {
                            inflater.setInput(EOM_BYTES);
                            usedEomBytes = true;
                        } else {
                            return TransformationResult.END_OF_FRAME;
                        }
                    }
                } else if (readBuffer.position() > 0) {
                    return TransformationResult.OVERFLOW;
                } else if (fin) {
                    inflater.setInput(EOM_BYTES);
                    usedEomBytes = true;
                }
            } else if (written == 0) {
                if (fin && (isServer && !clientContextTakeover || !isServer && !serverContextTakeover)) {
                    try {
                        inflater.reset();
                    } catch (NullPointerException e) {
                        throw new IOException(sm.getString("perMessageDeflate.alreadyClosed"), e);
                    }
                }
                return TransformationResult.END_OF_FRAME;
            }
        }

        return TransformationResult.OVERFLOW;
    }


    @Override
    public boolean validateRsv(int rsv, byte opCode) {
        if (Util.isControl(opCode)) {
            if ((rsv & RSV_BITMASK) != 0) {
                return false;
            } else {
                if (next == null) {
                    return true;
                } else {
                    return next.validateRsv(rsv, opCode);
                }
            }
        } else {
            int rsvNext = rsv;
            if ((rsv & RSV_BITMASK) != 0) {
                rsvNext = rsv ^ RSV_BITMASK;
            }
            if (next == null) {
                return true;
            } else {
                return next.validateRsv(rsvNext, opCode);
            }
        }
    }


    @Override
    public Extension getExtensionResponse() {
        Extension result = new WsExtension(NAME);

        List<Extension.Parameter> params = result.getParameters();

        if (!serverContextTakeover) {
            params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
        }
        if (serverMaxWindowBits != -1) {
            params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS, Integer.toString(serverMaxWindowBits)));
        }
        if (!clientContextTakeover) {
            params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
        }
        if (clientMaxWindowBits != -1) {
            params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS, Integer.toString(clientMaxWindowBits)));
        }

        return result;
    }


    @Override
    public void setNext(Transformation t) {
        if (next == null) {
            this.next = t;
        } else {
            next.setNext(t);
        }
    }


    @Override
    public boolean validateRsvBits(int i) {
        if ((i & RSV_BITMASK) != 0) {
            return false;
        }
        if (next == null) {
            return true;
        } else {
            return next.validateRsvBits(i | RSV_BITMASK);
        }
    }


    @Override
    public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) throws IOException {
        List<MessagePart> allCompressedParts = new ArrayList<>();

        for (MessagePart uncompressedPart : uncompressedParts) {
            byte opCode = uncompressedPart.getOpCode();
            if (Util.isControl(opCode)) {
                // Control messages can appear in the middle of other messages
                // and must not be compressed. Pass it straight through.
                allCompressedParts.add(uncompressedPart);
                continue;
            }

            if (uncompressedPart.getPayload().limit() != 0) {
                emptyMessage = false;
            }
            if (emptyMessage && uncompressedPart.isFin()) {
                // Zero length messages can't be compressed so pass the
                // final (empty) part straight through.
                allCompressedParts.add(uncompressedPart);
            } else {
                List<MessagePart> compressedParts = new ArrayList<>();
                ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
                SendHandler uncompressedIntermediateHandler = uncompressedPart.getIntermediateHandler();

                if (uncompressedPayload.hasArray()) {
                    deflater.setInput(uncompressedPayload.array(),
                            uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
                            uncompressedPayload.remaining());
                } else {
                    byte[] bytes = new byte[uncompressedPayload.remaining()];
                    uncompressedPayload.get(bytes);
                    deflater.setInput(bytes, 0, bytes.length);
                }

                int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
                boolean deflateRequired = true;

                while (deflateRequired) {
                    ByteBuffer compressedPayload = writeBuffer;

                    try {
                        int written = deflater.deflate(compressedPayload.array(),
                                compressedPayload.arrayOffset() + compressedPayload.position(),
                                compressedPayload.remaining(), flush);
                        compressedPayload.position(compressedPayload.position() + written);
                    } catch (NullPointerException e) {
                        throw new IOException(sm.getString("perMessageDeflate.alreadyClosed"), e);
                    }

                    if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
                        // This message part has been fully processed by the
                        // deflater. Fire the send handler for this message part
                        // and move on to the next message part.
                        break;
                    }

                    // If this point is reached, a new compressed message part
                    // will be created...
                    MessagePart compressedPart;

                    // .. and a new writeBuffer will be required.
                    writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);

                    // Flip the compressed payload ready for writing
                    compressedPayload.flip();

                    boolean fin = uncompressedPart.isFin();
                    boolean full = compressedPayload.limit() == compressedPayload.capacity();
                    boolean needsInput = deflater.needsInput();
                    long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();

                    if (fin && !full && needsInput) {
                        // End of compressed message. Drop EOM bytes and output.
                        compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
                        compressedPart = new MessagePart(true, getRsv(uncompressedPart), opCode, compressedPayload,
                                uncompressedIntermediateHandler, uncompressedIntermediateHandler,
                                blockingWriteTimeoutExpiry);
                        deflateRequired = false;
                        startNewMessage();
                    } else if (full && !needsInput) {
                        // Write buffer full and input message not fully read.
                        // Output and start new compressed part.
                        compressedPart = new MessagePart(false, getRsv(uncompressedPart), opCode, compressedPayload,
                                uncompressedIntermediateHandler, uncompressedIntermediateHandler,
                                blockingWriteTimeoutExpiry);
                    } else if (!fin && full && needsInput) {
                        // Write buffer full and input message not fully read.
                        // Output and get more data.
                        compressedPart = new MessagePart(false, getRsv(uncompressedPart), opCode, compressedPayload,
                                uncompressedIntermediateHandler, uncompressedIntermediateHandler,
                                blockingWriteTimeoutExpiry);
                        deflateRequired = false;
                    } else if (fin && full && needsInput) {
                        // Write buffer full. Input fully read. Deflater may be
                        // in one of four states:
                        // - output complete (just happened to align with end of
                        // buffer
                        // - in middle of EOM bytes
                        // - about to write EOM bytes
                        // - more data to write
                        int eomBufferWritten;
                        try {
                            eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
                        } catch (NullPointerException e) {
                            throw new IOException(sm.getString("perMessageDeflate.alreadyClosed"), e);
                        }
                        if (eomBufferWritten < EOM_BUFFER.length) {
                            // EOM has just been completed
                            compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
                            compressedPart = new MessagePart(true, getRsv(uncompressedPart), opCode, compressedPayload,
                                    uncompressedIntermediateHandler, uncompressedIntermediateHandler,
                                    blockingWriteTimeoutExpiry);
                            deflateRequired = false;
                            startNewMessage();
                        } else {
                            // More data to write
                            // Copy bytes to new write buffer
                            writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
                            compressedPart = new MessagePart(false, getRsv(uncompressedPart), opCode, compressedPayload,
                                    uncompressedIntermediateHandler, uncompressedIntermediateHandler,
                                    blockingWriteTimeoutExpiry);
                        }
                    } else {
                        throw new IllegalStateException(sm.getString("perMessageDeflate.invalidState"));
                    }

                    // Add the newly created compressed part to the set of parts
                    // to pass on to the next transformation.
                    compressedParts.add(compressedPart);
                }

                SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
                int size = compressedParts.size();
                if (size > 0) {
                    compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
                }

                allCompressedParts.addAll(compressedParts);
            }
        }

        if (next == null) {
            return allCompressedParts;
        } else {
            return next.sendMessagePart(allCompressedParts);
        }
    }


    private void startNewMessage() throws IOException {
        firstCompressedFrameWritten = false;
        emptyMessage = true;
        if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
            try {
                deflater.reset();
            } catch (NullPointerException e) {
                throw new IOException(sm.getString("perMessageDeflate.alreadyClosed"), e);
            }
        }
    }


    private int getRsv(MessagePart uncompressedMessagePart) {
        int result = uncompressedMessagePart.getRsv();
        if (!firstCompressedFrameWritten) {
            result += RSV_BITMASK;
            firstCompressedFrameWritten = true;
        }
        return result;
    }


    @Override
    public void close() {
        // There will always be a next transformation
        next.close();
        inflater.end();
        deflater.end();
    }
}