| /** |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.activemq.transport.nio; |
| |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.EOFException; |
| import java.io.IOException; |
| import java.net.Socket; |
| import java.net.SocketTimeoutException; |
| import java.net.URI; |
| import java.net.UnknownHostException; |
| import java.nio.ByteBuffer; |
| import java.nio.channels.SelectionKey; |
| import java.nio.channels.Selector; |
| import java.security.cert.X509Certificate; |
| import java.util.concurrent.CountDownLatch; |
| |
| import javax.net.SocketFactory; |
| import javax.net.ssl.SSLContext; |
| import javax.net.ssl.SSLEngine; |
| import javax.net.ssl.SSLEngineResult; |
| import javax.net.ssl.SSLEngineResult.HandshakeStatus; |
| import javax.net.ssl.SSLParameters; |
| import javax.net.ssl.SSLPeerUnverifiedException; |
| import javax.net.ssl.SSLSession; |
| |
| import org.apache.activemq.command.ConnectionInfo; |
| import org.apache.activemq.openwire.OpenWireFormat; |
| import org.apache.activemq.thread.TaskRunnerFactory; |
| import org.apache.activemq.util.IOExceptionSupport; |
| import org.apache.activemq.util.ServiceStopper; |
| import org.apache.activemq.wireformat.WireFormat; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| public class NIOSSLTransport extends NIOTransport { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(NIOSSLTransport.class); |
| |
| protected boolean needClientAuth; |
| protected boolean wantClientAuth; |
| protected String[] enabledCipherSuites; |
| protected String[] enabledProtocols; |
| protected boolean verifyHostName = false; |
| |
| protected SSLContext sslContext; |
| protected SSLEngine sslEngine; |
| protected SSLSession sslSession; |
| |
| protected volatile boolean handshakeInProgress = false; |
| protected SSLEngineResult.Status status = null; |
| protected SSLEngineResult.HandshakeStatus handshakeStatus = null; |
| protected TaskRunnerFactory taskRunnerFactory; |
| |
| public NIOSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { |
| super(wireFormat, socketFactory, remoteLocation, localLocation); |
| } |
| |
| public NIOSSLTransport(WireFormat wireFormat, Socket socket, SSLEngine engine, InitBuffer initBuffer, |
| ByteBuffer inputBuffer) throws IOException { |
| super(wireFormat, socket, initBuffer); |
| this.sslEngine = engine; |
| if (engine != null) { |
| this.sslSession = engine.getSession(); |
| } |
| this.inputBuffer = inputBuffer; |
| } |
| |
| public void setSslContext(SSLContext sslContext) { |
| this.sslContext = sslContext; |
| } |
| |
| volatile boolean hasSslEngine = false; |
| |
| @Override |
| protected void initializeStreams() throws IOException { |
| if (sslEngine != null) { |
| hasSslEngine = true; |
| } |
| NIOOutputStream outputStream = null; |
| try { |
| channel = socket.getChannel(); |
| channel.configureBlocking(false); |
| |
| if (sslContext == null) { |
| sslContext = SSLContext.getDefault(); |
| } |
| |
| String remoteHost = null; |
| int remotePort = -1; |
| |
| try { |
| URI remoteAddress = new URI(this.getRemoteAddress()); |
| remoteHost = remoteAddress.getHost(); |
| remotePort = remoteAddress.getPort(); |
| } catch (Exception e) { |
| } |
| |
| // initialize engine, the initial sslSession we get will need to be |
| // updated once the ssl handshake process is completed. |
| if (!hasSslEngine) { |
| if (remoteHost != null && remotePort != -1) { |
| sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); |
| } else { |
| sslEngine = sslContext.createSSLEngine(); |
| } |
| |
| if (verifyHostName) { |
| SSLParameters sslParams = new SSLParameters(); |
| sslParams.setEndpointIdentificationAlgorithm("HTTPS"); |
| sslEngine.setSSLParameters(sslParams); |
| } |
| |
| sslEngine.setUseClientMode(false); |
| if (enabledCipherSuites != null) { |
| sslEngine.setEnabledCipherSuites(enabledCipherSuites); |
| } |
| |
| if (enabledProtocols != null) { |
| sslEngine.setEnabledProtocols(enabledProtocols); |
| } |
| |
| if (wantClientAuth) { |
| sslEngine.setWantClientAuth(wantClientAuth); |
| } |
| |
| if (needClientAuth) { |
| sslEngine.setNeedClientAuth(needClientAuth); |
| } |
| |
| sslSession = sslEngine.getSession(); |
| |
| inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); |
| inputBuffer.clear(); |
| } |
| |
| outputStream = new NIOOutputStream(channel); |
| outputStream.setEngine(sslEngine); |
| this.dataOut = new DataOutputStream(outputStream); |
| this.buffOut = outputStream; |
| |
| //If the sslEngine was not passed in, then handshake |
| if (!hasSslEngine) { |
| sslEngine.beginHandshake(); |
| } |
| handshakeStatus = sslEngine.getHandshakeStatus(); |
| if (!hasSslEngine) { |
| doHandshake(); |
| } |
| |
| selection = SelectorManager.getInstance().register(channel, new SelectorManager.Listener() { |
| @Override |
| public void onSelect(SelectorSelection selection) { |
| try { |
| initialized.await(); |
| } catch (InterruptedException error) { |
| onException(IOExceptionSupport.create(error)); |
| } |
| serviceRead(); |
| } |
| |
| @Override |
| public void onError(SelectorSelection selection, Throwable error) { |
| if (error instanceof IOException) { |
| onException((IOException) error); |
| } else { |
| onException(IOExceptionSupport.create(error)); |
| } |
| } |
| }); |
| doInit(); |
| |
| } catch (Exception e) { |
| try { |
| if(outputStream != null) { |
| outputStream.close(); |
| } |
| super.closeStreams(); |
| } catch (Exception ex) {} |
| throw new IOException(e); |
| } |
| } |
| |
| final protected CountDownLatch initialized = new CountDownLatch(1); |
| |
| protected void doInit() throws Exception { |
| taskRunnerFactory.execute(new Runnable() { |
| |
| @Override |
| public void run() { |
| //Need to start in new thread to let startup finish first |
| //We can trigger a read because we know the channel is ready since the SSL handshake |
| //already happened |
| serviceRead(); |
| initialized.countDown(); |
| } |
| }); |
| } |
| |
| //Only used for the auto transport to abort the openwire init method early if already initialized |
| boolean openWireInititialized = false; |
| |
| protected void doOpenWireInit() throws Exception { |
| //Do this later to let wire format negotiation happen |
| if (initBuffer != null && !openWireInititialized && this.wireFormat instanceof OpenWireFormat) { |
| initBuffer.buffer.flip(); |
| if (initBuffer.buffer.hasRemaining()) { |
| nextFrameSize = -1; |
| receiveCounter += initBuffer.readSize; |
| processCommand(initBuffer.buffer); |
| processCommand(initBuffer.buffer); |
| initBuffer.buffer.clear(); |
| openWireInititialized = true; |
| } |
| } |
| } |
| |
| protected void finishHandshake() throws Exception { |
| if (handshakeInProgress) { |
| handshakeInProgress = false; |
| nextFrameSize = -1; |
| |
| // Once handshake completes we need to ask for the now real sslSession |
| // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the |
| // cipher suite. |
| sslSession = sslEngine.getSession(); |
| } |
| } |
| |
| @Override |
| public void serviceRead() { |
| try { |
| if (handshakeInProgress) { |
| doHandshake(); |
| } |
| |
| doOpenWireInit(); |
| |
| ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); |
| plain.position(plain.limit()); |
| |
| while (true) { |
| //If the transport was already stopped then break |
| if (this.isStopped()) { |
| return; |
| } |
| |
| if (!plain.hasRemaining()) { |
| |
| int readCount = secureRead(plain); |
| |
| if (readCount == 0) { |
| break; |
| } |
| |
| // channel is closed, cleanup |
| if (readCount == -1) { |
| onException(new EOFException()); |
| selection.close(); |
| break; |
| } |
| |
| receiveCounter += readCount; |
| } |
| |
| if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { |
| processCommand(plain); |
| } |
| } |
| } catch (IOException e) { |
| onException(e); |
| } catch (Throwable e) { |
| onException(IOExceptionSupport.create(e)); |
| } |
| } |
| |
| protected void processCommand(ByteBuffer plain) throws Exception { |
| |
| // Are we waiting for the next Command or are we building on the current one |
| if (nextFrameSize == -1) { |
| |
| // We can get small packets that don't give us enough for the frame size |
| // so allocate enough for the initial size value and |
| if (plain.remaining() < Integer.SIZE) { |
| if (currentBuffer == null) { |
| currentBuffer = ByteBuffer.allocate(4); |
| } |
| |
| // Go until we fill the integer sized current buffer. |
| while (currentBuffer.hasRemaining() && plain.hasRemaining()) { |
| currentBuffer.put(plain.get()); |
| } |
| |
| // Didn't we get enough yet to figure out next frame size. |
| if (currentBuffer.hasRemaining()) { |
| return; |
| } else { |
| currentBuffer.flip(); |
| nextFrameSize = currentBuffer.getInt(); |
| } |
| |
| } else { |
| |
| // Either we are completing a previous read of the next frame size or its |
| // fully contained in plain already. |
| if (currentBuffer != null) { |
| |
| // Finish the frame size integer read and get from the current buffer. |
| while (currentBuffer.hasRemaining()) { |
| currentBuffer.put(plain.get()); |
| } |
| |
| currentBuffer.flip(); |
| nextFrameSize = currentBuffer.getInt(); |
| |
| } else { |
| nextFrameSize = plain.getInt(); |
| } |
| } |
| |
| if (wireFormat instanceof OpenWireFormat) { |
| long maxFrameSize = ((OpenWireFormat) wireFormat).getMaxFrameSize(); |
| if (nextFrameSize > maxFrameSize) { |
| throw new IOException("Frame size of " + (nextFrameSize / (1024 * 1024)) + |
| " MB larger than max allowed " + (maxFrameSize / (1024 * 1024)) + " MB"); |
| } |
| } |
| |
| // now we got the data, lets reallocate and store the size for the marshaler. |
| // if there's more data in plain, then the next call will start processing it. |
| currentBuffer = ByteBuffer.allocate(nextFrameSize + 4); |
| currentBuffer.putInt(nextFrameSize); |
| |
| } else { |
| // If its all in one read then we can just take it all, otherwise take only |
| // the current frame size and the next iteration starts a new command. |
| if (currentBuffer != null) { |
| if (currentBuffer.remaining() >= plain.remaining()) { |
| currentBuffer.put(plain); |
| } else { |
| byte[] fill = new byte[currentBuffer.remaining()]; |
| plain.get(fill); |
| currentBuffer.put(fill); |
| } |
| |
| // Either we have enough data for a new command or we have to wait for some more. |
| if (currentBuffer.hasRemaining()) { |
| return; |
| } else { |
| currentBuffer.flip(); |
| Object command = wireFormat.unmarshal(new DataInputStream(new NIOInputStream(currentBuffer))); |
| doConsume(command); |
| nextFrameSize = -1; |
| currentBuffer = null; |
| } |
| } |
| } |
| } |
| |
| //Prevent concurrent access while reading from the channel |
| protected synchronized int secureRead(ByteBuffer plain) throws Exception { |
| |
| if (!(inputBuffer.position() != 0 && inputBuffer.hasRemaining()) || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { |
| int bytesRead = channel.read(inputBuffer); |
| |
| if (bytesRead == 0 && !(sslEngine.getHandshakeStatus().equals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP))) { |
| return 0; |
| } |
| |
| if (bytesRead == -1) { |
| sslEngine.closeInbound(); |
| if (inputBuffer.position() == 0 || status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { |
| return -1; |
| } |
| } |
| } |
| |
| plain.clear(); |
| |
| inputBuffer.flip(); |
| SSLEngineResult res; |
| do { |
| res = sslEngine.unwrap(inputBuffer, plain); |
| } while (res.getStatus() == SSLEngineResult.Status.OK && res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP |
| && res.bytesProduced() == 0); |
| |
| if (res.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { |
| finishHandshake(); |
| } |
| |
| status = res.getStatus(); |
| handshakeStatus = res.getHandshakeStatus(); |
| |
| // TODO deal with BUFFER_OVERFLOW |
| |
| if (status == SSLEngineResult.Status.CLOSED) { |
| sslEngine.closeInbound(); |
| return -1; |
| } |
| |
| inputBuffer.compact(); |
| plain.flip(); |
| |
| return plain.remaining(); |
| } |
| |
| protected void doHandshake() throws Exception { |
| handshakeInProgress = true; |
| Selector selector = null; |
| SelectionKey key = null; |
| boolean readable = true; |
| try { |
| while (true) { |
| HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); |
| switch (handshakeStatus) { |
| case NEED_UNWRAP: |
| if (readable) { |
| secureRead(ByteBuffer.allocate(sslSession.getApplicationBufferSize())); |
| } |
| if (this.status == SSLEngineResult.Status.BUFFER_UNDERFLOW) { |
| long now = System.currentTimeMillis(); |
| if (selector == null) { |
| selector = Selector.open(); |
| key = channel.register(selector, SelectionKey.OP_READ); |
| } else { |
| key.interestOps(SelectionKey.OP_READ); |
| } |
| int keyCount = selector.select(this.getSoTimeout()); |
| if (keyCount == 0 && this.getSoTimeout() > 0 && ((System.currentTimeMillis() - now) >= this.getSoTimeout())) { |
| throw new SocketTimeoutException("Timeout during handshake"); |
| } |
| readable = key.isReadable(); |
| } |
| break; |
| case NEED_TASK: |
| Runnable task; |
| while ((task = sslEngine.getDelegatedTask()) != null) { |
| task.run(); |
| } |
| break; |
| case NEED_WRAP: |
| ((NIOOutputStream) buffOut).write(ByteBuffer.allocate(0)); |
| break; |
| case FINISHED: |
| case NOT_HANDSHAKING: |
| finishHandshake(); |
| return; |
| } |
| } |
| } finally { |
| if (key!=null) try {key.cancel();} catch (Exception ignore) {} |
| if (selector!=null) try {selector.close();} catch (Exception ignore) {} |
| } |
| } |
| |
| @Override |
| protected void doStart() throws Exception { |
| taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); |
| // no need to init as we can delay that until demand (eg in doHandshake) |
| super.doStart(); |
| } |
| |
| @Override |
| protected void doStop(ServiceStopper stopper) throws Exception { |
| initialized.countDown(); |
| |
| if (taskRunnerFactory != null) { |
| taskRunnerFactory.shutdownNow(); |
| taskRunnerFactory = null; |
| } |
| if (channel != null) { |
| channel.close(); |
| channel = null; |
| } |
| super.doStop(stopper); |
| } |
| |
| /** |
| * Overriding in order to add the client's certificates to ConnectionInfo Commands. |
| * |
| * @param command |
| * The Command coming in. |
| */ |
| @Override |
| public void doConsume(Object command) { |
| if (command instanceof ConnectionInfo) { |
| ConnectionInfo connectionInfo = (ConnectionInfo) command; |
| connectionInfo.setTransportContext(getPeerCertificates()); |
| } |
| super.doConsume(command); |
| } |
| |
| /** |
| * @return peer certificate chain associated with the ssl socket |
| */ |
| @Override |
| public X509Certificate[] getPeerCertificates() { |
| |
| X509Certificate[] clientCertChain = null; |
| try { |
| if (sslEngine.getSession() != null) { |
| clientCertChain = (X509Certificate[]) sslEngine.getSession().getPeerCertificates(); |
| } |
| } catch (SSLPeerUnverifiedException e) { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Failed to get peer certificates.", e); |
| } |
| } |
| |
| return clientCertChain; |
| } |
| |
| public boolean isNeedClientAuth() { |
| return needClientAuth; |
| } |
| |
| public void setNeedClientAuth(boolean needClientAuth) { |
| this.needClientAuth = needClientAuth; |
| } |
| |
| public boolean isWantClientAuth() { |
| return wantClientAuth; |
| } |
| |
| public void setWantClientAuth(boolean wantClientAuth) { |
| this.wantClientAuth = wantClientAuth; |
| } |
| |
| public String[] getEnabledCipherSuites() { |
| return enabledCipherSuites; |
| } |
| |
| public void setEnabledCipherSuites(String[] enabledCipherSuites) { |
| this.enabledCipherSuites = enabledCipherSuites; |
| } |
| |
| public String[] getEnabledProtocols() { |
| return enabledProtocols; |
| } |
| |
| public void setEnabledProtocols(String[] enabledProtocols) { |
| this.enabledProtocols = enabledProtocols; |
| } |
| |
| public boolean isVerifyHostName() { |
| return verifyHostName; |
| } |
| |
| public void setVerifyHostName(boolean verifyHostName) { |
| this.verifyHostName = verifyHostName; |
| } |
| } |