blob: e1c15c2c22dd865ef4e95df51752d1713a898e0b [file]
/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
#include "ConnectionContext.h"
#include "qpid/messaging/amqp/Sasl.h"
#include "qpid/messaging/exceptions.h"
#include "qpid/sys/SecurityLayer.h"
#include "qpid/log/Statement.h"
#include "qpid/Sasl.h"
#include "qpid/SaslFactory.h"
#include "qpid/StringUtils.h"
#include <sstream>
namespace qpid {
namespace messaging {
namespace amqp {
Sasl::Sasl(const std::string& id, ConnectionContext& c, const std::string& hostname_)
: qpid::amqp::SaslClient(id), context(c),
sasl(qpid::SaslFactory::getInstance().create(c.username, c.password, c.service, hostname_, c.minSsf, c.maxSsf, false)),
hostname(hostname_), readHeader(true), writeHeader(true), haveOutput(false), state(NONE) {}
Sasl::~Sasl() {}
std::size_t Sasl::decode(const char* buffer, std::size_t size)
{
size_t decoded = 0;
if (readHeader) {
decoded += readProtocolHeader(buffer, size);
readHeader = !decoded;
}
if (state == NONE && decoded < size) {
decoded += read(buffer + decoded, size - decoded);
}
QPID_LOG(trace, id << " Sasl::decode(" << size << "): " << decoded);
return decoded;
}
std::size_t Sasl::encode(char* buffer, std::size_t size)
{
size_t encoded = 0;
if (writeHeader) {
encoded += writeProtocolHeader(buffer, size);
writeHeader = !encoded;
}
if (encoded < size) {
encoded += write(buffer + encoded, size - encoded);
}
haveOutput = (encoded == size);
QPID_LOG(trace, id << " Sasl::encode(" << size << "): " << encoded);
return encoded;
}
bool Sasl::canEncode()
{
QPID_LOG(trace, id << " Sasl::canEncode(): " << writeHeader << " || " << haveOutput);
return writeHeader || haveOutput;
}
void Sasl::mechanisms(const std::string& offered)
{
QPID_LOG_CAT(debug, protocol, id << " Received SASL-MECHANISMS(" << offered << ")");
std::string response;
std::string mechanisms;
if (context.mechanism.size()) {
std::vector<std::string> allowed = split(context.mechanism, " ");
std::vector<std::string> supported = split(offered, " ");
std::stringstream intersection;
for (std::vector<std::string>::const_iterator i = allowed.begin(); i != allowed.end(); ++i) {
if (std::find(supported.begin(), supported.end(), *i) != supported.end()) {
intersection << *i << " ";
}
}
mechanisms = intersection.str();
} else {
mechanisms = offered;
}
try {
if (sasl->start(mechanisms, response, context.getTransportSecuritySettings())) {
init(sasl->getMechanism(), &response, hostname.size() ? &hostname : 0);
} else {
init(sasl->getMechanism(), 0, hostname.size() ? &hostname : 0);
}
haveOutput = true;
context.activateOutput();
} catch (const std::exception& e) {
failed(e.what());
}
}
void Sasl::challenge(const std::string& challenge)
{
QPID_LOG_CAT(debug, protocol, id << " Received SASL-CHALLENGE(" << challenge.size() << " bytes)");
try {
std::string r = sasl->step(challenge);
response(&r);
haveOutput = true;
context.activateOutput();
} catch (const std::exception& e) {
failed(e.what());
}
}
namespace {
const std::string EMPTY;
}
void Sasl::challenge()
{
QPID_LOG_CAT(debug, protocol, id << " Received SASL-CHALLENGE(null)");
try {
std::string r = sasl->step(EMPTY);
response(&r);
} catch (const std::exception& e) {
failed(e.what());
}
}
void Sasl::outcome(uint8_t result, const std::string& extra)
{
QPID_LOG_CAT(debug, protocol, id << " Received SASL-OUTCOME(" << result << ", " << extra << ")");
outcome(result);
}
void Sasl::outcome(uint8_t result)
{
QPID_LOG_CAT(debug, protocol, id << " Received SASL-OUTCOME(" << result << ")");
if (result) state = FAILED;
else state = SUCCEEDED;
securityLayer = sasl->getSecurityLayer(context.maxFrameSize);
if (securityLayer.get()) {
context.initSecurityLayer(*securityLayer);
}
context.activateOutput();
}
bool Sasl::stopReading()
{
return state != NONE;
}
qpid::sys::Codec* Sasl::getSecurityLayer()
{
return securityLayer.get();
}
namespace {
const std::string DEFAULT_ERROR("Authentication failed");
}
bool Sasl::authenticated()
{
switch (state) {
case SUCCEEDED: return true;
case FAILED: throw qpid::messaging::AuthenticationFailure(error.size() ? error : DEFAULT_ERROR);
case NONE: default: return false;
}
}
void Sasl::failed(const std::string& text)
{
QPID_LOG_CAT(info, client, id << " Failure during authentication: " << text);
error = text;
state = FAILED;
}
std::string Sasl::getAuthenticatedUsername()
{
return sasl->getUserId();
}
}}} // namespace qpid::messaging::amqp