| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import json |
| |
| from fastapi import status, APIRouter, HTTPException |
| |
| from hugegraph_llm.api.exceptions.rag_exceptions import generate_response |
| from hugegraph_llm.api.models.rag_requests import ( |
| RAGRequest, |
| GraphConfigRequest, |
| LLMConfigRequest, |
| RerankerConfigRequest, |
| GraphRAGRequest, |
| GremlinGenerateRequest, |
| ) |
| from hugegraph_llm.api.models.rag_response import RAGResponse |
| from hugegraph_llm.config import huge_settings |
| from hugegraph_llm.config import llm_settings, prompt |
| from hugegraph_llm.utils.graph_index_utils import get_vertex_details |
| from hugegraph_llm.utils.log import log |
| |
| |
| # pylint: disable=too-many-statements |
| def rag_http_api( |
| router: APIRouter, |
| rag_answer_func, |
| graph_rag_recall_func, |
| apply_graph_conf, |
| apply_llm_conf, |
| apply_embedding_conf, |
| apply_reranker_conf, |
| gremlin_generate_selective_func, |
| ): |
| @router.post("/rag", status_code=status.HTTP_200_OK) |
| def rag_answer_api(req: RAGRequest): |
| set_graph_config(req) |
| |
| # Basic parameter validation: empty query => 400 |
| if not req.query or not str(req.query).strip(): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Query must not be empty.", |
| ) |
| |
| result = rag_answer_func( |
| text=req.query, |
| raw_answer=req.raw_answer, |
| vector_only_answer=req.vector_only, |
| graph_only_answer=req.graph_only, |
| graph_vector_answer=req.graph_vector_answer, |
| graph_ratio=req.graph_ratio, |
| rerank_method=req.rerank_method, |
| near_neighbor_first=req.near_neighbor_first, |
| gremlin_tmpl_num=req.gremlin_tmpl_num, |
| max_graph_items=req.max_graph_items, |
| topk_return_results=req.topk_return_results, |
| vector_dis_threshold=req.vector_dis_threshold, |
| topk_per_keyword=req.topk_per_keyword, |
| # Keep prompt params in the end |
| custom_related_information=req.custom_priority_info, |
| answer_prompt=req.answer_prompt or prompt.answer_prompt, |
| keywords_extract_prompt=req.keywords_extract_prompt |
| or prompt.keywords_extract_prompt, |
| gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, |
| ) |
| # TODO: we need more info in the response for users to understand the query logic |
| return { |
| "query": req.query, |
| **{ |
| key: value |
| for key, value in zip( |
| ["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], |
| result, |
| ) |
| if getattr(req, key) |
| }, |
| } |
| |
| def set_graph_config(req): |
| if req.client_config: |
| huge_settings.graph_url = req.client_config.url |
| huge_settings.graph_name = req.client_config.graph |
| huge_settings.graph_user = req.client_config.user |
| huge_settings.graph_pwd = req.client_config.pwd |
| huge_settings.graph_space = req.client_config.gs |
| |
| @router.post("/rag/graph", status_code=status.HTTP_200_OK) |
| def graph_rag_recall_api(req: GraphRAGRequest): |
| try: |
| set_graph_config(req) |
| |
| # Basic parameter validation: empty query => 400 |
| if not req.query or not str(req.query).strip(): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Query must not be empty.", |
| ) |
| |
| result = graph_rag_recall_func( |
| query=req.query, |
| max_graph_items=req.max_graph_items, |
| topk_return_results=req.topk_return_results, |
| vector_dis_threshold=req.vector_dis_threshold, |
| topk_per_keyword=req.topk_per_keyword, |
| gremlin_tmpl_num=req.gremlin_tmpl_num, |
| rerank_method=req.rerank_method, |
| near_neighbor_first=req.near_neighbor_first, |
| custom_related_information=req.custom_priority_info, |
| gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, |
| get_vertex_only=req.get_vertex_only, |
| ) |
| |
| if req.get_vertex_only: |
| vertex_details = get_vertex_details(result["match_vids"], result) |
| if vertex_details: |
| result["match_vids"] = vertex_details |
| |
| if isinstance(result, dict): |
| params = [ |
| "query", |
| "keywords", |
| "match_vids", |
| "graph_result_flag", |
| "gremlin", |
| "graph_result", |
| "vertex_degree_list", |
| ] |
| user_result = {key: result[key] for key in params if key in result} |
| return {"graph_recall": user_result} |
| return {"graph_recall": json.dumps(result)} |
| |
| except TypeError as e: |
| log.error("TypeError in graph_rag_recall_api: %s", e) |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) |
| ) from e |
| except Exception as e: |
| log.error("Unexpected error occurred: %s", e) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="An unexpected error occurred.", |
| ) from e |
| |
| @router.post("/config/graph", status_code=status.HTTP_201_CREATED) |
| def graph_config_api(req: GraphConfigRequest): |
| # Accept status code |
| res = apply_graph_conf( |
| req.url, req.name, req.user, req.pwd, req.gs, origin_call="http" |
| ) |
| return generate_response(RAGResponse(status_code=res, message="Missing Value")) |
| |
| # TODO: restructure the implement of llm to three types, like "/config/chat_llm" |
| @router.post("/config/llm", status_code=status.HTTP_201_CREATED) |
| def llm_config_api(req: LLMConfigRequest): |
| llm_settings.llm_type = req.llm_type |
| |
| if req.llm_type == "openai": |
| res = apply_llm_conf( |
| req.api_key, |
| req.api_base, |
| req.language_model, |
| req.max_tokens, |
| origin_call="http", |
| ) |
| else: |
| res = apply_llm_conf( |
| req.host, req.port, req.language_model, None, origin_call="http" |
| ) |
| return generate_response(RAGResponse(status_code=res, message="Missing Value")) |
| |
| @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) |
| def embedding_config_api(req: LLMConfigRequest): |
| llm_settings.embedding_type = req.llm_type |
| |
| if req.llm_type == "openai": |
| res = apply_embedding_conf( |
| req.api_key, req.api_base, req.language_model, origin_call="http" |
| ) |
| else: |
| res = apply_embedding_conf( |
| req.host, req.port, req.language_model, origin_call="http" |
| ) |
| return generate_response(RAGResponse(status_code=res, message="Missing Value")) |
| |
| @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) |
| def rerank_config_api(req: RerankerConfigRequest): |
| llm_settings.reranker_type = req.reranker_type |
| |
| if req.reranker_type == "cohere": |
| res = apply_reranker_conf( |
| req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http" |
| ) |
| elif req.reranker_type == "siliconflow": |
| res = apply_reranker_conf( |
| req.api_key, req.reranker_model, None, origin_call="http" |
| ) |
| else: |
| res = status.HTTP_501_NOT_IMPLEMENTED |
| return generate_response(RAGResponse(status_code=res, message="Missing Value")) |
| |
| @router.post("/text2gremlin", status_code=status.HTTP_200_OK) |
| def text2gremlin_api(req: GremlinGenerateRequest): |
| try: |
| set_graph_config(req) |
| |
| # Basic parameter validation: empty query => 400 |
| if not req.query or not str(req.query).strip(): |
| raise HTTPException( |
| status_code=status.HTTP_400_BAD_REQUEST, |
| detail="Query must not be empty.", |
| ) |
| |
| output_types_str_list = None |
| if req.output_types: |
| output_types_str_list = [ot.value for ot in req.output_types] |
| |
| response_dict = gremlin_generate_selective_func( |
| inp=req.query, |
| example_num=req.example_num, |
| schema_input=huge_settings.graph_name, |
| gremlin_prompt_input=req.gremlin_prompt, |
| requested_outputs=output_types_str_list, |
| ) |
| return response_dict |
| except HTTPException as e: |
| raise e |
| except Exception as e: |
| log.error("Error in text2gremlin_api: %s", e) |
| raise HTTPException( |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| detail="An unexpected error occurred during Gremlin generation.", |
| ) from e |