blob: d0e2fc8e0993ae91afa95cbe103c3cea1413a55e [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.transport.nio;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.security.cert.X509Certificate;
import java.util.concurrent.CountDownLatch;
import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import org.apache.activemq.command.ConnectionInfo;
import org.apache.activemq.openwire.OpenWireFormat;
import org.apache.activemq.thread.TaskRunnerFactory;
import org.apache.activemq.util.IOExceptionSupport;
import org.apache.activemq.util.ServiceStopper;
import org.apache.activemq.wireformat.WireFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class NIOSSLTransport extends NIOTransport {
private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class);
protected boolean needClientAuth;
protected boolean wantClientAuth;
protected String[] enabledCipherSuites;
protected String[] enabledProtocols;
protected boolean verifyHostName = false;
protected SSLContext sslContext;
protected SSLEngine sslEngine;
protected SSLSession sslSession;
protected volatile boolean handshakeInProgress = false;
protected SSLEngineResult.Status status = null;
protected SSLEngineResult.HandshakeStatus handshakeStatus = null;
protected TaskRunnerFactory taskRunnerFactory;
public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
super(wireFormat, socketFactory, remoteLocation, localLocation);
}
public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer,
ByteBuffer inputBuffer) throws IOException {
super(wireFormat, socket, initBuffer);
this.sslEngine = engine;
if (engine != null) {
this.sslSession = engine.getSession();
}
this.inputBuffer = inputBuffer;
}
public void setSslContext(SSLContext sslContext) {
this.sslContext = sslContext;
}
volatile boolean hasSslEngine = false;
@Override
protected void initializeStreams() throws IOException {
if (sslEngine != null) {
hasSslEngine = true;
}
NIOOutputStream outputStream = null;
try {
channel = socket.getChannel();
channel.configureBlocking(false);
if (sslContext == null) {
sslContext = SSLContext.getDefault();
}
String remoteHost = null;
int remotePort = -1;
try {
URI remoteAddress = new URI(this.getRemoteAddress());
remoteHost = remoteAddress.getHost();
remotePort = remoteAddress.getPort();
} catch (Exception e) {
}
// initialize engine, the initial sslSession we get will need to be
// updated once the ssl handshake process is completed.
if (!hasSslEngine) {
if (remoteHost != null && remotePort != -1) {
sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
} else {
sslEngine = sslContext.createSSLEngine();
}
if (verifyHostName) {
SSLParameters sslParams = new SSLParameters();
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
sslEngine.setSSLParameters(sslParams);
}
sslEngine.setUseClientMode(false);
if (enabledCipherSuites != null) {
sslEngine.setEnabledCipherSuites(enabledCipherSuites);
}
if (enabledProtocols != null) {
sslEngine.setEnabledProtocols(enabledProtocols);
}
if (wantClientAuth) {
sslEngine.setWantClientAuth(wantClientAuth);
}
if (needClientAuth) {
sslEngine.setNeedClientAuth(needClientAuth);
}
sslSession = sslEngine.getSession();
inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
inputBuffer.clear();
}
outputStream = new NIOOutputStream(channel);
outputStream.setEngine(sslEngine);
this.dataOut = new DataOutputStream(outputStream);
this.buffOut = outputStream;
//If the sslEngine was not passed in, then handshake
if (!hasSslEngine) {
sslEngine.beginHandshake();
}
handshakeStatus = sslEngine.getHandshakeStatus();
if (!hasSslEngine) {
doHandshake();
}
selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() {
@Override
public void onSelect(SelectorSelection selection) {
try {
initialized.await();
} catch (InterruptedException error) {
onException(IOExceptionSupport.create(error));
}
serviceRead();
}
@Override
public void onError(SelectorSelection selection, Throwable error) {
if (error instanceof IOException) {
onException((IOException) error);
} else {
onException(IOExceptionSupport.create(error));
}
}
});
doInit();
} catch (Exception e) {
try {
if(outputStream != null) {
outputStream.close();
}
super.closeStreams();
} catch (Exception ex) {}
throw new IOException(e);
}
}
final protected CountDownLatch initialized = new CountDownLatch(1);
protected void doInit() throws Exception {
taskRunnerFactory.execute(new Runnable() {
@Override
public void run() {
//Need to start in new thread to let startup finish first
//We can trigger a read because we know the channel is ready since the SSL handshake
//already happened
serviceRead();
initialized.countDown();
}
});
}
//Only used for the auto transport to abort the openwire init method early if already initialized
boolean openWireInititialized = false;
protected void doOpenWireInit() throws Exception {
//Do this later to let wire format negotiation happen
if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) {
initBuffer.buffer.flip();
if (initBuffer.buffer.hasRemaining()) {
nextFrameSize = -1;
receiveCounter += initBuffer.readSize;
processCommand(initBuffer.buffer);
processCommand(initBuffer.buffer);
initBuffer.buffer.clear();
openWireInititialized = true;
}
}
}
protected void finishHandshake() throws Exception {
if (handshakeInProgress) {
handshakeInProgress = false;
nextFrameSize = -1;
// Once handshake completes we need to ask for the now real sslSession
// otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
// cipher suite.
sslSession = sslEngine.getSession();
}
}
@Override
public void serviceRead() {
try {
if (handshakeInProgress) {
doHandshake();
}
doOpenWireInit();
ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
plain.position(plain.limit());
while (true) {
//If the transport was already stopped then break
if (this.isStopped()) {
return;
}
if (!plain.hasRemaining()) {
int readCount = secureRead(plain);
if (readCount == 0) {
break;
}
// channel is closed, cleanup
if (readCount == -1) {
onException(new EOFException());
selection.close();
break;
}
receiveCounter += readCount;
}
if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
processCommand(plain);
}
}
} catch (IOException e) {
onException(e);
} catch (Throwable e) {
onException(IOExceptionSupport.create(e));
}
}
protected void processCommand(ByteBuffer plain) throws Exception {
// Are we waiting for the next Command or are we building on the current one
if (nextFrameSize == -1) {
// We can get small packets that don't give us enough for the frame size
// so allocate enough for the initial size value and
if (plain.remaining() < Integer.SIZE) {
if (currentBuffer == null) {
currentBuffer = ByteBuffer.allocate(4);
}
// Go until we fill the integer sized current buffer.
while (currentBuffer.hasRemaining() && plain.hasRemaining()) {
currentBuffer.put(plain.get());
}
// Didn't we get enough yet to figure out next frame size.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
}
} else {
// Either we are completing a previous read of the next frame size or its
// fully contained in plain already.
if (currentBuffer != null) {
// Finish the frame size integer read and get from the current buffer.
while (currentBuffer.hasRemaining()) {
currentBuffer.put(plain.get());
}
currentBuffer.flip();
nextFrameSize = currentBuffer.getInt();
} else {
nextFrameSize = plain.getInt();
}
}
if (wireFormat instanceof OpenWireFormat) {
long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize();
if (nextFrameSize > maxFrameSize) {
throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) +
" MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB");
}
}
// now we got the data, lets reallocate and store the size for the marshaler.
// if there's more data in plain, then the next call will start processing it.
currentBuffer = ByteBuffer.allocate(nextFrameSize + 4);
currentBuffer.putInt(nextFrameSize);
} else {
// If its all in one read then we can just take it all, otherwise take only
// the current frame size and the next iteration starts a new command.
if (currentBuffer != null) {
if (currentBuffer.remaining() >= plain.remaining()) {
currentBuffer.put(plain);
} else {
byte[] fill = new byte[currentBuffer.remaining()];
plain.get(fill);
currentBuffer.put(fill);
}
// Either we have enough data for a new command or we have to wait for some more.
if (currentBuffer.hasRemaining()) {
return;
} else {
currentBuffer.flip();
Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer)));
doConsume(command);
nextFrameSize = -1;
currentBuffer = null;
}
}
}
}
//Prevent concurrent access while reading from the channel
protected synchronized int secureRead(ByteBuffer plain) throws Exception {
if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
int bytesRead = channel.read(inputBuffer);
if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) {
return 0;
}
if (bytesRead == -1) {
sslEngine.closeInbound();
if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
return -1;
}
}
}
plain.clear();
inputBuffer.flip();
SSLEngineResult res;
do {
res = sslEngine.unwrap(inputBuffer, plain);
} while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP
&& res.bytesProduced() == 0);
if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) {
finishHandshake();
}
status = res.getStatus();
handshakeStatus = res.getHandshakeStatus();
// TODO deal with BUFFER_OVERFLOW
if (status == SSLEngineResult.Status.CLOSED) {
sslEngine.closeInbound();
return -1;
}
inputBuffer.compact();
plain.flip();
return plain.remaining();
}
protected void doHandshake() throws Exception {
handshakeInProgress = true;
Selector selector = null;
SelectionKey key = null;
boolean readable = true;
try {
while (true) {
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
switch (handshakeStatus) {
case NEED_UNWRAP:
if (readable) {
secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize()));
}
if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) {
long now = System.currentTimeMillis();
if (selector == null) {
selector = Selector.open();
key = channel.register(selector, SelectionKey.OP_READ);
} else {
key.interestOps(SelectionKey.OP_READ);
}
int keyCount = selector.select(this.getSoTimeout());
if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) {
throw new SocketTimeoutException("Timeout during handshake");
}
readable = key.isReadable();
}
break;
case NEED_TASK:
Runnable task;
while ((task = sslEngine.getDelegatedTask()) != null) {
task.run();
}
break;
case NEED_WRAP:
((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0));
break;
case FINISHED:
case NOT_HANDSHAKING:
finishHandshake();
return;
}
}
} finally {
if (key!=null) try {key.cancel();} catch (Exception ignore) {}
if (selector!=null) try {selector.close();} catch (Exception ignore) {}
}
}
@Override
protected void doStart() throws Exception {
taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
// no need to init as we can delay that until demand (eg in doHandshake)
super.doStart();
}
@Override
protected void doStop(ServiceStopper stopper) throws Exception {
initialized.countDown();
if (taskRunnerFactory != null) {
taskRunnerFactory.shutdownNow();
taskRunnerFactory = null;
}
if (channel != null) {
channel.close();
channel = null;
}
super.doStop(stopper);
}
/**
* Overriding in order to add the client's certificates to ConnectionInfo Commands.
*
* @param command
* The Command coming in.
*/
@Override
public void doConsume(Object command) {
if (command instanceof ConnectionInfo) {
ConnectionInfo connectionInfo = (ConnectionInfo) command;
connectionInfo.setTransportContext(getPeerCertificates());
}
super.doConsume(command);
}
/**
* @return peer certificate chain associated with the ssl socket
*/
@Override
public X509Certificate[] getPeerCertificates() {
X509Certificate[] clientCertChain = null;
try {
if (sslEngine.getSession() != null) {
clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates();
}
} catch (SSLPeerUnverifiedException e) {
if (LOG.isTraceEnabled()) {
LOG.trace("Failed to get peer certificates.", e);
}
}
return clientCertChain;
}
public boolean isNeedClientAuth() {
return needClientAuth;
}
public void setNeedClientAuth(boolean needClientAuth) {
this.needClientAuth = needClientAuth;
}
public boolean isWantClientAuth() {
return wantClientAuth;
}
public void setWantClientAuth(boolean wantClientAuth) {
this.wantClientAuth = wantClientAuth;
}
public String[] getEnabledCipherSuites() {
return enabledCipherSuites;
}
public void setEnabledCipherSuites(String[] enabledCipherSuites) {
this.enabledCipherSuites = enabledCipherSuites;
}
public String[] getEnabledProtocols() {
return enabledProtocols;
}
public void setEnabledProtocols(String[] enabledProtocols) {
this.enabledProtocols = enabledProtocols;
}
public boolean isVerifyHostName() {
return verifyHostName;
}
public void setVerifyHostName(boolean verifyHostName) {
this.verifyHostName = verifyHostName;
}
}