| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| #include <gtest/gtest.h> |
| #include <pulsar/Client.h> |
| #include <pulsar/ConsumerCryptoFailureAction.h> |
| #include <pulsar/MessageBatch.h> |
| |
| #include <optional> |
| #include <stdexcept> |
| |
| #include "lib/CompressionCodec.h" |
| #include "lib/MessageCrypto.h" |
| #include "lib/SharedBuffer.h" |
| |
| static std::string lookupUrl = "pulsar://localhost:6650"; |
| |
| using namespace pulsar; |
| |
| static CryptoKeyReaderPtr getDefaultCryptoKeyReader() { |
| return std::make_shared<DefaultCryptoKeyReader>(TEST_CONF_DIR "/public-key.client-rsa.pem", |
| TEST_CONF_DIR "/private-key.client-rsa.pem"); |
| } |
| |
| static std::vector<std::string> decryptValue(const char* data, size_t length, |
| std::optional<const EncryptionContext*> context) { |
| if (!context.has_value()) { |
| return {std::string(data, length)}; |
| } |
| if (!context.value()->isDecryptionFailed()) { |
| return {std::string(data, length)}; |
| } |
| |
| MessageCrypto crypto{"test", false}; |
| SharedBuffer decryptedPayload; |
| auto originalPayload = SharedBuffer::copy(data, length); |
| if (!crypto.decrypt(*context.value(), originalPayload, getDefaultCryptoKeyReader(), decryptedPayload)) { |
| throw std::runtime_error("Decryption failed"); |
| } |
| |
| SharedBuffer uncompressedPayload; |
| if (!CompressionCodecProvider::getCodec(context.value()->compressionType()) |
| .decode(decryptedPayload, context.value()->uncompressedMessageSize(), uncompressedPayload)) { |
| throw std::runtime_error("Decompression failed"); |
| } |
| |
| std::vector<std::string> values; |
| if (auto batchSize = context.value()->batchSize(); batchSize > 0) { |
| MessageBatch batch; |
| for (auto&& msg : batch.parseFrom(uncompressedPayload, batchSize).messages()) { |
| values.emplace_back(msg.getDataAsString()); |
| } |
| } else { |
| // non-batched message |
| values.emplace_back(uncompressedPayload.data(), uncompressedPayload.readableBytes()); |
| } |
| return values; |
| } |
| |
| static void testDecryption(Client& client, const std::string& topic, bool withDecryption, |
| int numMessageReceived) { |
| ProducerConfiguration producerConf; |
| producerConf.setCompressionType(CompressionLZ4); |
| producerConf.addEncryptionKey("client-rsa.pem"); |
| producerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); |
| |
| Producer producer; |
| ASSERT_EQ(ResultOk, client.createProducer(topic, producerConf, producer)); |
| |
| std::vector<std::string> sentValues; |
| auto send = [&producer, &sentValues](const std::string& value) { |
| Message msg = MessageBuilder().setContent(value).build(); |
| producer.sendAsync(msg, nullptr); |
| sentValues.emplace_back(value); |
| }; |
| |
| for (int i = 0; i < 5; i++) { |
| send("msg-" + std::to_string(i)); |
| } |
| producer.flush(); |
| send("last-msg"); |
| producer.flush(); |
| |
| ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); |
| send("unencrypted-msg"); |
| producer.flush(); |
| producer.close(); |
| |
| ConsumerConfiguration consumerConf; |
| consumerConf.setSubscriptionInitialPosition(InitialPositionEarliest); |
| if (withDecryption) { |
| consumerConf.setCryptoKeyReader(getDefaultCryptoKeyReader()); |
| } else { |
| consumerConf.setCryptoFailureAction(ConsumerCryptoFailureAction::CONSUME); |
| } |
| Consumer consumer; |
| ASSERT_EQ(ResultOk, client.subscribe(topic, "sub", consumerConf, consumer)); |
| |
| std::vector<std::string> values; |
| for (int i = 0; i < numMessageReceived; i++) { |
| Message msg; |
| ASSERT_EQ(ResultOk, consumer.receive(msg, 3000)); |
| if (i < numMessageReceived - 1) { |
| ASSERT_TRUE(msg.getEncryptionContext().has_value()); |
| } else { |
| ASSERT_FALSE(msg.getEncryptionContext().has_value()); |
| } |
| for (auto&& value : decryptValue(static_cast<const char*>(msg.getData()), msg.getLength(), |
| msg.getEncryptionContext())) { |
| values.emplace_back(value); |
| } |
| } |
| ASSERT_EQ(values, sentValues); |
| consumer.close(); |
| } |
| |
| TEST(EncryptionTests, testDecryptionSuccess) { |
| Client client{lookupUrl}; |
| std::string topic = "test-decryption-success-" + std::to_string(time(nullptr)); |
| testDecryption(client, topic, true, 7); |
| client.close(); |
| } |
| |
| TEST(EncryptionTests, testDecryptionFailure) { |
| Client client{lookupUrl}; |
| std::string topic = "test-decryption-failure-" + std::to_string(time(nullptr)); |
| // The 1st batch that has 5 messages cannot be decrypted, so they can be received only once |
| testDecryption(client, topic, false, 3); |
| client.close(); |
| } |