blob: 15361f0b297aec45c60aff24caafe97af8688058 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.blur.thrift.sasl;
import java.io.UnsupportedEncodingException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.HashMap;
import java.util.Map;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import org.apache.blur.thirdparty.thrift_0_9_0.EncodingUtils;
import org.apache.blur.thirdparty.thrift_0_9_0.TByteArrayOutputStream;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TFramedTransport;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TMemoryInputTransport;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TSocket;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TTransport;
import org.apache.blur.thirdparty.thrift_0_9_0.transport.TTransportException;
import org.apache.blur.utils.ThreadValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A superclass for SASL client/server thrift transports. A subclass need only
* implement the <code>open</open> method.
*/
abstract class TSaslTransport extends TTransport {
private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class);
protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF;
protected static final int MECHANISM_NAME_BYTES = 1;
protected static final int STATUS_BYTES = 1;
protected static final int PAYLOAD_LENGTH_BYTES = 4;
protected static enum SaslRole {
SERVER, CLIENT;
}
/**
* Status bytes used during the initial Thrift SASL handshake.
*/
protected static enum NegotiationStatus {
START((byte) 0x01), OK((byte) 0x02), BAD((byte) 0x03), ERROR((byte) 0x04), COMPLETE((byte) 0x05);
private final byte value;
private static final Map<Byte, NegotiationStatus> reverseMap = new HashMap<Byte, NegotiationStatus>();
static {
for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
reverseMap.put(s.getValue(), s);
}
}
private NegotiationStatus(byte val) {
this.value = val;
}
public byte getValue() {
return value;
}
public static NegotiationStatus byValue(byte val) {
return reverseMap.get(val);
}
}
/**
* Transport underlying this one.
*/
protected TTransport underlyingTransport;
/**
* Either a SASL client or a SASL server.
*/
private SaslParticipant sasl;
/**
* Whether or not we should wrap/unwrap reads/writes. Determined by whether or
* not a QOP is negotiated during the SASL handshake.
*/
private boolean shouldWrap = false;
/**
* Buffer for input.
*/
private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
/**
* Buffer for output.
*/
private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024);
/**
* Create a TSaslTransport. It's assumed that setSaslServer will be called
* later to initialize the SASL endpoint underlying this transport.
*
* @param underlyingTransport
* The thrift transport which this transport is wrapping.
*/
protected TSaslTransport(TTransport underlyingTransport) {
this.underlyingTransport = underlyingTransport;
}
/**
* Create a TSaslTransport which acts as a client.
*
* @param saslClient
* The <code>SaslClient</code> which this transport will use for SASL
* negotiation.
* @param underlyingTransport
* The thrift transport which this transport is wrapping.
*/
protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) {
sasl = new SaslParticipant(saslClient);
this.underlyingTransport = underlyingTransport;
}
protected void setSaslServer(SaslServer saslServer) {
sasl = new SaslParticipant(saslServer);
}
// Used to read the status byte and payload length.
private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
/**
* Send a complete Thrift SASL message.
*
* @param status
* The status to send.
* @param payload
* The data to send as the payload of this message.
* @throws TTransportException
*/
protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException {
if (payload == null)
payload = new byte[0];
messageHeader[0] = status.getValue();
EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES);
if (LOGGER.isDebugEnabled())
LOGGER.debug(getRole() + ": Writing message with status {} and payload length {}", status, payload.length);
underlyingTransport.write(messageHeader);
underlyingTransport.write(payload);
underlyingTransport.flush();
}
/**
* Read a complete Thrift SASL message.
*
* @return The SASL status and payload from this message.
* @throws TTransportException
* Thrown if there is a failure reading from the underlying
* transport, or if a status code of BAD or ERROR is encountered.
*/
protected SaslResponse receiveSaslMessage() throws TTransportException {
underlyingTransport.readAll(messageHeader, 0, messageHeader.length);
byte statusByte = messageHeader[0];
byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)];
underlyingTransport.readAll(payload, 0, payload.length);
NegotiationStatus status = NegotiationStatus.byValue(statusByte);
if (status == null) {
sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte);
} else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) {
try {
String remoteMessage = new String(payload, "UTF-8");
throw new TTransportException("Peer indicated failure: " + remoteMessage);
} catch (UnsupportedEncodingException e) {
throw new TTransportException(e);
}
}
if (LOGGER.isDebugEnabled())
LOGGER.debug(getRole() + ": Received message with status {} and payload length {}", status, payload.length);
return new SaslResponse(status, payload);
}
/**
* Send a Thrift SASL message with the given status (usually BAD or ERROR) and
* string message, and then throw a TTransportException with the given
* message.
*
* @param status
* The Thrift SASL status code to send. Usually BAD or ERROR.
* @param message
* The optional message to send to the other side.
* @throws TTransportException
* Always thrown with the message provided.
*/
protected void sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException {
try {
sendSaslMessage(status, message.getBytes());
} catch (Exception e) {
LOGGER.warn("Could not send failure response", e);
message += "\nAlso, could not send response: " + e.toString();
}
throw new TTransportException(message);
}
/**
* Implemented by subclasses to start the Thrift SASL handshake process. When
* this method completes, the <code>SaslParticipant</code> in this class is
* assumed to be initialized.
*
* @throws TTransportException
* @throws SaslException
*/
abstract protected void handleSaslStartMessage() throws TTransportException, SaslException;
protected abstract SaslRole getRole();
/**
* Opens the underlying transport if it's not already open and then performs
* SASL negotiation. If a QOP is negotiated during this SASL handshake, it
* used for all communication on this transport after this call is complete.
*/
@Override
public void open() throws TTransportException {
LOGGER.debug("opening transport {}", this);
if (sasl != null && sasl.isComplete())
throw new TTransportException("SASL transport already open");
if (!underlyingTransport.isOpen())
underlyingTransport.open();
try {
// Negotiate a SASL mechanism. The client also sends its
// initial response, or an empty one.
handleSaslStartMessage();
LOGGER.debug("{}: Start message handled", getRole());
SaslResponse message = null;
while (!sasl.isComplete()) {
message = receiveSaslMessage();
if (message.status != NegotiationStatus.COMPLETE && message.status != NegotiationStatus.OK) {
throw new TTransportException("Expected COMPLETE or OK, got " + message.status);
}
byte[] challenge;
try {
setupConnectionInfo();
challenge = sasl.evaluateChallengeOrResponse(message.payload);
} finally {
resetConnectionInfo();
}
// If we are the client, and the server indicates COMPLETE, we don't
// need to
// send back any further response.
if (message.status == NegotiationStatus.COMPLETE && getRole() == SaslRole.CLIENT) {
LOGGER.debug("{}: All done!", getRole());
break;
}
sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK, challenge);
}
LOGGER.debug("{}: Main negotiation loop complete", getRole());
assert sasl.isComplete();
// If we're the client, and we're complete, but the server isn't
// complete yet, we need to wait for its response. This will occur
// with ANONYMOUS auth, for example, where we send an initial response
// and are immediately complete.
if (getRole() == SaslRole.CLIENT && (message == null || message.status == NegotiationStatus.OK)) {
LOGGER.debug("{}: SASL Client receiving last message", getRole());
message = receiveSaslMessage();
if (message.status != NegotiationStatus.COMPLETE) {
throw new TTransportException("Expected SASL COMPLETE, but got " + message.status);
}
}
} catch (SaslException e) {
try {
LOGGER.error("SASL negotiation failure", e);
sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage());
} finally {
underlyingTransport.close();
}
}
String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
if (qop != null && !qop.equalsIgnoreCase("auth"))
shouldWrap = true;
}
private void resetConnectionInfo() {
_currentConnection.set(null);
}
static ThreadValue<InetSocketAddress> _currentConnection = new ThreadValue<InetSocketAddress>();
private void setupConnectionInfo() {
if (underlyingTransport instanceof TSocket) {
TSocket tSocket = (TSocket) underlyingTransport;
Socket socket = tSocket.getSocket();
SocketAddress remoteSocketAddress = socket.getRemoteSocketAddress();
InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteSocketAddress;
_currentConnection.set(inetSocketAddress);
}
}
/**
* Get the underlying <code>SaslClient</code>.
*
* @return The <code>SaslClient</code>, or <code>null</code> if this transport
* is backed by a <code>SaslServer</code>.
*/
public SaslClient getSaslClient() {
return sasl.saslClient;
}
/**
* Get the underlying transport that Sasl is using.
*
* @return The <code>TTransport</code> transport
*/
public TTransport getUnderlyingTransport() {
return underlyingTransport;
}
/**
* Get the underlying <code>SaslServer</code>.
*
* @return The <code>SaslServer</code>, or <code>null</code> if this transport
* is backed by a <code>SaslClient</code>.
*/
public SaslServer getSaslServer() {
return sasl.saslServer;
}
/**
* Read a 4-byte word from the underlying transport and interpret it as an
* integer.
*
* @return The length prefix of the next SASL message to read.
* @throws TTransportException
* Thrown if reading from the underlying transport fails.
*/
protected int readLength() throws TTransportException {
byte[] lenBuf = new byte[4];
underlyingTransport.readAll(lenBuf, 0, lenBuf.length);
return EncodingUtils.decodeBigEndian(lenBuf);
}
/**
* Write the given integer as 4 bytes to the underlying transport.
*
* @param length
* The length prefix of the next SASL message to write.
* @throws TTransportException
* Thrown if writing to the underlying transport fails.
*/
protected void writeLength(int length) throws TTransportException {
byte[] lenBuf = new byte[4];
TFramedTransport.encodeFrameSize(length, lenBuf);
underlyingTransport.write(lenBuf);
}
// Below is the SASL implementation of the TTransport interface.
/**
* Closes the underlying transport and disposes of the SASL implementation
* underlying this transport.
*/
@Override
public void close() {
underlyingTransport.close();
try {
sasl.dispose();
} catch (SaslException e) {
// Not much we can do here.
}
}
/**
* True if the underlying transport is open and the SASL handshake is
* complete.
*/
@Override
public boolean isOpen() {
return underlyingTransport.isOpen() && sasl != null && sasl.isComplete();
}
/**
* Read from the underlying transport. Unwraps the contents if a QOP was
* negotiated during the SASL handshake.
*/
@Override
public int read(byte[] buf, int off, int len) throws TTransportException {
if (!isOpen())
throw new TTransportException("SASL authentication not complete");
int got = readBuffer.read(buf, off, len);
if (got > 0) {
return got;
}
// Read another frame of data
try {
readFrame();
} catch (SaslException e) {
throw new TTransportException(e);
}
return readBuffer.read(buf, off, len);
}
/**
* Read a single frame of data from the underlying transport, unwrapping if
* necessary.
*
* @throws TTransportException
* Thrown if there's an error reading from the underlying transport.
* @throws SaslException
* Thrown if there's an error unwrapping the data.
*/
private void readFrame() throws TTransportException, SaslException {
int dataLength = readLength();
if (dataLength < 0)
throw new TTransportException("Read a negative frame size (" + dataLength + ")!");
byte[] buff = new byte[dataLength];
LOGGER.debug("{}: reading data length: {}", getRole(), dataLength);
underlyingTransport.readAll(buff, 0, dataLength);
if (shouldWrap) {
buff = sasl.unwrap(buff, 0, buff.length);
LOGGER.debug("data length after unwrap: {}", buff.length);
}
readBuffer.reset(buff);
}
/**
* Write to the underlying transport.
*/
@Override
public void write(byte[] buf, int off, int len) throws TTransportException {
if (!isOpen())
throw new TTransportException("SASL authentication not complete");
writeBuffer.write(buf, off, len);
}
/**
* Flushes to the underlying transport. Wraps the contents if a QOP was
* negotiated during the SASL handshake.
*/
@Override
public void flush() throws TTransportException {
byte[] buf = writeBuffer.get();
int dataLength = writeBuffer.len();
writeBuffer.reset();
if (shouldWrap) {
LOGGER.debug("data length before wrap: {}", dataLength);
try {
buf = sasl.wrap(buf, 0, dataLength);
} catch (SaslException e) {
throw new TTransportException(e);
}
dataLength = buf.length;
}
LOGGER.debug("writing data length: {}", dataLength);
writeLength(dataLength);
underlyingTransport.write(buf, 0, dataLength);
underlyingTransport.flush();
}
/**
* Used exclusively by readSaslMessage to return both a status and data.
*/
protected static class SaslResponse {
public NegotiationStatus status;
public byte[] payload;
public SaslResponse(NegotiationStatus status, byte[] payload) {
this.status = status;
this.payload = payload;
}
}
/**
* Used to abstract over the <code>SaslServer</code> and
* <code>SaslClient</code> classes, which share a lot of their interface, but
* unfortunately don't share a common superclass.
*/
private static class SaslParticipant {
// One of these will always be null.
public SaslServer saslServer;
public SaslClient saslClient;
public SaslParticipant(SaslServer saslServer) {
this.saslServer = saslServer;
}
public SaslParticipant(SaslClient saslClient) {
this.saslClient = saslClient;
}
public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException {
if (saslClient != null) {
return saslClient.evaluateChallenge(challengeOrResponse);
} else {
return saslServer.evaluateResponse(challengeOrResponse);
}
}
public boolean isComplete() {
if (saslClient != null)
return saslClient.isComplete();
else
return saslServer.isComplete();
}
public void dispose() throws SaslException {
if (saslClient != null)
saslClient.dispose();
else
saslServer.dispose();
}
public byte[] unwrap(byte[] buf, int off, int len) throws SaslException {
if (saslClient != null)
return saslClient.unwrap(buf, off, len);
else
return saslServer.unwrap(buf, off, len);
}
public byte[] wrap(byte[] buf, int off, int len) throws SaslException {
if (saslClient != null)
return saslClient.wrap(buf, off, len);
else
return saslServer.wrap(buf, off, len);
}
public Object getNegotiatedProperty(String propName) {
if (saslClient != null)
return saslClient.getNegotiatedProperty(propName);
else
return saslServer.getNegotiatedProperty(propName);
}
}
}