blob: 4c5cac5bf621516d5b19851ed9234a8496c5c51a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.qpid.jms.provider.amqp;
import java.util.function.Function;
import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecurityException;
import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecuritySaslException;
import org.apache.qpid.jms.sasl.Mechanism;
import org.apache.qpid.jms.sasl.SaslSecurityRuntimeException;
import org.apache.qpid.proton.engine.Sasl;
import org.apache.qpid.proton.engine.Sasl.SaslOutcome;
import org.apache.qpid.proton.engine.Transport;
/**
* Manage the SASL authentication process
*/
public class AmqpSaslAuthenticator {
private final Function<String[], Mechanism> mechanismFinder;
private Mechanism mechanism;
private boolean complete;
private ProviderConnectionSecurityException failureCause;
/**
* Create the authenticator and initialize it.
*
* @param mechanismFinder
* An object that is used to locate the most correct SASL Mechanism to perform the authentication.
*/
public AmqpSaslAuthenticator(Function<String[], Mechanism> mechanismFinder) {
this.mechanismFinder = mechanismFinder;
}
public boolean isComplete() {
return complete;
}
public ProviderConnectionSecurityException getFailureCause() {
return failureCause;
}
public boolean wasSuccessful() throws IllegalStateException {
if (complete) {
return failureCause == null;
} else {
throw new IllegalStateException("Authentication has not completed yet.");
}
}
//----- SaslListener implementation --------------------------------------//
public void handleSaslMechanisms(Sasl sasl, Transport transport) {
try {
String[] remoteMechanisms = sasl.getRemoteMechanisms();
if (remoteMechanisms != null && remoteMechanisms.length != 0) {
try {
mechanism = mechanismFinder.apply(remoteMechanisms);
} catch (SaslSecurityRuntimeException ssre){
recordFailure("Could not find a suitable SASL mechanism. " + ssre.getMessage(), ssre);
return;
}
byte[] response = mechanism.getInitialResponse();
if (response != null) {
sasl.send(response, 0, response.length);
}
sasl.setMechanisms(mechanism.getName());
}
} catch (Throwable error) {
recordFailure("Exception while processing SASL init: " + error.getMessage(), error);
}
}
public void handleSaslChallenge(Sasl sasl, Transport transport) {
try {
if (sasl.pending() >= 0) {
byte[] challenge = new byte[sasl.pending()];
sasl.recv(challenge, 0, challenge.length);
byte[] response = mechanism.getChallengeResponse(challenge);
if (response != null) {
sasl.send(response, 0, response.length);
}
}
} catch (Throwable error) {
recordFailure("Exception while processing SASL step: " + error.getMessage(), error);
}
}
public void handleSaslOutcome(Sasl sasl, Transport transport) {
try {
switch (sasl.getState()) {
case PN_SASL_FAIL:
handleSaslFail(sasl);
break;
case PN_SASL_PASS:
handleSaslCompletion(sasl);
break;
default:
break;
}
} catch (Throwable error) {
recordFailure(error.getMessage(), error);
}
}
//----- Internal support methods -----------------------------------------//
private void handleSaslFail(Sasl sasl) {
StringBuilder message = new StringBuilder("Client failed to authenticate");
if (mechanism != null) {
message.append(" using SASL: ").append(mechanism.getName());
if (mechanism.getAdditionalFailureInformation() != null) {
message.append(" (").append(mechanism.getAdditionalFailureInformation()).append(")");
}
}
SaslOutcome outcome = sasl.getOutcome();
if (outcome.equals(SaslOutcome.PN_SASL_TEMP)) {
message.append(", due to temporary system error.");
}
recordFailure(message.toString(), null, outcome.getCode());
}
private void handleSaslCompletion(Sasl sasl) {
try {
if (sasl.pending() != 0) {
byte[] additionalData = new byte[sasl.pending()];
sasl.recv(additionalData, 0, additionalData.length);
mechanism.getChallengeResponse(additionalData);
}
mechanism.verifyCompletion();
complete = true;
} catch (Throwable error) {
recordFailure("Exception while processing SASL exchange completion: " + error.getMessage(), error);
}
}
private void recordFailure(String message, Throwable cause) {
recordFailure(message, cause, SaslOutcome.PN_SASL_NONE.getCode());
}
private void recordFailure(String message, Throwable cause, int outcome) {
failureCause = new ProviderConnectionSecuritySaslException(message, outcome, cause);
complete = true;
}
}