blob: ff5cb98e56595f20ade1e93f9ea8b43b2916be96 [file]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <gtest/gtest.h>
#include <pulsar/Client.h>
#include <pulsar/ConsumerCryptoFailureAction.h>
#include <pulsar/MessageBatch.h>
#include <optional>
#include <stdexcept>
#include "lib/CompressionCodec.h"
#include "lib/MessageCrypto.h"
#include "lib/SharedBuffer.h"
static std::string lookupUrl = "pulsar://localhost:6650";
using namespace pulsar;
static CryptoKeyReaderPtr getDefaultCryptoKeyReader() {
return std::make_shared<DefaultCryptoKeyReader>(TEST_CONF_DIR "/public-key.client-rsa.pem",
TEST_CONF_DIR "/private-key.client-rsa.pem");
}
static std::vector<std::string> decryptValue(const char* data, size_t length,
std::optional<const EncryptionContext*> context) {
if (!context.has_value()) {
return {std::string(data, length)};
}
if (!context.value()->isDecryptionFailed()) {
return {std::string(data, length)};
}
MessageCrypto crypto{"test", false};
SharedBuffer decryptedPayload;
auto originalPayload = SharedBuffer::copy(data, length);
if (!crypto.decrypt(*context.value(), originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) {
throw std::runtime_error("Decryption failed");
}
SharedBuffer uncompressedPayload;
if (!CompressionCodecProvider::getCodec(context.value()->compressionType())
.decode(decryptedPayload, context.value()->uncompressedMessageSize(), uncompressedPayload)) {
throw std::runtime_error("Decompression failed");
}
std::vector<std::string> values;
if (auto batchSize = context.value()->batchSize(); batchSize > 0) {
MessageBatch batch;
for (auto&& msg : batch.parseFrom(uncompressedPayload, batchSize).messages()) {
values.emplace_back(msg.getDataAsString());
}
} else {
// non-batched message
values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes());
}
return values;
}
static void testDecryption(Client& client, const std::string& topic, bool withDecryption,
int numMessageReceived) {
ProducerConfiguration producerConf;
producerConf.setCompressionType(CompressionLZ4);
producerConf.addEncryptionKey("client-rsa.pem");
producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
Producer producer;
ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer));
std::vector<std::string> sentValues;
auto send = [&producer, &sentValues](const std::string& value) {
Message msg = MessageBuilder().setContent(value).build();
producer.sendAsync(msg, nullptr);
sentValues.emplace_back(value);
};
for (int i = 0; i < 5; i++) {
send("msg-" + std::to_string(i));
}
producer.flush();
send("last-msg");
producer.flush();
ASSERT_EQ(ResultOk, client.createProducer(topic, producer));
send("unencrypted-msg");
producer.flush();
producer.close();
ConsumerConfiguration consumerConf;
consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest);
if (withDecryption) {
consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader());
} else {
consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME);
}
Consumer consumer;
ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer));
std::vector<std::string> values;
for (int i = 0; i < numMessageReceived; i++) {
Message msg;
ASSERT_EQ(ResultOk, consumer.receive(msg, 3000));
if (i < numMessageReceived - 1) {
ASSERT_TRUE(msg.getEncryptionContext().has_value());
} else {
ASSERT_FALSE(msg.getEncryptionContext().has_value());
}
for (auto&& value : decryptValue(static_cast<const char*>(msg.getData()), msg.getLength(),
msg.getEncryptionContext())) {
values.emplace_back(value);
}
}
ASSERT_EQ(values, sentValues);
consumer.close();
}
TEST(EncryptionTests, testDecryptionSuccess) {
Client client{lookupUrl};
std::string topic = "test-decryption-success-" + std::to_string(time(nullptr));
testDecryption(client, topic, true, 7);
client.close();
}
TEST(EncryptionTests, testDecryptionFailure) {
Client client{lookupUrl};
std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr));
// The 1st batch that has 5 messages cannot be decrypted, so they can be received only once
testDecryption(client, topic, false, 3);
client.close();
}