blob: 59902d05f538abf8d92a6943efa4cc7113e20c59 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.remote.io.socket.ssl;
import org.apache.nifi.remote.exception.TransmissionDisabledException;
import org.apache.nifi.remote.io.socket.BufferStateManager;
import org.apache.nifi.remote.io.socket.BufferStateManager.Direction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLPeerUnverifiedException;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedByInterruptException;
import java.nio.channels.SocketChannel;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.TimeUnit;
public class SSLSocketChannel implements Closeable {
public static final int MAX_WRITE_SIZE = 65536;
private static final Logger logger = LoggerFactory.getLogger(SSLSocketChannel.class);
private static final long BUFFER_FULL_EMPTY_WAIT_NANOS = TimeUnit.NANOSECONDS.convert(1, TimeUnit.MILLISECONDS);
private final String remoteAddress;
private final int port;
private final SSLEngine engine;
private final SocketAddress socketAddress;
private BufferStateManager streamInManager;
private BufferStateManager streamOutManager;
private BufferStateManager appDataManager;
private SocketChannel channel;
private final byte[] oneByteBuffer = new byte[1];
private int timeoutMillis = 30000;
private volatile boolean connected = false;
private boolean handshaking = false;
private boolean closed = false;
private volatile boolean interrupted = false;
public SSLSocketChannel(final SSLContext sslContext, final String hostname, final int port, final InetAddress localAddress, final boolean client) throws IOException {
this.socketAddress = new InetSocketAddress(hostname, port);
this.channel = SocketChannel.open();
if (localAddress != null) {
final SocketAddress localSocketAddress = new InetSocketAddress(localAddress, 0);
this.channel.bind(localSocketAddress);
}
this.remoteAddress = hostname;
this.port = port;
this.engine = sslContext.createSSLEngine();
this.engine.setUseClientMode(client);
engine.setNeedClientAuth(true);
streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
}
public SSLSocketChannel(final SSLContext sslContext, final SocketChannel socketChannel, final boolean client) throws IOException {
if (!socketChannel.isConnected()) {
throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
}
this.channel = socketChannel;
this.socketAddress = socketChannel.getRemoteAddress();
final Socket socket = socketChannel.socket();
this.remoteAddress = socket.getInetAddress().toString();
this.port = socket.getPort();
this.engine = sslContext.createSSLEngine();
this.engine.setUseClientMode(client);
this.engine.setNeedClientAuth(true);
streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
}
public SSLSocketChannel(final SSLEngine sslEngine, final SocketChannel socketChannel) throws IOException {
if (!socketChannel.isConnected()) {
throw new IllegalArgumentException("Cannot pass an un-connected SocketChannel");
}
this.channel = socketChannel;
this.socketAddress = socketChannel.getRemoteAddress();
final Socket socket = socketChannel.socket();
this.remoteAddress = socket.getInetAddress().toString();
this.port = socket.getPort();
// don't set useClientMode or needClientAuth, use the engine as is and let the caller configure it
this.engine = sslEngine;
streamInManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
streamOutManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getPacketBufferSize()));
appDataManager = new BufferStateManager(ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()));
}
public void setTimeout(final int millis) {
this.timeoutMillis = millis;
}
public int getTimeout() {
return timeoutMillis;
}
public void connect() throws IOException {
try {
channel.configureBlocking(false);
if (!channel.isConnected()) {
final long startTime = System.currentTimeMillis();
if (!channel.connect(socketAddress)) {
while (!channel.finishConnect()) {
if (interrupted) {
throw new TransmissionDisabledException();
}
if (System.currentTimeMillis() > startTime + timeoutMillis) {
throw new SocketTimeoutException("Timed out connecting to " + remoteAddress + ":" + port);
}
try {
Thread.sleep(50L);
} catch (final InterruptedException e) {
}
}
}
}
engine.beginHandshake();
performHandshake();
logger.debug("{} Successfully completed SSL handshake", this);
streamInManager.clear();
streamOutManager.clear();
appDataManager.clear();
connected = true;
} catch (final Exception e) {
logger.error("{} failed to connect", this, e);
closeQuietly(channel);
engine.closeInbound();
engine.closeOutbound();
throw e;
}
}
public String getDn() throws CertificateException, SSLPeerUnverifiedException {
final Certificate[] certs = engine.getSession().getPeerCertificates();
if (certs == null || certs.length == 0) {
throw new SSLPeerUnverifiedException("No certificates found");
}
final Certificate certificate = certs[0];
if (certificate instanceof X509Certificate) {
final X509Certificate peerCertificate = (X509Certificate) certificate;
peerCertificate.checkValidity();
return peerCertificate.getSubjectDN().getName().trim();
} else {
throw new CertificateException(String.format("X.509 Certificate class not found [%s]", certificate.getClass()));
}
}
private void performHandshake() throws IOException {
// Generate handshake message
final byte[] emptyMessage = new byte[0];
handshaking = true;
logger.debug("{} Performing Handshake", this);
try {
while (true) {
switch (engine.getHandshakeStatus()) {
case FINISHED:
return;
case NEED_WRAP: {
final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
final SSLEngineResult wrapHelloResult = engine.wrap(appDataOut, outboundBuffer);
if (wrapHelloResult.getStatus() == Status.BUFFER_OVERFLOW) {
streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
continue;
}
if (wrapHelloResult.getStatus() != Status.OK) {
throw new SSLHandshakeException("Could not generate SSL Handshake information: SSLEngineResult: "
+ wrapHelloResult.toString());
}
logger.trace("{} Handshake response after wrapping: {}", this, wrapHelloResult);
final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
final int bytesToSend = readableStreamOut.remaining();
writeFully(readableStreamOut);
logger.trace("{} Sent {} bytes of wrapped data for handshake", this, bytesToSend);
streamOutManager.clear();
}
continue;
case NEED_UNWRAP: {
final ByteBuffer readableDataIn = streamInManager.prepareForRead(0);
final ByteBuffer appData = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
// Read handshake response from other side
logger.trace("{} Unwrapping: {} to {}", this, readableDataIn, appData);
SSLEngineResult handshakeResponseResult = engine.unwrap(readableDataIn, appData);
logger.trace("{} Handshake response after unwrapping: {}", this, handshakeResponseResult);
if (handshakeResponseResult.getStatus() == Status.BUFFER_UNDERFLOW) {
final ByteBuffer writableDataIn = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = readData(writableDataIn);
if (bytesRead > 0) {
logger.trace("{} Read {} bytes for handshake", this, bytesRead);
}
if (bytesRead < 0) {
throw new SSLHandshakeException("Reached End-of-File marker while performing handshake");
}
} else if (handshakeResponseResult.getStatus() == Status.CLOSED) {
throw new IOException("Channel was closed by peer during handshake");
} else {
streamInManager.compact();
appDataManager.clear();
}
}
break;
case NEED_TASK:
performTasks();
continue;
case NOT_HANDSHAKING:
return;
}
}
} finally {
handshaking = false;
}
}
private void performTasks() {
Runnable runnable;
while ((runnable = engine.getDelegatedTask()) != null) {
runnable.run();
}
}
private void closeQuietly(final Closeable closeable) {
try {
closeable.close();
} catch (final Exception e) {
}
}
public void consume() throws IOException {
channel.shutdownInput();
final byte[] b = new byte[4096];
final ByteBuffer buffer = ByteBuffer.wrap(b);
int readCount;
do {
readCount = channel.read(buffer);
buffer.flip();
} while (readCount > 0);
}
private int readData(final ByteBuffer dest) throws IOException {
final long startTime = System.currentTimeMillis();
while (true) {
if (interrupted) {
throw new TransmissionDisabledException();
}
if (dest.remaining() == 0) {
return 0;
}
final int readCount = channel.read(dest);
long sleepNanos = 1L;
if (readCount == 0) {
if (System.currentTimeMillis() > startTime + timeoutMillis) {
throw new SocketTimeoutException("Timed out reading from socket connected to " + remoteAddress + ":" + port);
}
try {
TimeUnit.NANOSECONDS.sleep(sleepNanos);
} catch (InterruptedException e) {
close();
Thread.currentThread().interrupt(); // set the interrupt status
throw new ClosedByInterruptException();
}
sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
continue;
}
logger.trace("{} Read {} bytes", this, readCount);
return readCount;
}
}
private Status encryptAndWriteFully(final BufferStateManager src) throws IOException {
SSLEngineResult result = null;
final ByteBuffer buff = src.prepareForRead(0);
final ByteBuffer outBuff = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
logger.trace("{} Encrypting {} bytes", this, buff.remaining());
while (buff.remaining() > 0) {
result = engine.wrap(buff, outBuff);
if (result.getStatus() == Status.OK) {
final ByteBuffer readableOutBuff = streamOutManager.prepareForRead(0);
writeFully(readableOutBuff);
streamOutManager.clear();
} else {
return result.getStatus();
}
}
return result.getStatus();
}
private void writeFully(final ByteBuffer src) throws IOException {
long lastByteWrittenTime = System.currentTimeMillis();
int bytesWritten = 0;
while (src.hasRemaining()) {
if (interrupted) {
throw new TransmissionDisabledException();
}
final int written = channel.write(src);
bytesWritten += written;
final long now = System.currentTimeMillis();
long sleepNanos = 1L;
if (written > 0) {
lastByteWrittenTime = now;
} else {
if (now > lastByteWrittenTime + timeoutMillis) {
throw new SocketTimeoutException("Timed out writing to socket connected to " + remoteAddress + ":" + port);
}
try {
TimeUnit.NANOSECONDS.sleep(sleepNanos);
} catch (final InterruptedException e) {
close();
Thread.currentThread().interrupt(); // set the interrupt status
throw new ClosedByInterruptException();
}
sleepNanos = Math.min(sleepNanos * 2, BUFFER_FULL_EMPTY_WAIT_NANOS);
}
}
logger.trace("{} Wrote {} bytes", this, bytesWritten);
}
public boolean isClosed() {
if (closed) {
return true;
}
// need to detect if peer has sent closure handshake...if so the answer is true
final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
int readCount = 0;
try {
readCount = channel.read(writableInBuffer);
} catch (IOException e) {
logger.error("{} failed to read data", this, e);
readCount = -1; // treat the condition same as if End of Stream
}
if (readCount == 0) {
return false;
}
if (readCount > 0) {
logger.trace("{} Read {} bytes", this, readCount);
final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
try {
SSLEngineResult unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
logger.trace("{} When checking if closed, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
if (unwrapResponse.getStatus().equals(Status.CLOSED)) {
// Drain the incoming TCP buffer
final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
int bytesDiscarded = channel.read(discardBuffer);
while (bytesDiscarded > 0) {
discardBuffer.clear();
bytesDiscarded = channel.read(discardBuffer);
}
engine.closeInbound();
} else {
streamInManager.compact();
return false;
}
} catch (IOException e) {
logger.error("{} failed to check if closed. Closing channel.", this, e);
}
}
// either readCount is -1, indicating an end of stream, or the peer sent a closure handshake
// so go ahead and close down the channel
closeQuietly(channel.socket());
closeQuietly(channel);
closed = true;
return true;
}
@Override
public void close() throws IOException {
logger.debug("{} Closing Connection", this);
if (channel == null) {
return;
}
if (closed) {
return;
}
try {
engine.closeOutbound();
final byte[] emptyMessage = new byte[0];
final ByteBuffer appDataOut = ByteBuffer.wrap(emptyMessage);
final ByteBuffer outboundBuffer = streamOutManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
final SSLEngineResult handshakeResult = engine.wrap(appDataOut, outboundBuffer);
if (handshakeResult.getStatus() != Status.CLOSED) {
throw new IOException("Invalid close state - will not send network data");
}
final ByteBuffer readableStreamOut = streamOutManager.prepareForRead(1);
writeFully(readableStreamOut);
} finally {
// Drain the incoming TCP buffer
final ByteBuffer discardBuffer = ByteBuffer.allocate(8192);
try {
int bytesDiscarded = channel.read(discardBuffer);
while (bytesDiscarded > 0) {
discardBuffer.clear();
bytesDiscarded = channel.read(discardBuffer);
}
} catch (Exception e) {
}
closeQuietly(channel.socket());
closeQuietly(channel);
closed = true;
}
}
private int copyFromAppDataBuffer(final byte[] buffer, final int offset, final int len) {
// If any data already exists in the application data buffer, copy it to the buffer.
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
final int appDataRemaining = appDataBuffer.remaining();
if (appDataRemaining > 0) {
final int bytesToCopy = Math.min(len, appDataBuffer.remaining());
appDataBuffer.get(buffer, offset, bytesToCopy);
final int bytesCopied = appDataRemaining - appDataBuffer.remaining();
logger.trace("{} Copied {} ({}) bytes from unencrypted application buffer to user space",
this, bytesToCopy, bytesCopied);
return bytesCopied;
}
return 0;
}
public int available() throws IOException {
ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
final int buffered = appDataBuffer.remaining() + streamDataBuffer.remaining();
if (buffered > 0) {
return buffered;
}
final boolean wasAbleToRead = isDataAvailable();
if (!wasAbleToRead) {
return 0;
}
appDataBuffer = appDataManager.prepareForRead(1);
streamDataBuffer = streamInManager.prepareForRead(1);
return appDataBuffer.remaining() + streamDataBuffer.remaining();
}
public boolean isDataAvailable() throws IOException {
final ByteBuffer appDataBuffer = appDataManager.prepareForRead(1);
final ByteBuffer streamDataBuffer = streamInManager.prepareForRead(1);
if (appDataBuffer.remaining() > 0 || streamDataBuffer.remaining() > 0) {
return true;
}
final ByteBuffer writableBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = channel.read(writableBuffer);
return (bytesRead > 0);
}
public int read() throws IOException {
final int bytesRead = read(oneByteBuffer);
if (bytesRead == -1) {
return -1;
}
return oneByteBuffer[0] & 0xFF;
}
public int read(final byte[] buffer) throws IOException {
return read(buffer, 0, buffer.length);
}
public int read(final byte[] buffer, final int offset, final int len) throws IOException {
logger.debug("{} Reading up to {} bytes of data", this, len);
if (!connected) {
connect();
}
int copied = copyFromAppDataBuffer(buffer, offset, len);
if (copied > 0) {
return copied;
}
appDataManager.clear();
while (true) {
// prepare buffers and call unwrap
final ByteBuffer streamInBuffer = streamInManager.prepareForRead(1);
SSLEngineResult unwrapResponse = null;
final ByteBuffer appDataBuffer = appDataManager.prepareForWrite(engine.getSession().getApplicationBufferSize());
unwrapResponse = engine.unwrap(streamInBuffer, appDataBuffer);
logger.trace("{} When reading data, (handshake={}) Unwrap response: {}", this, handshaking, unwrapResponse);
switch (unwrapResponse.getStatus()) {
case BUFFER_OVERFLOW:
throw new SSLHandshakeException("Buffer Overflow, which is not allowed to happen from an unwrap");
case BUFFER_UNDERFLOW: {
// appDataManager.prepareForRead(engine.getSession().getApplicationBufferSize());
final ByteBuffer writableInBuffer = streamInManager.prepareForWrite(engine.getSession().getPacketBufferSize());
final int bytesRead = readData(writableInBuffer);
if (bytesRead < 0) {
return -1;
}
continue;
}
case CLOSED:
copied = copyFromAppDataBuffer(buffer, offset, len);
if (copied == 0) {
return -1;
}
streamInManager.compact();
return copied;
case OK: {
copied = copyFromAppDataBuffer(buffer, offset, len);
if (copied == 0) {
throw new IOException("Failed to decrypt data");
}
streamInManager.compact();
return copied;
}
}
}
}
public void write(final int data) throws IOException {
write(new byte[]{(byte) data}, 0, 1);
}
public void write(final byte[] data) throws IOException {
write(data, 0, data.length);
}
public void write(final byte[] data, final int offset, final int len) throws IOException {
logger.debug("{} Writing {} bytes of data", this, len);
if (!connected) {
connect();
}
int iterations = len / MAX_WRITE_SIZE;
if (len % MAX_WRITE_SIZE > 0) {
iterations++;
}
for (int i = 0; i < iterations; i++) {
streamOutManager.clear();
final int itrOffset = offset + i * MAX_WRITE_SIZE;
final int itrLen = Math.min(len - itrOffset, MAX_WRITE_SIZE);
final ByteBuffer byteBuffer = ByteBuffer.wrap(data, itrOffset, itrLen);
final BufferStateManager buffMan = new BufferStateManager(byteBuffer, Direction.READ);
final Status status = encryptAndWriteFully(buffMan);
switch (status) {
case BUFFER_OVERFLOW:
streamOutManager.ensureSize(engine.getSession().getPacketBufferSize());
appDataManager.ensureSize(engine.getSession().getApplicationBufferSize());
continue;
case OK:
continue;
case CLOSED:
throw new IOException("Channel is closed");
case BUFFER_UNDERFLOW:
throw new AssertionError("Got Buffer Underflow but should not have...");
}
}
}
public void interrupt() {
this.interrupted = true;
}
}