blob: ec507af33a95f6e76feec90e4b7efff557abd3d5 [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RunLlamaCppInference.h"
#include "api/core/Resource.h"
#include "minifi-cpp/core/FlowFile.h"
#include "unit/Catch.h"
#include "unit/SingleProcessorTestController.h"
#include "unit/TestBase.h"
#include "utils/CProcessor.h"
#include "unit/TestUtils.h"
#include "CProcessorTestUtils.h"
namespace org::apache::nifi::minifi::extensions::llamacpp::test {
class MockLlamaContext : public processors::LlamaContext {
public:
std::optional<std::string> applyTemplate(const std::vector<processors::LlamaChatMessage>& messages) override {
if (fail_apply_template_) {
return std::nullopt;
}
messages_ = messages;
return "Test input";
}
nonstd::expected<processors::GenerationResult, std::string> generate(const std::string& input, std::function<void(std::string_view/*token*/)> token_handler) override {
if (fail_generation_) {
return nonstd::make_unexpected("Generation failed");
}
processors::GenerationResult result;
input_ = input;
token_handler("Test ");
token_handler("generated");
token_handler(" content");
result.time_to_first_token = std::chrono::milliseconds(100);
result.num_tokens_in = 10;
result.num_tokens_out = 3;
result.tokens_per_second = 2.0;
return result;
}
[[nodiscard]] const std::vector<processors::LlamaChatMessage>& getMessages() const {
return messages_;
}
[[nodiscard]] const std::string& getInput() const {
return input_;
}
void setGenerationFailure() {
fail_generation_ = true;
}
void setApplyTemplateFailure() {
fail_apply_template_ = true;
}
private:
bool fail_generation_{false};
bool fail_apply_template_{false};
std::vector<processors::LlamaChatMessage> messages_;
std::string input_;
};
TEST_CASE("Prompt is generated correctly with default parameters") {
auto mock_llama_context = std::make_unique<MockLlamaContext>();
auto mock_llama_context_ptr = mock_llama_context.get();
std::filesystem::path test_model_path;
processors::LlamaSamplerParams test_sampler_params;
processors::LlamaContextParams test_context_params;
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) {
test_model_path = model_path;
test_sampler_params = sampler_params;
test_context_params = context_params;
return std::move(mock_llama_context);
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
CHECK(test_model_path == "Dummy model");
CHECK(test_sampler_params.temperature == 0.8F);
CHECK(test_sampler_params.top_k == 40);
CHECK(test_sampler_params.top_p == 0.9F);
CHECK(test_sampler_params.min_p == std::nullopt);
CHECK(test_sampler_params.min_keep == 0);
CHECK(test_context_params.n_ctx == 4096);
CHECK(test_context_params.n_batch == 2048);
CHECK(test_context_params.n_ubatch == 512);
CHECK(test_context_params.n_seq_max == 1);
CHECK(test_context_params.n_threads == 4);
CHECK(test_context_params.n_threads_batch == 4);
REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0];
CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTimeToFirstToken.name) == "100 ms");
CHECK(*output_flow_file->getAttribute(processors::RunLlamaCppInference::LlamaCppTokensPerSecond.name) == "2.00");
CHECK(controller.plan->getContent(output_flow_file) == "Test generated content");
CHECK(mock_llama_context_ptr->getInput() == "Test input");
REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. "
"You are expected to generate a response based on the question and the input data.");
CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?");
}
TEST_CASE("Prompt is generated correctly with custom parameters") {
auto mock_llama_context = std::make_unique<MockLlamaContext>();
auto mock_llama_context_ptr = mock_llama_context.get();
std::filesystem::path test_model_path;
processors::LlamaSamplerParams test_sampler_params;
processors::LlamaContextParams test_context_params;
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path& model_path, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams& context_params) {
test_model_path = model_path;
test_sampler_params = sampler_params;
test_context_params = context_params;
return std::move(mock_llama_context);
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "0.4"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "20"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, ""));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "0.1"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinKeep.name, "1"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TextContextSize.name, "4096"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::LogicalMaximumBatchSize.name, "1024"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::PhysicalMaximumBatchSize.name, "796"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MaxNumberOfSequences.name, "2"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForGeneration.name, "12"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ThreadsForBatchProcessing.name, "8"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, "Whatever"));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
CHECK(test_model_path == "/path/to/model");
CHECK(test_sampler_params.temperature == 0.4F);
CHECK(test_sampler_params.top_k == 20);
CHECK(test_sampler_params.top_p == std::nullopt);
CHECK(test_sampler_params.min_p == 0.1F);
CHECK(test_sampler_params.min_keep == 1);
CHECK(test_context_params.n_ctx == 4096);
CHECK(test_context_params.n_batch == 1024);
CHECK(test_context_params.n_ubatch == 796);
CHECK(test_context_params.n_seq_max == 2);
CHECK(test_context_params.n_threads == 12);
CHECK(test_context_params.n_threads_batch == 8);
REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0];
CHECK(controller.plan->getContent(output_flow_file) == "Test generated content");
CHECK(mock_llama_context_ptr->getInput() == "Test input");
REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
CHECK(mock_llama_context_ptr->getMessages()[0].content == "Whatever");
CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
CHECK(mock_llama_context_ptr->getMessages()[1].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?");
}
TEST_CASE("Empty flow file does not include input data in prompt") {
auto mock_llama_context = std::make_unique<MockLlamaContext>();
auto mock_llama_context_ptr = mock_llama_context.get();
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::move(mock_llama_context);
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}});
REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0];
CHECK(controller.plan->getContent(output_flow_file) == "Test generated content");
CHECK(mock_llama_context_ptr->getInput() == "Test input");
REQUIRE(mock_llama_context_ptr->getMessages().size() == 2);
CHECK(mock_llama_context_ptr->getMessages()[0].role == "system");
CHECK(mock_llama_context_ptr->getMessages()[0].content == "You are a helpful assistant. You are given a question with some possible input data otherwise called flow file content. "
"You are expected to generate a response based on the question and the input data.");
CHECK(mock_llama_context_ptr->getMessages()[1].role == "user");
CHECK(mock_llama_context_ptr->getMessages()[1].content == "Question: What is the answer to life, the universe and everything?");
}
TEST_CASE("Invalid values for optional double type properties throw exception") {
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::make_unique<MockLlamaContext>();
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
std::string property_name;
SECTION("Invalid value for Temperature property") {
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Temperature.name, "invalid_value"));
property_name = processors::RunLlamaCppInference::Temperature.name;
}
SECTION("Invalid value for Top P property") {
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopP.name, "invalid_value"));
property_name = processors::RunLlamaCppInference::TopP.name;
}
SECTION("Invalid value for Min P property") {
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::MinP.name, "invalid_value"));
property_name = processors::RunLlamaCppInference::MinP.name;
}
REQUIRE_THROWS(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}));
CHECK(minifi::test::utils::verifyLogLinePresenceInPollTime(1s,
fmt::format("Expected parsable float from RunLlamaCppInference::{}, but got GeneralParsingError (Parsing Error:0)", property_name)));
}
TEST_CASE("Top K property empty and invalid values are handled properly") {
std::optional<int32_t> test_top_k = 0;
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams& sampler_params, const processors::LlamaContextParams&) {
test_top_k = sampler_params.top_k;
return std::make_unique<MockLlamaContext>();
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
SECTION("Empty value for Top K property") {
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, ""));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
REQUIRE(test_top_k == std::nullopt);
}
SECTION("Invalid value for Top K property") {
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::TopK.name, "invalid_value"));
REQUIRE_THROWS(controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}}));
CHECK(minifi::test::utils::verifyLogLinePresenceInPollTime(1s,
"Expected parsable int64_t from \"RunLlamaCppInference::Top K\", but got GeneralParsingError (Parsing Error:0)"));
}
}
TEST_CASE("Error handling during generation and applying template") {
auto mock_llama_context = std::make_unique<MockLlamaContext>();
SECTION("Generation fails with error") {
mock_llama_context->setGenerationFailure();
}
SECTION("Applying template fails with error") {
mock_llama_context->setApplyTemplateFailure();
}
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::move(mock_llama_context);
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty());
REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0];
CHECK(controller.plan->getContent(output_flow_file) == "42");
}
TEST_CASE("Route flow file to failure when prompt and input data is empty") {
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::make_unique<MockLlamaContext>();
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "/path/to/model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, ""));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "", .attributes = {}});
REQUIRE(results.at(processors::RunLlamaCppInference::Success).empty());
REQUIRE(results.at(processors::RunLlamaCppInference::Failure).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Failure)[0];
CHECK(controller.plan->getContent(output_flow_file).empty());
}
TEST_CASE("System prompt is optional") {
auto mock_llama_context = std::make_unique<MockLlamaContext>();
auto mock_llama_context_ptr = mock_llama_context.get();
minifi::test::SingleProcessorTestController controller(minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::move(mock_llama_context);
}));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::SystemPrompt.name, ""));
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
auto& output_flow_file = results.at(processors::RunLlamaCppInference::Success)[0];
CHECK(controller.plan->getContent(output_flow_file) == "Test generated content");
CHECK(mock_llama_context_ptr->getInput() == "Test input");
REQUIRE(mock_llama_context_ptr->getMessages().size() == 1);
CHECK(mock_llama_context_ptr->getMessages()[0].role == "user");
CHECK(mock_llama_context_ptr->getMessages()[0].content == "Input data (or flow file content):\n42\n\nQuestion: What is the answer to life, the universe and everything?");
}
TEST_CASE("Test output metrics") {
auto processor = minifi::test::utils::make_custom_c_processor<processors::RunLlamaCppInference>(
core::ProcessorMetadata{utils::Identifier{}, "RunLlamaCppInference", logging::LoggerFactory<processors::RunLlamaCppInference>::getLogger()},
[&](const std::filesystem::path&, const processors::LlamaSamplerParams&, const processors::LlamaContextParams&) {
return std::make_unique<MockLlamaContext>();
});
auto processor_metrics = processor->getMetrics();
minifi::test::SingleProcessorTestController controller(std::move(processor));
LogTestController::getInstance().setTrace<processors::RunLlamaCppInference>();
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::ModelPath.name, "Dummy model"));
REQUIRE(controller.getProcessor()->setProperty(processors::RunLlamaCppInference::Prompt.name, "Question: What is the answer to life, the universe and everything?"));
controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
auto results = controller.trigger(minifi::test::InputFlowFileData{.content = "42", .attributes = {}});
REQUIRE(results.at(processors::RunLlamaCppInference::Success).size() == 1);
auto prometheus_metrics = processor_metrics->calculateMetrics();
REQUIRE(prometheus_metrics.size() >= 2);
CHECK(prometheus_metrics[prometheus_metrics.size() - 2].name == "tokens_in");
CHECK(prometheus_metrics[prometheus_metrics.size() - 2].value == 20);
CHECK(prometheus_metrics[prometheus_metrics.size() - 1].name == "tokens_out");
CHECK(prometheus_metrics[prometheus_metrics.size() - 1].value == 6);
auto c2_metrics = processor_metrics->serialize();
REQUIRE_FALSE(c2_metrics.empty());
REQUIRE(c2_metrics[0].children.size() >= 2);
CHECK(c2_metrics[0].children[c2_metrics[0].children.size() - 2].name == "TokensIn");
CHECK(c2_metrics[0].children[c2_metrics[0].children.size() - 2].value.to_string() == "20");
CHECK(c2_metrics[0].children[c2_metrics[0].children.size() - 1].name == "TokensOut");
CHECK(c2_metrics[0].children[c2_metrics[0].children.size() - 1].value.to_string() == "6");
}
} // namespace org::apache::nifi::minifi::extensions::llamacpp::test