blob: eea4053682c7fd3078def41b34c6ea53cfe23d17 [file] [log] [blame]
/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
package org.apache.qpid.proton.engine.impl;
import static org.apache.qpid.proton.engine.impl.ByteBufferUtils.newWriteableBuffer;
import static org.apache.qpid.proton.engine.impl.ByteBufferUtils.pourAll;
import static org.apache.qpid.proton.engine.impl.ByteBufferUtils.pourBufferToArray;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.qpid.proton.amqp.Binary;
import org.apache.qpid.proton.amqp.Symbol;
import org.apache.qpid.proton.amqp.security.SaslChallenge;
import org.apache.qpid.proton.amqp.security.SaslCode;
import org.apache.qpid.proton.amqp.security.SaslFrameBody;
import org.apache.qpid.proton.amqp.security.SaslInit;
import org.apache.qpid.proton.amqp.security.SaslMechanisms;
import org.apache.qpid.proton.amqp.security.SaslResponse;
import org.apache.qpid.proton.codec.AMQPDefinedTypes;
import org.apache.qpid.proton.codec.DecoderImpl;
import org.apache.qpid.proton.codec.EncoderImpl;
import org.apache.qpid.proton.engine.Sasl;
import org.apache.qpid.proton.engine.SaslListener;
import org.apache.qpid.proton.engine.Transport;
import org.apache.qpid.proton.engine.TransportException;
public class SaslImpl implements Sasl, SaslFrameBody.SaslFrameBodyHandler<Void>, SaslFrameHandler, TransportLayer
{
private static final Logger _logger = Logger.getLogger(SaslImpl.class.getName());
public static final byte SASL_FRAME_TYPE = (byte) 1;
private static final String HEADER_DESCRIPTION = "SASL";
private final DecoderImpl _decoder = new DecoderImpl();
private final EncoderImpl _encoder = new EncoderImpl(_decoder);
private final TransportImpl _transport;
private boolean _tail_closed = false;
private boolean _head_closed = false;
private final int _maxFrameSize;
private final FrameWriter _frameWriter;
private ByteBuffer _pending;
private boolean _headerWritten;
private Binary _challengeResponse;
private SaslFrameParser _frameParser;
private boolean _initReceived;
private boolean _mechanismsSent;
private boolean _initSent;
enum Role { CLIENT, SERVER }
private SaslOutcome _outcome = SaslOutcome.PN_SASL_NONE;
private SaslState _state = SaslState.PN_SASL_IDLE;
private String _hostname;
private boolean _done;
private Symbol[] _mechanisms;
private Symbol _chosenMechanism;
private Role _role;
private boolean _allowSkip = true;
private SaslListener _saslListener;
/**
* @param maxFrameSize the size of the input and output buffers
* {@link SaslTransportWrapper#_inputBuffer} and
* {@link SaslTransportWrapper#_outputBuffer}.
*/
SaslImpl(TransportImpl transport, int maxFrameSize)
{
_transport = transport;
_maxFrameSize = maxFrameSize;
AMQPDefinedTypes.registerAllTypes(_decoder,_encoder);
_frameParser = new SaslFrameParser(this, _decoder, maxFrameSize, _transport);
_frameWriter = new FrameWriter(_encoder, maxFrameSize, FrameWriter.SASL_FRAME_TYPE, _transport);
}
void fail() {
if (_role == null || _role == Role.CLIENT) {
_role = Role.CLIENT;
_initSent = true;
} else {
_initReceived = true;
}
_done = true;
_outcome = SaslOutcome.PN_SASL_SYS;
}
@Override
public boolean isDone()
{
return _done && (_role==Role.CLIENT || _initReceived);
}
private void process()
{
processHeader();
if(_role == Role.SERVER)
{
if(!_mechanismsSent && _mechanisms != null)
{
SaslMechanisms mechanisms = new SaslMechanisms();
mechanisms.setSaslServerMechanisms(_mechanisms);
writeFrame(mechanisms);
_mechanismsSent = true;
_state = SaslState.PN_SASL_STEP;
}
if(getState() == SaslState.PN_SASL_STEP && getChallengeResponse() != null)
{
SaslChallenge challenge = new SaslChallenge();
challenge.setChallenge(getChallengeResponse());
writeFrame(challenge);
setChallengeResponse(null);
}
if(_done)
{
org.apache.qpid.proton.amqp.security.SaslOutcome outcome =
new org.apache.qpid.proton.amqp.security.SaslOutcome();
outcome.setCode(SaslCode.values()[_outcome.getCode()]);
if (_outcome == PN_SASL_OK)
{
outcome.setAdditionalData(getChallengeResponse());
}
writeFrame(outcome);
setChallengeResponse(null);
}
}
else if(_role == Role.CLIENT)
{
if(getState() == SaslState.PN_SASL_IDLE && _chosenMechanism != null)
{
processInit();
_state = SaslState.PN_SASL_STEP;
//HACK: if we received an outcome before
//we sent our init, change the state now
if(_outcome != SaslOutcome.PN_SASL_NONE)
{
_state = classifyStateFromOutcome(_outcome);
}
}
if(getState() == SaslState.PN_SASL_STEP && getChallengeResponse() != null)
{
processResponse();
}
}
}
private void writeFrame(SaslFrameBody frameBody)
{
_frameWriter.writeFrame(frameBody);
}
@Override
final public int recv(byte[] bytes, int offset, int size)
{
if(_pending == null)
{
return -1;
}
final int written = pourBufferToArray(_pending, bytes, offset, size);
if(!_pending.hasRemaining())
{
_pending = null;
}
return written;
}
@Override
final public int send(byte[] bytes, int offset, int size)
{
byte[] data = new byte[size];
System.arraycopy(bytes, offset, data, 0, size);
setChallengeResponse(new Binary(data));
return size;
}
final int processHeader()
{
if(!_headerWritten)
{
logHeader();
_frameWriter.writeHeader(AmqpHeader.SASL_HEADER);
_headerWritten = true;
return AmqpHeader.SASL_HEADER.length;
}
else
{
return 0;
}
}
private void logHeader()
{
if (_transport.isFrameTracingEnabled())
{
_transport.log(TransportImpl.OUTGOING, HEADER_DESCRIPTION);
ProtocolTracer tracer = _transport.getProtocolTracer();
if (tracer != null)
{
tracer.sentHeader(HEADER_DESCRIPTION);
}
}
}
@Override
public int pending()
{
return _pending == null ? 0 : _pending.remaining();
}
void setPending(ByteBuffer pending)
{
_pending = pending;
}
@Override
public SaslState getState()
{
return _state;
}
final Binary getChallengeResponse()
{
return _challengeResponse;
}
final void setChallengeResponse(Binary challengeResponse)
{
_challengeResponse = challengeResponse;
}
@Override
public void setMechanisms(String... mechanisms)
{
if(mechanisms != null)
{
_mechanisms = new Symbol[mechanisms.length];
for(int i = 0; i < mechanisms.length; i++)
{
_mechanisms[i] = Symbol.valueOf(mechanisms[i]);
}
}
if(_role == Role.CLIENT)
{
assert mechanisms != null;
assert mechanisms.length == 1;
_chosenMechanism = Symbol.valueOf(mechanisms[0]);
}
}
@Override
public String[] getRemoteMechanisms()
{
if(_role == Role.SERVER)
{
return _chosenMechanism == null ? new String[0] : new String[] { _chosenMechanism.toString() };
}
else if(_role == Role.CLIENT)
{
if(_mechanisms == null)
{
return new String[0];
}
else
{
String[] remoteMechanisms = new String[_mechanisms.length];
for(int i = 0; i < _mechanisms.length; i++)
{
remoteMechanisms[i] = _mechanisms[i].toString();
}
return remoteMechanisms;
}
}
else
{
throw new IllegalStateException();
}
}
public void setMechanism(Symbol mechanism)
{
_chosenMechanism = mechanism;
}
public Symbol getChosenMechanism()
{
return _chosenMechanism;
}
public void setResponse(Binary initialResponse)
{
setPending(initialResponse.asByteBuffer());
}
@Override
public void handle(SaslFrameBody frameBody, Binary payload)
{
_transport.log(TransportImpl.INCOMING, frameBody);
ProtocolTracer tracer = _transport.getProtocolTracer();
if( tracer != null )
{
tracer.receivedSaslBody(frameBody);
}
frameBody.invoke(this, payload, null);
}
@Override
public void handleInit(SaslInit saslInit, Binary payload, Void context)
{
if(_role == null)
{
server();
}
checkRole(Role.SERVER);
_hostname = saslInit.getHostname();
_chosenMechanism = saslInit.getMechanism();
_initReceived = true;
if(saslInit.getInitialResponse() != null)
{
setPending(saslInit.getInitialResponse().asByteBuffer());
}
if(_saslListener != null) {
_saslListener.onSaslInit(this, _transport);
}
}
@Override
public void handleResponse(SaslResponse saslResponse, Binary payload, Void context)
{
checkRole(Role.SERVER);
setPending(saslResponse.getResponse() == null ? null : saslResponse.getResponse().asByteBuffer());
if(_saslListener != null) {
_saslListener.onSaslResponse(this, _transport);
}
}
@Override
public void done(SaslOutcome outcome)
{
checkRole(Role.SERVER);
_outcome = outcome;
_done = true;
_state = classifyStateFromOutcome(outcome);
_logger.fine("SASL negotiation done: " + this);
}
private void checkRole(Role role)
{
if(role != _role)
{
throw new IllegalStateException("Role is " + _role + " but should be " + role);
}
}
@Override
public void handleMechanisms(SaslMechanisms saslMechanisms, Binary payload, Void context)
{
if(_role == null)
{
client();
}
checkRole(Role.CLIENT);
_mechanisms = saslMechanisms.getSaslServerMechanisms();
if(_saslListener != null) {
_saslListener.onSaslMechanisms(this, _transport);
}
}
@Override
public void handleChallenge(SaslChallenge saslChallenge, Binary payload, Void context)
{
checkRole(Role.CLIENT);
setPending(saslChallenge.getChallenge() == null ? null : saslChallenge.getChallenge().asByteBuffer());
if(_saslListener != null) {
_saslListener.onSaslChallenge(this, _transport);
}
}
@Override
public void handleOutcome(org.apache.qpid.proton.amqp.security.SaslOutcome saslOutcome,
Binary payload,
Void context)
{
checkRole(Role.CLIENT);
for(SaslOutcome outcome : SaslOutcome.values())
{
setPending(saslOutcome.getAdditionalData() == null ? null : saslOutcome.getAdditionalData().asByteBuffer());
if(outcome.getCode() == saslOutcome.getCode().ordinal())
{
_outcome = outcome;
if (_state != SaslState.PN_SASL_IDLE)
{
_state = classifyStateFromOutcome(outcome);
}
break;
}
}
_done = true;
if(_logger.isLoggable(Level.FINE))
{
_logger.fine("Handled outcome: " + this);
}
if(_saslListener != null) {
_saslListener.onSaslOutcome(this, _transport);
}
}
private SaslState classifyStateFromOutcome(SaslOutcome outcome)
{
return outcome == SaslOutcome.PN_SASL_OK ? SaslState.PN_SASL_PASS : SaslState.PN_SASL_FAIL;
}
private void processResponse()
{
SaslResponse response = new SaslResponse();
response.setResponse(getChallengeResponse());
setChallengeResponse(null);
writeFrame(response);
}
private void processInit()
{
SaslInit init = new SaslInit();
init.setHostname(_hostname);
init.setMechanism(_chosenMechanism);
if(getChallengeResponse() != null)
{
init.setInitialResponse(getChallengeResponse());
setChallengeResponse(null);
}
_initSent = true;
writeFrame(init);
}
@Override
public void plain(String username, String password)
{
client();
_chosenMechanism = Symbol.valueOf("PLAIN");
byte[] usernameBytes = username.getBytes(StandardCharsets.UTF_8);
byte[] passwordBytes = password.getBytes(StandardCharsets.UTF_8);
byte[] data = new byte[usernameBytes.length+passwordBytes.length+2];
System.arraycopy(usernameBytes, 0, data, 1, usernameBytes.length);
System.arraycopy(passwordBytes, 0, data, 2+usernameBytes.length, passwordBytes.length);
setChallengeResponse(new Binary(data));
}
@Override
public SaslOutcome getOutcome()
{
return _outcome;
}
@Override
public void client()
{
_role = Role.CLIENT;
if(_mechanisms != null)
{
assert _mechanisms.length == 1;
_chosenMechanism = _mechanisms[0];
}
}
@Override
public void server()
{
_role = Role.SERVER;
}
@Override
public void allowSkip(boolean allowSkip)
{
_allowSkip = allowSkip;
}
@Override
public TransportWrapper wrap(final TransportInput input, final TransportOutput output)
{
return new SaslSniffer(new SwitchingSaslTransportWrapper(input, output),
new PlainTransportWrapper(output, input)) {
@Override
protected boolean isDeterminationMade() {
if (_role == Role.SERVER && _allowSkip) {
return super.isDeterminationMade();
} else {
_selectedTransportWrapper = _wrapper1;
return true;
}
}
};
}
@Override
public String toString()
{
StringBuilder builder = new StringBuilder();
builder
.append("SaslImpl [_outcome=").append(_outcome)
.append(", state=").append(_state)
.append(", done=").append(_done)
.append(", role=").append(_role)
.append("]");
return builder.toString();
}
private class SaslTransportWrapper implements TransportWrapper
{
private final TransportInput _underlyingInput;
private final TransportOutput _underlyingOutput;
private boolean _outputComplete;
private final ByteBuffer _outputBuffer;
private final ByteBuffer _inputBuffer;
private final ByteBuffer _head;
private final SwitchingSaslTransportWrapper _parent;
private SaslTransportWrapper(SwitchingSaslTransportWrapper parent, TransportInput input, TransportOutput output)
{
_underlyingInput = input;
_underlyingOutput = output;
_inputBuffer = newWriteableBuffer(_maxFrameSize);
_outputBuffer = newWriteableBuffer(_maxFrameSize);
_parent = parent;
if (_transport.isUseReadOnlyOutputBuffer()) {
_head = _outputBuffer.asReadOnlyBuffer();
} else {
_head = _outputBuffer.duplicate();
}
_head.limit(0);
}
private void fillOutputBuffer()
{
if(isOutputInSaslMode())
{
writeSaslOutput();
if(_done)
{
_outputComplete = true;
}
}
}
/**
* TODO rationalise this method with respect to the other similar checks of _role/_initReceived etc
* @see SaslImpl#isDone()
*/
private boolean isInputInSaslMode()
{
return _role == null || (_role == Role.CLIENT && !_done) || (_role == Role.SERVER && (!_initReceived || !_done));
}
private boolean isOutputInSaslMode()
{
return _role == null || (_role == Role.CLIENT && (!_done || !_initSent)) || (_role == Role.SERVER && !_outputComplete);
}
@Override
public int capacity()
{
if (_tail_closed) return Transport.END_OF_STREAM;
if (isInputInSaslMode())
{
return _inputBuffer.remaining();
}
else
{
return _underlyingInput.capacity();
}
}
@Override
public int position()
{
if (_tail_closed) return Transport.END_OF_STREAM;
if (isInputInSaslMode())
{
return _inputBuffer.position();
}
else
{
return _underlyingInput.position();
}
}
@Override
public ByteBuffer tail()
{
if (!isInputInSaslMode())
{
return _underlyingInput.tail();
}
return _inputBuffer;
}
@Override
public void process() throws TransportException
{
_inputBuffer.flip();
try
{
reallyProcessInput();
}
finally
{
_inputBuffer.compact();
}
}
@Override
public void close_tail()
{
_tail_closed = true;
if (isInputInSaslMode()) {
_head_closed = true;
_underlyingInput.close_tail();
} else {
_underlyingInput.close_tail();
}
}
private void reallyProcessInput() throws TransportException
{
if(isInputInSaslMode())
{
if(_logger.isLoggable(Level.FINER))
{
_logger.log(Level.FINER, SaslImpl.this + " about to call input.");
}
_frameParser.input(_inputBuffer);
}
if(!isInputInSaslMode())
{
if(_logger.isLoggable(Level.FINER))
{
_logger.log(Level.FINER, SaslImpl.this + " about to call plain input");
}
if (_inputBuffer.hasRemaining())
{
int bytes = pourAll(_inputBuffer, _underlyingInput);
if (bytes == Transport.END_OF_STREAM)
{
_tail_closed = true;
}
if (!_inputBuffer.hasRemaining())
{
_parent.switchToNextInput();
}
}
else
{
_parent.switchToNextInput();
}
_underlyingInput.process();
}
}
@Override
public int pending()
{
if (isOutputInSaslMode() || _outputBuffer.position() != 0)
{
fillOutputBuffer();
_head.limit(_outputBuffer.position());
if (_head_closed && _outputBuffer.position() == 0)
{
return Transport.END_OF_STREAM;
}
else
{
return _outputBuffer.position();
}
}
else
{
_parent.switchToNextOutput();
return _underlyingOutput.pending();
}
}
@Override
public ByteBuffer head()
{
if (isOutputInSaslMode() || _outputBuffer.position() != 0)
{
pending();
return _head;
}
else
{
_parent.switchToNextOutput();
return _underlyingOutput.head();
}
}
@Override
public void pop(int bytes)
{
if (isOutputInSaslMode() || _outputBuffer.position() != 0)
{
_outputBuffer.flip();
_outputBuffer.position(bytes);
_outputBuffer.compact();
_head.position(0);
_head.limit(_outputBuffer.position());
}
else
{
_parent.switchToNextOutput();
_underlyingOutput.pop(bytes);
}
}
@Override
public void close_head()
{
_parent.switchToNextOutput();
_underlyingOutput.close_head();
}
private void writeSaslOutput()
{
SaslImpl.this.process();
_frameWriter.readBytes(_outputBuffer);
if(_logger.isLoggable(Level.FINER))
{
_logger.log(Level.FINER, "Finished writing SASL output. Output Buffer : " + _outputBuffer);
}
}
}
private class SwitchingSaslTransportWrapper implements TransportWrapper {
private final TransportInput _underlyingInput;
private final TransportOutput _underlyingOutput;
private TransportInput currentInput;
private TransportOutput currentOutput;
private SwitchingSaslTransportWrapper(TransportInput input, TransportOutput output) {
_underlyingInput = input;
_underlyingOutput = output;
// The wrapper can be GC'd after both current's are switched to next.
SaslTransportWrapper saslProcessor = new SaslTransportWrapper(this, input, output);
currentInput = saslProcessor;
currentOutput = saslProcessor;
}
@Override
public int capacity() {
return currentInput.capacity();
}
@Override
public int position() {
return currentInput.position();
}
@Override
public ByteBuffer tail() throws TransportException {
return currentInput.tail();
}
@Override
public void process() throws TransportException {
currentInput.process();
}
@Override
public void close_tail() {
currentInput.close_tail();
}
@Override
public int pending() {
return currentOutput.pending();
}
@Override
public ByteBuffer head() {
return currentOutput.head();
}
@Override
public void pop(int bytes) {
currentOutput.pop(bytes);
}
@Override
public void close_head() {
currentOutput.close_head();
}
void switchToNextInput() {
currentInput = _underlyingInput;
}
void switchToNextOutput() {
currentOutput = _underlyingOutput;
}
}
@Override
public String getHostname()
{
if(_role != null)
{
checkRole(Role.SERVER);
}
return _hostname;
}
@Override
public void setRemoteHostname(String hostname)
{
if(_role != null)
{
checkRole(Role.CLIENT);
}
_hostname = hostname;
}
@Override
public void setListener(SaslListener saslListener) {
_saslListener = saslListener;
}
}