| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "vec/columns/column_array.h" |
| #include "vec/functions/ai/ai_classify.h" |
| #include "vec/functions/ai/ai_extract.h" |
| #include "vec/functions/ai/ai_filter.h" |
| #include "vec/functions/ai/ai_fix_grammar.h" |
| #include "vec/functions/ai/ai_generate.h" |
| #include "vec/functions/ai/ai_mask.h" |
| #include "vec/functions/ai/ai_sentiment.h" |
| #include "vec/functions/ai/ai_similarity.h" |
| #include "vec/functions/ai/ai_summarize.h" |
| #include "vec/functions/ai/ai_translate.h" |
| #include "vec/functions/ai/embed.h" |
| #include "vec/functions/simple_function_factory.h" |
| |
| namespace doris::vectorized { |
| Status FunctionAIClassify::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| // Get the text column |
| const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
| StringRef text = text_column.column->get_data_at(row_num); |
| std::string text_str = std::string(text.data, text.size); |
| |
| // Get the labels array column |
| const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
| const auto& [array_column, array_row_num] = |
| check_column_const_set_readability(*labels_column.column, row_num); |
| const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
| if (col_array == nullptr) { |
| return Status::InternalError( |
| "labels argument for {} must be Array(String) or Array(Varchar)", name); |
| } |
| |
| std::vector<std::string> label_values; |
| const auto& data = col_array->get_data(); |
| const auto& offsets = col_array->get_offsets(); |
| size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
| size_t end = offsets[array_row_num]; |
| for (size_t i = start; i < end; ++i) { |
| Field field; |
| data.get(i, field); |
| label_values.emplace_back(field.template get<TYPE_STRING>()); |
| } |
| |
| std::string labels_str = "["; |
| for (size_t i = 0; i < label_values.size(); ++i) { |
| if (i > 0) { |
| labels_str += ", "; |
| } |
| labels_str += "\"" + label_values[i] + "\""; |
| } |
| labels_str += "]"; |
| |
| prompt = "Labels: " + labels_str + "\nText: " + text_str; |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionAIExtract::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| // Get the text column |
| const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
| StringRef text = text_column.column->get_data_at(row_num); |
| std::string text_str = std::string(text.data, text.size); |
| |
| // Get the labels array column |
| const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
| const auto& [array_column, array_row_num] = |
| check_column_const_set_readability(*labels_column.column, row_num); |
| const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
| if (col_array == nullptr) { |
| return Status::InternalError( |
| "labels argument for {} must be Array(String) or Array(Varchar)", name); |
| } |
| |
| std::vector<std::string> label_values; |
| const auto& offsets = col_array->get_offsets(); |
| const auto& data = col_array->get_data(); |
| size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
| size_t end = offsets[array_row_num]; |
| for (size_t i = start; i < end; ++i) { |
| Field field; |
| data.get(i, field); |
| label_values.emplace_back(field.template get<TYPE_STRING>()); |
| } |
| |
| std::string labels_str = "["; |
| for (size_t i = 0; i < label_values.size(); ++i) { |
| if (i > 0) { |
| labels_str += ", "; |
| } |
| labels_str += "\"" + label_values[i] + "\""; |
| } |
| labels_str += "]"; |
| |
| prompt = "Labels: " + labels_str + "\nText: " + text_str; |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionAIGenerate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
| StringRef text_ref = text_column.column->get_data_at(row_num); |
| prompt = std::string(text_ref.data, text_ref.size); |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionAIMask::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| // Get the text column |
| const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
| StringRef text = text_column.column->get_data_at(row_num); |
| std::string text_str = std::string(text.data, text.size); |
| |
| // Get the labels array column |
| const ColumnWithTypeAndName& labels_column = block.get_by_position(arguments[2]); |
| const auto& [array_column, array_row_num] = |
| check_column_const_set_readability(*labels_column.column, row_num); |
| const auto* col_array = check_and_get_column<ColumnArray>(*array_column); |
| if (col_array == nullptr) { |
| return Status::InternalError( |
| "labels argument for {} must be Array(String) or Array(Varchar)", name); |
| } |
| |
| std::vector<std::string> label_values; |
| const auto& offsets = col_array->get_offsets(); |
| const auto& data = col_array->get_data(); |
| size_t start = array_row_num > 0 ? offsets[array_row_num - 1] : 0; |
| size_t end = offsets[array_row_num]; |
| for (size_t i = start; i < end; ++i) { |
| Field field; |
| data.get(i, field); |
| label_values.emplace_back(field.template get<TYPE_STRING>()); |
| } |
| |
| std::string labels_str = "["; |
| for (size_t i = 0; i < label_values.size(); ++i) { |
| if (i > 0) { |
| labels_str += ", "; |
| } |
| labels_str += "\"" + label_values[i] + "\""; |
| } |
| labels_str += "]"; |
| |
| prompt = "Labels: " + labels_str + "\nText: " + text_str; |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionAISimilarity::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| // text1 |
| const ColumnWithTypeAndName& text_column_1 = block.get_by_position(arguments[1]); |
| StringRef text_1 = text_column_1.column.get()->get_data_at(row_num); |
| std::string text_str_1 = std::string(text_1.data, text_1.size); |
| |
| // text2 |
| const ColumnWithTypeAndName& text_column_2 = block.get_by_position(arguments[2]); |
| StringRef text_2 = text_column_2.column.get()->get_data_at(row_num); |
| std::string text_str_2 = std::string(text_2.data, text_2.size); |
| |
| prompt = "Text 1: " + text_str_1 + "\nText 2: " + text_str_2; |
| |
| return Status::OK(); |
| } |
| |
| Status FunctionAITranslate::build_prompt(const Block& block, const ColumnNumbers& arguments, |
| size_t row_num, std::string& prompt) const { |
| // text |
| const ColumnWithTypeAndName& text_column = block.get_by_position(arguments[1]); |
| StringRef text = text_column.column.get()->get_data_at(row_num); |
| std::string text_str = std::string(text.data, text.size); |
| |
| // target language |
| const ColumnWithTypeAndName& lang_column = block.get_by_position(arguments[2]); |
| StringRef lang = lang_column.column.get()->get_data_at(row_num); |
| std::string target_lang = std::string(lang.data, lang.size); |
| |
| prompt = "Translate the following text to " + target_lang + ".\nText: " + text_str; |
| |
| return Status::OK(); |
| } |
| |
| void register_function_ai(SimpleFunctionFactory& factory) { |
| factory.register_function<FunctionEmbed>(); |
| factory.register_function<FunctionAIClassify>(); |
| factory.register_function<FunctionAIExtract>(); |
| factory.register_function<FunctionAIFilter>(); |
| factory.register_function<FunctionAIFixGrammar>(); |
| factory.register_function<FunctionAIGenerate>(); |
| factory.register_function<FunctionAIMask>(); |
| factory.register_function<FunctionAISentiment>(); |
| factory.register_function<FunctionAISimilarity>(); |
| factory.register_function<FunctionAISummarize>(); |
| factory.register_function<FunctionAITranslate>(); |
| } |
| |
| } // namespace doris::vectorized |