| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.blur.thrift.sasl; |
| |
| import java.io.UnsupportedEncodingException; |
| import java.net.InetSocketAddress; |
| import java.net.Socket; |
| import java.net.SocketAddress; |
| import java.util.HashMap; |
| import java.util.Map; |
| |
| import javax.security.sasl.Sasl; |
| import javax.security.sasl.SaslClient; |
| import javax.security.sasl.SaslException; |
| import javax.security.sasl.SaslServer; |
| |
| import org.apache.blur.thirdparty.thrift_0_9_0.EncodingUtils; |
| import org.apache.blur.thirdparty.thrift_0_9_0.TByteArrayOutputStream; |
| import org.apache.blur.thirdparty.thrift_0_9_0.transport.TFramedTransport; |
| import org.apache.blur.thirdparty.thrift_0_9_0.transport.TMemoryInputTransport; |
| import org.apache.blur.thirdparty.thrift_0_9_0.transport.TSocket; |
| import org.apache.blur.thirdparty.thrift_0_9_0.transport.TTransport; |
| import org.apache.blur.thirdparty.thrift_0_9_0.transport.TTransportException; |
| import org.apache.blur.utils.ThreadValue; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * A superclass for SASL client/server thrift transports. A subclass need only |
| * implement the <code>open</open> method. |
| */ |
| abstract class TSaslTransport extends TTransport { |
| |
| private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class); |
| |
| protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF; |
| |
| protected static final int MECHANISM_NAME_BYTES = 1; |
| protected static final int STATUS_BYTES = 1; |
| protected static final int PAYLOAD_LENGTH_BYTES = 4; |
| |
| protected static enum SaslRole { |
| SERVER, CLIENT; |
| } |
| |
| /** |
| * Status bytes used during the initial Thrift SASL handshake. |
| */ |
| protected static enum NegotiationStatus { |
| START((byte) 0x01), OK((byte) 0x02), BAD((byte) 0x03), ERROR((byte) 0x04), COMPLETE((byte) 0x05); |
| |
| private final byte value; |
| |
| private static final Map<Byte, NegotiationStatus> reverseMap = new HashMap<Byte, NegotiationStatus>(); |
| static { |
| for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) { |
| reverseMap.put(s.getValue(), s); |
| } |
| } |
| |
| private NegotiationStatus(byte val) { |
| this.value = val; |
| } |
| |
| public byte getValue() { |
| return value; |
| } |
| |
| public static NegotiationStatus byValue(byte val) { |
| return reverseMap.get(val); |
| } |
| } |
| |
| /** |
| * Transport underlying this one. |
| */ |
| protected TTransport underlyingTransport; |
| |
| /** |
| * Either a SASL client or a SASL server. |
| */ |
| private SaslParticipant sasl; |
| |
| /** |
| * Whether or not we should wrap/unwrap reads/writes. Determined by whether or |
| * not a QOP is negotiated during the SASL handshake. |
| */ |
| private boolean shouldWrap = false; |
| |
| /** |
| * Buffer for input. |
| */ |
| private TMemoryInputTransport readBuffer = new TMemoryInputTransport(); |
| |
| /** |
| * Buffer for output. |
| */ |
| private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024); |
| |
| /** |
| * Create a TSaslTransport. It's assumed that setSaslServer will be called |
| * later to initialize the SASL endpoint underlying this transport. |
| * |
| * @param underlyingTransport |
| * The thrift transport which this transport is wrapping. |
| */ |
| protected TSaslTransport(TTransport underlyingTransport) { |
| this.underlyingTransport = underlyingTransport; |
| } |
| |
| /** |
| * Create a TSaslTransport which acts as a client. |
| * |
| * @param saslClient |
| * The <code>SaslClient</code> which this transport will use for SASL |
| * negotiation. |
| * @param underlyingTransport |
| * The thrift transport which this transport is wrapping. |
| */ |
| protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) { |
| sasl = new SaslParticipant(saslClient); |
| this.underlyingTransport = underlyingTransport; |
| } |
| |
| protected void setSaslServer(SaslServer saslServer) { |
| sasl = new SaslParticipant(saslServer); |
| } |
| |
| // Used to read the status byte and payload length. |
| private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; |
| |
| /** |
| * Send a complete Thrift SASL message. |
| * |
| * @param status |
| * The status to send. |
| * @param payload |
| * The data to send as the payload of this message. |
| * @throws TTransportException |
| */ |
| protected void sendSaslMessage(NegotiationStatus status, byte[] payload) throws TTransportException { |
| if (payload == null) |
| payload = new byte[0]; |
| |
| messageHeader[0] = status.getValue(); |
| EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES); |
| |
| if (LOGGER.isDebugEnabled()) |
| LOGGER.debug(getRole() + ": Writing message with status {} and payload length {}", status, payload.length); |
| underlyingTransport.write(messageHeader); |
| underlyingTransport.write(payload); |
| underlyingTransport.flush(); |
| } |
| |
| /** |
| * Read a complete Thrift SASL message. |
| * |
| * @return The SASL status and payload from this message. |
| * @throws TTransportException |
| * Thrown if there is a failure reading from the underlying |
| * transport, or if a status code of BAD or ERROR is encountered. |
| */ |
| protected SaslResponse receiveSaslMessage() throws TTransportException { |
| underlyingTransport.readAll(messageHeader, 0, messageHeader.length); |
| |
| byte statusByte = messageHeader[0]; |
| byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)]; |
| underlyingTransport.readAll(payload, 0, payload.length); |
| |
| NegotiationStatus status = NegotiationStatus.byValue(statusByte); |
| if (status == null) { |
| sendAndThrowMessage(NegotiationStatus.ERROR, "Invalid status " + statusByte); |
| } else if (status == NegotiationStatus.BAD || status == NegotiationStatus.ERROR) { |
| try { |
| String remoteMessage = new String(payload, "UTF-8"); |
| throw new TTransportException("Peer indicated failure: " + remoteMessage); |
| } catch (UnsupportedEncodingException e) { |
| throw new TTransportException(e); |
| } |
| } |
| |
| if (LOGGER.isDebugEnabled()) |
| LOGGER.debug(getRole() + ": Received message with status {} and payload length {}", status, payload.length); |
| return new SaslResponse(status, payload); |
| } |
| |
| /** |
| * Send a Thrift SASL message with the given status (usually BAD or ERROR) and |
| * string message, and then throw a TTransportException with the given |
| * message. |
| * |
| * @param status |
| * The Thrift SASL status code to send. Usually BAD or ERROR. |
| * @param message |
| * The optional message to send to the other side. |
| * @throws TTransportException |
| * Always thrown with the message provided. |
| */ |
| protected void sendAndThrowMessage(NegotiationStatus status, String message) throws TTransportException { |
| try { |
| sendSaslMessage(status, message.getBytes()); |
| } catch (Exception e) { |
| LOGGER.warn("Could not send failure response", e); |
| message += "\nAlso, could not send response: " + e.toString(); |
| } |
| throw new TTransportException(message); |
| } |
| |
| /** |
| * Implemented by subclasses to start the Thrift SASL handshake process. When |
| * this method completes, the <code>SaslParticipant</code> in this class is |
| * assumed to be initialized. |
| * |
| * @throws TTransportException |
| * @throws SaslException |
| */ |
| abstract protected void handleSaslStartMessage() throws TTransportException, SaslException; |
| |
| protected abstract SaslRole getRole(); |
| |
| /** |
| * Opens the underlying transport if it's not already open and then performs |
| * SASL negotiation. If a QOP is negotiated during this SASL handshake, it |
| * used for all communication on this transport after this call is complete. |
| */ |
| @Override |
| public void open() throws TTransportException { |
| LOGGER.debug("opening transport {}", this); |
| if (sasl != null && sasl.isComplete()) |
| throw new TTransportException("SASL transport already open"); |
| |
| if (!underlyingTransport.isOpen()) |
| underlyingTransport.open(); |
| |
| try { |
| // Negotiate a SASL mechanism. The client also sends its |
| // initial response, or an empty one. |
| handleSaslStartMessage(); |
| LOGGER.debug("{}: Start message handled", getRole()); |
| |
| SaslResponse message = null; |
| while (!sasl.isComplete()) { |
| message = receiveSaslMessage(); |
| if (message.status != NegotiationStatus.COMPLETE && message.status != NegotiationStatus.OK) { |
| throw new TTransportException("Expected COMPLETE or OK, got " + message.status); |
| } |
| byte[] challenge; |
| try { |
| setupConnectionInfo(); |
| challenge = sasl.evaluateChallengeOrResponse(message.payload); |
| } finally { |
| resetConnectionInfo(); |
| } |
| |
| // If we are the client, and the server indicates COMPLETE, we don't |
| // need to |
| // send back any further response. |
| if (message.status == NegotiationStatus.COMPLETE && getRole() == SaslRole.CLIENT) { |
| LOGGER.debug("{}: All done!", getRole()); |
| break; |
| } |
| |
| sendSaslMessage(sasl.isComplete() ? NegotiationStatus.COMPLETE : NegotiationStatus.OK, challenge); |
| } |
| LOGGER.debug("{}: Main negotiation loop complete", getRole()); |
| |
| assert sasl.isComplete(); |
| |
| // If we're the client, and we're complete, but the server isn't |
| // complete yet, we need to wait for its response. This will occur |
| // with ANONYMOUS auth, for example, where we send an initial response |
| // and are immediately complete. |
| if (getRole() == SaslRole.CLIENT && (message == null || message.status == NegotiationStatus.OK)) { |
| LOGGER.debug("{}: SASL Client receiving last message", getRole()); |
| message = receiveSaslMessage(); |
| if (message.status != NegotiationStatus.COMPLETE) { |
| throw new TTransportException("Expected SASL COMPLETE, but got " + message.status); |
| } |
| } |
| } catch (SaslException e) { |
| try { |
| LOGGER.error("SASL negotiation failure", e); |
| sendAndThrowMessage(NegotiationStatus.BAD, e.getMessage()); |
| } finally { |
| underlyingTransport.close(); |
| } |
| } |
| |
| String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP); |
| if (qop != null && !qop.equalsIgnoreCase("auth")) |
| shouldWrap = true; |
| } |
| |
| private void resetConnectionInfo() { |
| _currentConnection.set(null); |
| } |
| |
| static ThreadValue<InetSocketAddress> _currentConnection = new ThreadValue<InetSocketAddress>(); |
| |
| private void setupConnectionInfo() { |
| if (underlyingTransport instanceof TSocket) { |
| TSocket tSocket = (TSocket) underlyingTransport; |
| Socket socket = tSocket.getSocket(); |
| SocketAddress remoteSocketAddress = socket.getRemoteSocketAddress(); |
| InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteSocketAddress; |
| _currentConnection.set(inetSocketAddress); |
| } |
| } |
| |
| /** |
| * Get the underlying <code>SaslClient</code>. |
| * |
| * @return The <code>SaslClient</code>, or <code>null</code> if this transport |
| * is backed by a <code>SaslServer</code>. |
| */ |
| public SaslClient getSaslClient() { |
| return sasl.saslClient; |
| } |
| |
| /** |
| * Get the underlying transport that Sasl is using. |
| * |
| * @return The <code>TTransport</code> transport |
| */ |
| public TTransport getUnderlyingTransport() { |
| return underlyingTransport; |
| } |
| |
| /** |
| * Get the underlying <code>SaslServer</code>. |
| * |
| * @return The <code>SaslServer</code>, or <code>null</code> if this transport |
| * is backed by a <code>SaslClient</code>. |
| */ |
| public SaslServer getSaslServer() { |
| return sasl.saslServer; |
| } |
| |
| /** |
| * Read a 4-byte word from the underlying transport and interpret it as an |
| * integer. |
| * |
| * @return The length prefix of the next SASL message to read. |
| * @throws TTransportException |
| * Thrown if reading from the underlying transport fails. |
| */ |
| protected int readLength() throws TTransportException { |
| byte[] lenBuf = new byte[4]; |
| underlyingTransport.readAll(lenBuf, 0, lenBuf.length); |
| return EncodingUtils.decodeBigEndian(lenBuf); |
| } |
| |
| /** |
| * Write the given integer as 4 bytes to the underlying transport. |
| * |
| * @param length |
| * The length prefix of the next SASL message to write. |
| * @throws TTransportException |
| * Thrown if writing to the underlying transport fails. |
| */ |
| protected void writeLength(int length) throws TTransportException { |
| byte[] lenBuf = new byte[4]; |
| TFramedTransport.encodeFrameSize(length, lenBuf); |
| underlyingTransport.write(lenBuf); |
| } |
| |
| // Below is the SASL implementation of the TTransport interface. |
| |
| /** |
| * Closes the underlying transport and disposes of the SASL implementation |
| * underlying this transport. |
| */ |
| @Override |
| public void close() { |
| underlyingTransport.close(); |
| try { |
| sasl.dispose(); |
| } catch (SaslException e) { |
| // Not much we can do here. |
| } |
| } |
| |
| /** |
| * True if the underlying transport is open and the SASL handshake is |
| * complete. |
| */ |
| @Override |
| public boolean isOpen() { |
| return underlyingTransport.isOpen() && sasl != null && sasl.isComplete(); |
| } |
| |
| /** |
| * Read from the underlying transport. Unwraps the contents if a QOP was |
| * negotiated during the SASL handshake. |
| */ |
| @Override |
| public int read(byte[] buf, int off, int len) throws TTransportException { |
| if (!isOpen()) |
| throw new TTransportException("SASL authentication not complete"); |
| |
| int got = readBuffer.read(buf, off, len); |
| if (got > 0) { |
| return got; |
| } |
| |
| // Read another frame of data |
| try { |
| readFrame(); |
| } catch (SaslException e) { |
| throw new TTransportException(e); |
| } |
| |
| return readBuffer.read(buf, off, len); |
| } |
| |
| /** |
| * Read a single frame of data from the underlying transport, unwrapping if |
| * necessary. |
| * |
| * @throws TTransportException |
| * Thrown if there's an error reading from the underlying transport. |
| * @throws SaslException |
| * Thrown if there's an error unwrapping the data. |
| */ |
| private void readFrame() throws TTransportException, SaslException { |
| int dataLength = readLength(); |
| |
| if (dataLength < 0) |
| throw new TTransportException("Read a negative frame size (" + dataLength + ")!"); |
| |
| byte[] buff = new byte[dataLength]; |
| LOGGER.debug("{}: reading data length: {}", getRole(), dataLength); |
| underlyingTransport.readAll(buff, 0, dataLength); |
| if (shouldWrap) { |
| buff = sasl.unwrap(buff, 0, buff.length); |
| LOGGER.debug("data length after unwrap: {}", buff.length); |
| } |
| readBuffer.reset(buff); |
| } |
| |
| /** |
| * Write to the underlying transport. |
| */ |
| @Override |
| public void write(byte[] buf, int off, int len) throws TTransportException { |
| if (!isOpen()) |
| throw new TTransportException("SASL authentication not complete"); |
| |
| writeBuffer.write(buf, off, len); |
| } |
| |
| /** |
| * Flushes to the underlying transport. Wraps the contents if a QOP was |
| * negotiated during the SASL handshake. |
| */ |
| @Override |
| public void flush() throws TTransportException { |
| byte[] buf = writeBuffer.get(); |
| int dataLength = writeBuffer.len(); |
| writeBuffer.reset(); |
| |
| if (shouldWrap) { |
| LOGGER.debug("data length before wrap: {}", dataLength); |
| try { |
| buf = sasl.wrap(buf, 0, dataLength); |
| } catch (SaslException e) { |
| throw new TTransportException(e); |
| } |
| dataLength = buf.length; |
| } |
| LOGGER.debug("writing data length: {}", dataLength); |
| writeLength(dataLength); |
| underlyingTransport.write(buf, 0, dataLength); |
| underlyingTransport.flush(); |
| } |
| |
| /** |
| * Used exclusively by readSaslMessage to return both a status and data. |
| */ |
| protected static class SaslResponse { |
| public NegotiationStatus status; |
| public byte[] payload; |
| |
| public SaslResponse(NegotiationStatus status, byte[] payload) { |
| this.status = status; |
| this.payload = payload; |
| } |
| } |
| |
| /** |
| * Used to abstract over the <code>SaslServer</code> and |
| * <code>SaslClient</code> classes, which share a lot of their interface, but |
| * unfortunately don't share a common superclass. |
| */ |
| private static class SaslParticipant { |
| // One of these will always be null. |
| public SaslServer saslServer; |
| public SaslClient saslClient; |
| |
| public SaslParticipant(SaslServer saslServer) { |
| this.saslServer = saslServer; |
| } |
| |
| public SaslParticipant(SaslClient saslClient) { |
| this.saslClient = saslClient; |
| } |
| |
| public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException { |
| if (saslClient != null) { |
| return saslClient.evaluateChallenge(challengeOrResponse); |
| } else { |
| return saslServer.evaluateResponse(challengeOrResponse); |
| } |
| } |
| |
| public boolean isComplete() { |
| if (saslClient != null) |
| return saslClient.isComplete(); |
| else |
| return saslServer.isComplete(); |
| } |
| |
| public void dispose() throws SaslException { |
| if (saslClient != null) |
| saslClient.dispose(); |
| else |
| saslServer.dispose(); |
| } |
| |
| public byte[] unwrap(byte[] buf, int off, int len) throws SaslException { |
| if (saslClient != null) |
| return saslClient.unwrap(buf, off, len); |
| else |
| return saslServer.unwrap(buf, off, len); |
| } |
| |
| public byte[] wrap(byte[] buf, int off, int len) throws SaslException { |
| if (saslClient != null) |
| return saslClient.wrap(buf, off, len); |
| else |
| return saslServer.wrap(buf, off, len); |
| } |
| |
| public Object getNegotiatedProperty(String propName) { |
| if (saslClient != null) |
| return saslClient.getNegotiatedProperty(propName); |
| else |
| return saslServer.getNegotiatedProperty(propName); |
| } |
| } |
| } |