Http2AsyncParser.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.coyote.http2;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.TimeUnit;
import jakarta.servlet.http.WebConnection;
import org.apache.coyote.ProtocolException;
import org.apache.tomcat.util.net.SocketEvent;
import org.apache.tomcat.util.net.SocketWrapperBase;
import org.apache.tomcat.util.net.SocketWrapperBase.BlockingMode;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionCheck;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionHandlerCall;
import org.apache.tomcat.util.net.SocketWrapperBase.CompletionState;
class Http2AsyncParser extends Http2Parser {
private final SocketWrapperBase<?> socketWrapper;
private final Http2AsyncUpgradeHandler upgradeHandler;
private volatile Throwable error = null;
Http2AsyncParser(String connectionId, Input input, Output output, SocketWrapperBase<?> socketWrapper,
Http2AsyncUpgradeHandler upgradeHandler) {
super(connectionId, input, output);
this.socketWrapper = socketWrapper;
socketWrapper.getSocketBufferHandler().expand(input.getMaxFrameSize());
this.upgradeHandler = upgradeHandler;
}
@Override
void readConnectionPreface(WebConnection webConnection, Stream stream) throws Http2Exception {
byte[] prefaceData = new byte[CLIENT_PREFACE_START.length];
ByteBuffer preface = ByteBuffer.wrap(prefaceData);
ByteBuffer header = ByteBuffer.allocate(9);
ByteBuffer framePayload = ByteBuffer.allocate(input.getMaxFrameSize());
PrefaceCompletionHandler handler =
new PrefaceCompletionHandler(webConnection, stream, prefaceData, preface, header, framePayload);
socketWrapper.read(BlockingMode.NON_BLOCK, socketWrapper.getReadTimeout(), TimeUnit.MILLISECONDS, null, handler,
handler, preface, header, framePayload);
}
private class PrefaceCompletionHandler extends FrameCompletionHandler {
private final WebConnection webConnection;
private final Stream stream;
private final byte[] prefaceData;
private volatile boolean prefaceValidated = false;
private PrefaceCompletionHandler(WebConnection webConnection, Stream stream, byte[] prefaceData,
ByteBuffer... buffers) {
super(FrameType.SETTINGS, buffers);
this.webConnection = webConnection;
this.stream = stream;
this.prefaceData = prefaceData;
}
@Override
public CompletionHandlerCall callHandler(CompletionState state, ByteBuffer[] buffers, int offset, int length) {
if (offset != 0 || length != 3) {
try {
throw new IllegalArgumentException(sm.getString("http2Parser.invalidBuffers"));
} catch (IllegalArgumentException e) {
error = e;
return CompletionHandlerCall.DONE;
}
}
if (!prefaceValidated) {
if (buffers[0].hasRemaining()) {
// The preface must be fully read before being validated
return CompletionHandlerCall.CONTINUE;
}
// Validate preface content
for (int i = 0; i < CLIENT_PREFACE_START.length; i++) {
if (CLIENT_PREFACE_START[i] != prefaceData[i]) {
error = new ProtocolException(sm.getString("http2Parser.preface.invalid"));
return CompletionHandlerCall.DONE;
}
}
prefaceValidated = true;
}
return validate(state, buffers[1], buffers[2]);
}
@Override
public void completed(Long result, Void attachment) {
if (streamException || error == null) {
ByteBuffer payload = buffers[2];
payload.flip();
try {
if (streamException) {
swallowPayload(streamId, frameTypeId, payloadSize, false, payload);
} else {
readSettingsFrame(flags, payloadSize, payload);
}
} catch (RuntimeException | IOException | Http2Exception e) {
error = e;
}
// Any extra frame is not processed yet, so put back any leftover data
if (payload.hasRemaining()) {
socketWrapper.unRead(payload);
}
// Finish processing the connection
upgradeHandler.processConnectionCallback(webConnection, stream);
} else {
upgradeHandler
.closeConnection(new ConnectionException(error.getMessage(), Http2Error.PROTOCOL_ERROR, error));
}
// Continue reading frames
upgradeHandler.upgradeDispatch(SocketEvent.OPEN_READ);
}
}
@Override
protected boolean readFrame(boolean block, FrameType expected) throws IOException, Http2Exception {
handleAsyncException();
ByteBuffer header = ByteBuffer.allocate(9);
ByteBuffer framePayload = ByteBuffer.allocate(input.getMaxFrameSize());
FrameCompletionHandler handler = new FrameCompletionHandler(expected, header, framePayload);
CompletionState state = socketWrapper.read(block ? BlockingMode.BLOCK : BlockingMode.NON_BLOCK,
block ? socketWrapper.getReadTimeout() : 0, TimeUnit.MILLISECONDS, null, handler, handler, header,
framePayload);
if (state == CompletionState.ERROR || state == CompletionState.INLINE) {
handleAsyncException();
return true;
} else {
return false;
}
}
private void handleAsyncException() throws IOException, Http2Exception {
if (error != null) {
Throwable error = this.error;
this.error = null;
if (error instanceof Http2Exception) {
throw (Http2Exception) error;
} else if (error instanceof IOException) {
throw (IOException) error;
} else if (error instanceof RuntimeException) {
throw (RuntimeException) error;
} else {
throw new RuntimeException(error);
}
}
}
private class FrameCompletionHandler implements CompletionCheck, CompletionHandler<Long,Void> {
private final FrameType expected;
protected final ByteBuffer[] buffers;
private volatile boolean parsedFrameHeader = false;
private volatile boolean validated = false;
private volatile CompletionState state = null;
protected volatile int payloadSize;
protected volatile int frameTypeId;
protected volatile FrameType frameType;
protected volatile int flags;
protected volatile int streamId;
protected volatile boolean streamException = false;
private FrameCompletionHandler(FrameType expected, ByteBuffer... buffers) {
this.expected = expected;
this.buffers = buffers;
}
@Override
public CompletionHandlerCall callHandler(CompletionState state, ByteBuffer[] buffers, int offset, int length) {
if (offset != 0 || length != 2) {
try {
throw new IllegalArgumentException(sm.getString("http2Parser.invalidBuffers"));
} catch (IllegalArgumentException e) {
error = e;
return CompletionHandlerCall.DONE;
}
}
return validate(state, buffers[0], buffers[1]);
}
protected CompletionHandlerCall validate(CompletionState state, ByteBuffer frameHeaderBuffer,
ByteBuffer payload) {
if (!parsedFrameHeader) {
// The first buffer should be 9 bytes long
if (frameHeaderBuffer.position() < 9) {
return CompletionHandlerCall.CONTINUE;
}
parsedFrameHeader = true;
payloadSize = ByteUtil.getThreeBytes(frameHeaderBuffer, 0);
frameTypeId = ByteUtil.getOneByte(frameHeaderBuffer, 3);
frameType = FrameType.valueOf(frameTypeId);
flags = ByteUtil.getOneByte(frameHeaderBuffer, 4);
streamId = ByteUtil.get31Bits(frameHeaderBuffer, 5);
}
this.state = state;
if (!validated) {
validated = true;
try {
validateFrame(expected, frameType, streamId, flags, payloadSize);
} catch (StreamException e) {
error = e;
streamException = true;
} catch (Http2Exception e) {
error = e;
// The problem will be handled later, consider the frame read is done
return CompletionHandlerCall.DONE;
}
}
if (payload.position() < payloadSize) {
return CompletionHandlerCall.CONTINUE;
}
return CompletionHandlerCall.DONE;
}
@Override
public void completed(Long result, Void attachment) {
if (streamException || error == null) {
ByteBuffer payload = buffers[1];
payload.flip();
try {
boolean continueParsing;
do {
continueParsing = false;
if (streamException) {
swallowPayload(streamId, frameTypeId, payloadSize, false, payload);
} else {
switch (frameType) {
case DATA:
readDataFrame(streamId, flags, payloadSize, payload);
break;
case HEADERS:
readHeadersFrame(streamId, flags, payloadSize, payload);
break;
case PRIORITY:
readPriorityFrame(streamId, payload);
break;
case RST:
readRstFrame(streamId, payload);
break;
case SETTINGS:
readSettingsFrame(flags, payloadSize, payload);
break;
case PUSH_PROMISE:
readPushPromiseFrame(streamId, flags, payloadSize, payload);
break;
case PING:
readPingFrame(flags, payload);
break;
case GOAWAY:
readGoawayFrame(payloadSize, payload);
break;
case WINDOW_UPDATE:
readWindowUpdateFrame(streamId, payload);
break;
case CONTINUATION:
readContinuationFrame(streamId, flags, payloadSize, payload);
break;
case PRIORITY_UPDATE:
readPriorityUpdateFrame(payloadSize, payload);
break;
case UNKNOWN:
readUnknownFrame(streamId, frameTypeId, flags, payloadSize, payload);
}
}
if (!upgradeHandler.isOverheadLimitExceeded()) {
// See if there is a new 9 byte header and continue parsing if possible
if (payload.remaining() >= 9) {
int position = payload.position();
payloadSize = ByteUtil.getThreeBytes(payload, position);
frameTypeId = ByteUtil.getOneByte(payload, position + 3);
frameType = FrameType.valueOf(frameTypeId);
flags = ByteUtil.getOneByte(payload, position + 4);
streamId = ByteUtil.get31Bits(payload, position + 5);
streamException = false;
if (payload.remaining() - 9 >= payloadSize) {
continueParsing = true;
// Now go over frame header
payload.position(payload.position() + 9);
try {
validateFrame(null, frameType, streamId, flags, payloadSize);
} catch (StreamException e) {
error = e;
streamException = true;
} catch (Http2Exception e) {
error = e;
continueParsing = false;
}
}
}
}
} while (continueParsing);
} catch (RuntimeException | IOException | Http2Exception e) {
error = e;
} finally {
if (payload.hasRemaining()) {
socketWrapper.unRead(payload);
}
}
}
if (state == CompletionState.DONE) {
// The call was not completed inline, so must start reading new frames
// or process the stream exception
upgradeHandler.upgradeDispatch(SocketEvent.OPEN_READ);
}
}
@Override
public void failed(Throwable e, Void attachment) {
// Always a fatal IO error
error = e;
if (log.isDebugEnabled()) {
log.debug(sm.getString("http2Parser.error", connectionId, Integer.valueOf(streamId), frameType), e);
}
if (state == null || state == CompletionState.DONE) {
upgradeHandler.upgradeDispatch(SocketEvent.ERROR);
}
}
}
}