blob: c1d2e635ec6b843fb19ae4da6445524b74a3e2ac [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <string_view>
#include "udf/udf.h"
using namespace impala_udf;
namespace impala {
using impala_udf::FunctionContext;
using impala_udf::StringVal;
class AiFunctions {
public:
static const string AI_GENERATE_TXT_JSON_PARSE_ERROR;
static const string AI_GENERATE_TXT_INVALID_PROTOCOL_ERROR;
static const string AI_GENERATE_TXT_UNSUPPORTED_ENDPOINT_ERROR;
static const string AI_GENERATE_TXT_INVALID_PROMPT_ERROR;
static const string AI_GENERATE_TXT_MSG_OVERRIDE_FORBIDDEN_ERROR;
static const string AI_GENERATE_TXT_N_OVERRIDE_FORBIDDEN_ERROR;
static const char* OPEN_AI_REQUEST_FIELD_CONTENT_TYPE_HEADER;
/// Sends a prompt to the input AI endpoint using the input model, api_key and
/// optional params.
static StringVal AiGenerateText(FunctionContext* ctx, const StringVal& endpoint,
const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params);
/// Sends a prompt to the default endpoint and uses the default model, default
/// api-key and default params.
static StringVal AiGenerateTextDefault(FunctionContext* ctx, const StringVal& prompt);
/// Set the ai_api_key_ member.
static void set_api_key(string& api_key) { ai_api_key_ = api_key; }
/// Validate api end point.
static bool is_api_endpoint_valid(const std::string_view& endpoint);
/// Check if endpoint is supported
static bool is_api_endpoint_supported(const std::string_view& endpoint);
private:
/// The default api_key used for communicating with external APIs.
static std::string ai_api_key_;
/// Internal function which implements the logic of parsing user input and sending
/// request to the external API endpoint. If 'dry_run' is set, the POST request is
/// returned. 'dry_run' mode is used only for unit tests.
template <bool fastpath>
static StringVal AiGenerateTextInternal(FunctionContext* ctx, const StringVal& endpoint,
const StringVal& prompt, const StringVal& model,
const StringVal& api_key_jceks_secret, const StringVal& params, const bool dry_run);
/// Internal helper function for parsing OPEN AI's API response. Input parameter is the
/// json representation of the OPEN AI's API response.
static std::string_view AiGenerateTextParseOpenAiResponse(
const std::string_view& reponse);
friend class ExprTest_AiFunctionsTest_Test;
};
} // namespace impala