| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.qpid.jms.provider.amqp; |
| |
| import java.util.function.Function; |
| |
| import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecurityException; |
| import org.apache.qpid.jms.provider.exceptions.ProviderConnectionSecuritySaslException; |
| import org.apache.qpid.jms.sasl.Mechanism; |
| import org.apache.qpid.jms.sasl.SaslSecurityRuntimeException; |
| import org.apache.qpid.proton.engine.Sasl; |
| import org.apache.qpid.proton.engine.Sasl.SaslOutcome; |
| import org.apache.qpid.proton.engine.Transport; |
| |
| /** |
| * Manage the SASL authentication process |
| */ |
| public class AmqpSaslAuthenticator { |
| |
| private final Function<String[], Mechanism> mechanismFinder; |
| |
| private Mechanism mechanism; |
| private boolean complete; |
| private ProviderConnectionSecurityException failureCause; |
| |
| /** |
| * Create the authenticator and initialize it. |
| * |
| * @param mechanismFinder |
| * An object that is used to locate the most correct SASL Mechanism to perform the authentication. |
| */ |
| public AmqpSaslAuthenticator(Function<String[], Mechanism> mechanismFinder) { |
| this.mechanismFinder = mechanismFinder; |
| } |
| |
| public boolean isComplete() { |
| return complete; |
| } |
| |
| public ProviderConnectionSecurityException getFailureCause() { |
| return failureCause; |
| } |
| |
| public boolean wasSuccessful() throws IllegalStateException { |
| if (complete) { |
| return failureCause == null; |
| } else { |
| throw new IllegalStateException("Authentication has not completed yet."); |
| } |
| } |
| |
| //----- SaslListener implementation --------------------------------------// |
| |
| public void handleSaslMechanisms(Sasl sasl, Transport transport) { |
| try { |
| String[] remoteMechanisms = sasl.getRemoteMechanisms(); |
| if (remoteMechanisms != null && remoteMechanisms.length != 0) { |
| try { |
| mechanism = mechanismFinder.apply(remoteMechanisms); |
| } catch (SaslSecurityRuntimeException ssre){ |
| recordFailure("Could not find a suitable SASL mechanism. " + ssre.getMessage(), ssre); |
| return; |
| } |
| |
| byte[] response = mechanism.getInitialResponse(); |
| if (response != null) { |
| sasl.send(response, 0, response.length); |
| } |
| sasl.setMechanisms(mechanism.getName()); |
| } |
| } catch (Throwable error) { |
| recordFailure("Exception while processing SASL init: " + error.getMessage(), error); |
| } |
| } |
| |
| public void handleSaslChallenge(Sasl sasl, Transport transport) { |
| try { |
| if (sasl.pending() >= 0) { |
| byte[] challenge = new byte[sasl.pending()]; |
| sasl.recv(challenge, 0, challenge.length); |
| byte[] response = mechanism.getChallengeResponse(challenge); |
| if (response != null) { |
| sasl.send(response, 0, response.length); |
| } |
| } |
| } catch (Throwable error) { |
| recordFailure("Exception while processing SASL step: " + error.getMessage(), error); |
| } |
| } |
| |
| public void handleSaslOutcome(Sasl sasl, Transport transport) { |
| try { |
| switch (sasl.getState()) { |
| case PN_SASL_FAIL: |
| handleSaslFail(sasl); |
| break; |
| case PN_SASL_PASS: |
| handleSaslCompletion(sasl); |
| break; |
| default: |
| break; |
| } |
| } catch (Throwable error) { |
| recordFailure(error.getMessage(), error); |
| } |
| } |
| |
| //----- Internal support methods -----------------------------------------// |
| |
| private void handleSaslFail(Sasl sasl) { |
| StringBuilder message = new StringBuilder("Client failed to authenticate"); |
| if (mechanism != null) { |
| message.append(" using SASL: ").append(mechanism.getName()); |
| if (mechanism.getAdditionalFailureInformation() != null) { |
| message.append(" (").append(mechanism.getAdditionalFailureInformation()).append(")"); |
| } |
| } |
| |
| SaslOutcome outcome = sasl.getOutcome(); |
| if (outcome.equals(SaslOutcome.PN_SASL_TEMP)) { |
| message.append(", due to temporary system error."); |
| } |
| |
| recordFailure(message.toString(), null, outcome.getCode()); |
| } |
| |
| private void handleSaslCompletion(Sasl sasl) { |
| try { |
| if (sasl.pending() != 0) { |
| byte[] additionalData = new byte[sasl.pending()]; |
| sasl.recv(additionalData, 0, additionalData.length); |
| mechanism.getChallengeResponse(additionalData); |
| } |
| mechanism.verifyCompletion(); |
| complete = true; |
| } catch (Throwable error) { |
| recordFailure("Exception while processing SASL exchange completion: " + error.getMessage(), error); |
| } |
| } |
| |
| private void recordFailure(String message, Throwable cause) { |
| recordFailure(message, cause, SaslOutcome.PN_SASL_NONE.getCode()); |
| } |
| |
| private void recordFailure(String message, Throwable cause, int outcome) { |
| failureCause = new ProviderConnectionSecuritySaslException(message, outcome, cause); |
| complete = true; |
| } |
| } |