blob: 5927a0199e60f247f9de53e58373061f85ee45e9 [file]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RunLlamaCppInference.h"
#include "api/core/ProcessContext.h"
#include "api/core/ProcessSession.h"
#include "api/core/Resource.h"
#include "rapidjson/document.h"
#include "rapidjson/error/en.h"
#include "LlamaContext.h"
#include "api/utils/ProcessorConfigUtils.h"
#include "DefaultLlamaContext.h"
namespace org::apache::nifi::minifi::extensions::llamacpp::processors {
MinifiStatus RunLlamaCppInference::onScheduleImpl(api::core::ProcessContext& context) {
model_path_.clear();
model_path_ = api::utils::parseProperty(context, ModelPath);
system_prompt_ = context.getProperty(SystemPrompt).value_or("");
LlamaSamplerParams llama_sampler_params;
llama_sampler_params.temperature = api::utils::parseOptionalFloatProperty(context, Temperature);
if (auto top_k = api::utils::parseOptionalI64Property(context, TopK)) {
llama_sampler_params.top_k = gsl::narrow<int32_t>(*top_k);
}
llama_sampler_params.top_p = api::utils::parseOptionalFloatProperty(context, TopP);
llama_sampler_params.min_p = api::utils::parseOptionalFloatProperty(context, MinP);
llama_sampler_params.min_keep = api::utils::parseU64Property(context, MinKeep);
LlamaContextParams llama_ctx_params;
llama_ctx_params.n_ctx = gsl::narrow<uint32_t>(api::utils::parseU64Property(context, TextContextSize));
llama_ctx_params.n_batch = gsl::narrow<uint32_t>(api::utils::parseU64Property(context, LogicalMaximumBatchSize));
llama_ctx_params.n_ubatch = gsl::narrow<uint32_t>(api::utils::parseU64Property(context, PhysicalMaximumBatchSize));
llama_ctx_params.n_seq_max = gsl::narrow<uint32_t>(api::utils::parseU64Property(context, MaxNumberOfSequences));
llama_ctx_params.n_threads = gsl::narrow<int32_t>(api::utils::parseI64Property(context, ThreadsForGeneration));
llama_ctx_params.n_threads_batch = gsl::narrow<int32_t>(api::utils::parseI64Property(context, ThreadsForBatchProcessing));
if (llama_context_provider_) {
llama_ctx_ = llama_context_provider_(model_path_, llama_sampler_params, llama_ctx_params);
} else {
llama_ctx_ = std::make_unique<DefaultLlamaContext>(model_path_, llama_sampler_params, llama_ctx_params);
}
return MINIFI_STATUS_SUCCESS;
}
void RunLlamaCppInference::increaseTokensIn(uint64_t token_count) {
metrics_.tokens_in += token_count;
}
void RunLlamaCppInference::increaseTokensOut(uint64_t token_count) {
metrics_.tokens_out += token_count;
}
MinifiStatus RunLlamaCppInference::onTriggerImpl(api::core::ProcessContext& context, api::core::ProcessSession& session) {
auto flow_file = session.get();
if (!flow_file) {
return MINIFI_STATUS_PROCESSOR_YIELD;
}
auto prompt = context.getProperty(Prompt, &flow_file).value_or("");
auto read_result = session.readBuffer(flow_file);
std::string input_data_and_prompt;
if (!read_result.empty()) {
input_data_and_prompt.append("Input data (or flow file content):\n");
input_data_and_prompt.append({reinterpret_cast<const char*>(read_result.data()), read_result.size()});
input_data_and_prompt.append("\n\n");
}
input_data_and_prompt.append(prompt);
if (input_data_and_prompt.empty()) {
logger_->log_error("Input data and prompt are empty");
session.transfer(std::move(flow_file), Failure);
return MINIFI_STATUS_SUCCESS;
}
auto input = [&] {
std::vector<LlamaChatMessage> messages;
if (!system_prompt_.empty()) {
messages.push_back({.role = "system", .content = system_prompt_});
}
messages.push_back({.role = "user", .content = input_data_and_prompt});
return llama_ctx_->applyTemplate(messages);
}();
if (!input) {
logger_->log_error("Inference failed with while applying template");
session.transfer(std::move(flow_file), Failure);
return MINIFI_STATUS_SUCCESS;
}
logger_->log_debug("AI model input: {}", *input);
auto start_time = std::chrono::steady_clock::now();
std::string text;
auto generation_result = llama_ctx_->generate(*input, [&] (std::string_view token) {
text += token;
});
auto elapsed_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - start_time).count();
if (!generation_result) {
logger_->log_error("Inference failed with generation error: '{}'", generation_result.error());
session.transfer(std::move(flow_file), Failure);
return MINIFI_STATUS_SUCCESS;
}
increaseTokensIn(generation_result->num_tokens_in);
increaseTokensOut(generation_result->num_tokens_out);
logger_->log_debug("Number of tokens generated: {}", generation_result->num_tokens_out);
logger_->log_debug("AI model inference time: {} ms", elapsed_time);
logger_->log_debug("AI model output: {}", text);
session.setAttribute(flow_file, LlamaCppTimeToFirstToken.name, std::to_string(generation_result->time_to_first_token.count()) + " ms");
session.setAttribute(flow_file, LlamaCppTokensPerSecond.name, fmt::format("{:.2f}", generation_result->tokens_per_second));
session.writeBuffer(flow_file, text);
session.transfer(std::move(flow_file), Success);
return MINIFI_STATUS_SUCCESS;
}
void RunLlamaCppInference::onUnSchedule() {
llama_ctx_.reset();
}
} // namespace org::apache::nifi::minifi::extensions::llamacpp::processors