blob: 869775e2ac4b4177da56385c89bac94a56397ecb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.qpid.server.transport;
import java.io.IOException;
import java.security.Principal;
import java.security.cert.Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.qpid.server.bytebuffer.QpidByteBuffer;
import org.apache.qpid.server.model.port.AmqpPort;
import org.apache.qpid.server.transport.network.security.ssl.SSLUtil;
import org.apache.qpid.server.util.ConnectionScopedRuntimeException;
import org.apache.qpid.server.util.ServerScopedRuntimeException;
public class NonBlockingConnectionTLSDelegate implements NonBlockingConnectionDelegate
{
private static final Logger LOGGER = LoggerFactory.getLogger(NonBlockingConnectionTLSDelegate.class);
private final SSLEngine _sslEngine;
private final NonBlockingConnection _parent;
private final int _networkBufferSize;
private SSLEngineResult _status;
private final List<QpidByteBuffer> _encryptedOutput = new ArrayList<>();
private Principal _principal;
private Certificate _peerCertificate;
private boolean _principalChecked;
private volatile boolean _hostChecked;
private QpidByteBuffer _netInputBuffer;
private QpidByteBuffer _netOutputBuffer;
private QpidByteBuffer _applicationBuffer;
private final boolean _ignoreInvalidSni;
public NonBlockingConnectionTLSDelegate(NonBlockingConnection parent, AmqpPort port)
{
_parent = parent;
_sslEngine = createSSLEngine(port);
_networkBufferSize = port.getNetworkBufferSize();
final int tlsPacketBufferSize = _sslEngine.getSession().getPacketBufferSize();
if (tlsPacketBufferSize > _networkBufferSize)
{
throw new ServerScopedRuntimeException("TLS implementation packet buffer size (" + tlsPacketBufferSize
+ ") is greater then broker network buffer size (" + _networkBufferSize + ")");
}
_netInputBuffer = QpidByteBuffer.allocateDirect(_networkBufferSize);
_applicationBuffer = QpidByteBuffer.allocateDirect(_networkBufferSize);
_netOutputBuffer = QpidByteBuffer.allocateDirect(_networkBufferSize);
_ignoreInvalidSni = port.isIgnoreInvalidSni();
}
@Override
public boolean readyForRead()
{
return _sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_WRAP;
}
@Override
public boolean processData() throws IOException
{
if(!_hostChecked)
{
try (QpidByteBuffer buffer = _netInputBuffer.duplicate())
{
buffer.flip();
if (SSLUtil.isSufficientToDetermineClientSNIHost(buffer))
{
final SNIHostName hostName = getSNIHostName(buffer);
if (hostName != null)
{
_parent.setSelectedHost(hostName.getAsciiName());
SSLParameters sslParameters = _sslEngine.getSSLParameters();
sslParameters.setServerNames(Collections.singletonList(hostName));
_sslEngine.setSSLParameters(sslParameters);
}
_hostChecked = true;
}
else
{
return false;
}
}
}
_netInputBuffer.flip();
boolean readData = false;
boolean tasksRun;
int oldNetBufferPos;
do
{
int oldAppBufPos = _applicationBuffer.position();
oldNetBufferPos = _netInputBuffer.position();
_status = QpidByteBuffer.decryptSSL(_sslEngine, _netInputBuffer, _applicationBuffer);
if (_status.getStatus() == SSLEngineResult.Status.CLOSED)
{
int remaining = _netInputBuffer.remaining();
_netInputBuffer.position(_netInputBuffer.limit());
// We'd usually expect no more bytes to be sent following a close_notify
LOGGER.debug("SSLEngine closed, discarded {} byte(s)", remaining);
}
tasksRun = runSSLEngineTasks(_status);
_applicationBuffer.flip();
if(_applicationBuffer.position() > oldAppBufPos)
{
readData = true;
}
_parent.processAmqpData(_applicationBuffer);
restoreApplicationBufferForWrite();
}
while((_netInputBuffer.hasRemaining() && (_netInputBuffer.position()>oldNetBufferPos)) || tasksRun);
if(_netInputBuffer.hasRemaining())
{
_netInputBuffer.compact();
}
else
{
_netInputBuffer.clear();
}
return readData;
}
private SNIHostName getSNIHostName(final QpidByteBuffer buffer)
{
try
{
final String name = SSLUtil.getServerNameFromTLSClientHello(buffer);
if (name != null)
{
return SSLUtil.createSNIHostName(name);
}
}
catch (ConnectionScopedRuntimeException e)
{
if (!_ignoreInvalidSni)
{
throw e;
}
}
return null;
}
@Override
public WriteResult doWrite(Collection<QpidByteBuffer> buffers) throws IOException
{
final int bufCount = buffers.size();
int totalConsumed = wrapBufferArray(buffers);
boolean bufsSent = true;
final Iterator<QpidByteBuffer> itr = buffers.iterator();
int bufIndex = 0;
while(itr.hasNext() && bufsSent && bufIndex++ < bufCount)
{
QpidByteBuffer buf = itr.next();
bufsSent = !buf.hasRemaining();
}
if(!_encryptedOutput.isEmpty())
{
_parent.writeToTransport(_encryptedOutput);
ListIterator<QpidByteBuffer> iter = _encryptedOutput.listIterator();
while (iter.hasNext())
{
QpidByteBuffer buf = iter.next();
if (!buf.hasRemaining())
{
buf.dispose();
iter.remove();
}
else
{
break;
}
}
}
return new WriteResult(bufsSent && _encryptedOutput.isEmpty(), totalConsumed);
}
protected void restoreApplicationBufferForWrite()
{
try (QpidByteBuffer oldApplicationBuffer = _applicationBuffer)
{
int unprocessedDataLength = _applicationBuffer.remaining();
_applicationBuffer.limit(_applicationBuffer.capacity());
_applicationBuffer = _applicationBuffer.slice();
_applicationBuffer.limit(unprocessedDataLength);
}
if (_applicationBuffer.limit() <= _applicationBuffer.capacity() - _sslEngine.getSession().getApplicationBufferSize())
{
_applicationBuffer.position(_applicationBuffer.limit());
_applicationBuffer.limit(_applicationBuffer.capacity());
}
else
{
try (QpidByteBuffer currentBuffer = _applicationBuffer)
{
int newBufSize;
if (currentBuffer.capacity() < _networkBufferSize)
{
newBufSize = _networkBufferSize;
}
else
{
newBufSize = currentBuffer.capacity() + _networkBufferSize;
_parent.reportUnexpectedByteBufferSizeUsage();
}
_applicationBuffer = QpidByteBuffer.allocateDirect(newBufSize);
_applicationBuffer.put(currentBuffer);
}
}
}
private int wrapBufferArray(Collection<QpidByteBuffer> buffers) throws SSLException
{
int totalConsumed = 0;
boolean encrypted;
do
{
if(_sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_UNWRAP)
{
if(_netOutputBuffer.remaining() < _sslEngine.getSession().getPacketBufferSize())
{
if(_netOutputBuffer.position() != 0)
{
_netOutputBuffer.flip();
_encryptedOutput.add(_netOutputBuffer);
}
else
{
_netOutputBuffer.dispose();
}
_netOutputBuffer = QpidByteBuffer.allocateDirect(_networkBufferSize);
}
_status = QpidByteBuffer.encryptSSL(_sslEngine, buffers, _netOutputBuffer);
if(_status.getStatus() == SSLEngineResult.Status.CLOSED)
{
throw new SSLException(String.format("SSLEngine.wrap operation could not be completed because"
+ " it was already closed (status %s, handshake status %s)",
_status.getStatus(), _status.getHandshakeStatus()));
}
// QPID-8489: workaround for JDK 8 bug to avoid tight looping for half closed connections
// Additional info: https://bugs.openjdk.java.net/browse/JDK-8240071,
// http://mail.openjdk.java.net/pipermail/security-dev/2019-January/019142.html
if(_status.bytesProduced() < 1 && _status.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP
&& !_sslEngine.isOutboundDone() && _sslEngine.isInboundDone())
{
throw new SSLException(String.format("SSLEngine.wrap produced 0 bytes (status %s, handshake status %s)",
_status.getStatus(), _status.getHandshakeStatus()));
}
encrypted = _status.bytesProduced() > 0;
totalConsumed += _status.bytesConsumed();
runSSLEngineTasks(_status);
if(encrypted && _netOutputBuffer.remaining() < _sslEngine.getSession().getPacketBufferSize())
{
_netOutputBuffer.flip();
_encryptedOutput.add(_netOutputBuffer);
_netOutputBuffer = QpidByteBuffer.allocateDirect(_networkBufferSize);
}
}
else
{
encrypted = false;
}
}
while(encrypted && _sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_UNWRAP);
if(_netOutputBuffer.position() != 0)
{
final QpidByteBuffer outputBuffer = _netOutputBuffer;
_netOutputBuffer = _netOutputBuffer.slice();
outputBuffer.flip();
_encryptedOutput.add(outputBuffer);
}
return totalConsumed;
}
private boolean runSSLEngineTasks(final SSLEngineResult status)
{
if(status.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK)
{
Runnable task;
while((task = _sslEngine.getDelegatedTask()) != null)
{
task.run();
}
return true;
}
return false;
}
@Override
public Principal getPeerPrincipal()
{
checkPeerPrincipal();
return _principal;
}
@Override
public Certificate getPeerCertificate()
{
checkPeerPrincipal();
return _peerCertificate;
}
@Override
public boolean needsWork()
{
return _sslEngine.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
}
private synchronized void checkPeerPrincipal()
{
if (!_principalChecked)
{
try
{
_principal = _sslEngine.getSession().getPeerPrincipal();
final Certificate[] peerCertificates =
_sslEngine.getSession().getPeerCertificates();
if (peerCertificates != null && peerCertificates.length > 0)
{
_peerCertificate = peerCertificates[0];
}
}
catch (SSLPeerUnverifiedException e)
{
_principal = null;
_peerCertificate = null;
}
_principalChecked = true;
}
}
private SSLEngine createSSLEngine(AmqpPort<?> port)
{
SSLEngine sslEngine = port.getSSLContext().createSSLEngine();
sslEngine.setUseClientMode(false);
SSLUtil.updateEnabledTlsProtocols(sslEngine, port.getTlsProtocolAllowList(), port.getTlsProtocolDenyList());
SSLUtil.updateEnabledCipherSuites(sslEngine, port.getTlsCipherSuiteAllowList(), port.getTlsCipherSuiteDenyList());
if(port.getTlsCipherSuiteAllowList() != null && !port.getTlsCipherSuiteAllowList().isEmpty())
{
SSLParameters sslParameters = sslEngine.getSSLParameters();
sslParameters.setUseCipherSuitesOrder(true);
sslEngine.setSSLParameters(sslParameters);
}
if(port.getNeedClientAuth())
{
sslEngine.setNeedClientAuth(true);
}
else if(port.getWantClientAuth())
{
sslEngine.setWantClientAuth(true);
}
return sslEngine;
}
@Override
public QpidByteBuffer getNetInputBuffer()
{
return _netInputBuffer;
}
@Override
public void shutdownInput()
{
if (_netInputBuffer != null)
{
_netInputBuffer.dispose();
_netInputBuffer = null;
}
if (_applicationBuffer != null)
{
_applicationBuffer.dispose();
_applicationBuffer = null;
}
}
@Override
public void shutdownOutput()
{
if (_netOutputBuffer != null)
{
_netOutputBuffer.dispose();
_netOutputBuffer = null;
}
try
{
_sslEngine.closeOutbound();
_sslEngine.closeInbound();
}
catch (SSLException e)
{
LOGGER.debug("Exception when closing SSLEngine", e);
}
}
@Override
public String getTransportInfo()
{
SSLSession session = _sslEngine.getSession();
return session.getProtocol() + " ; " + session.getCipherSuite() ;
}
}