blob: 9d1e01adeb2047b5264a30365d07ac11311d6e28 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Literal
from fastapi import status, APIRouter, HTTPException
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
RAGRequest,
GraphConfigRequest,
LLMConfigRequest,
RerankerConfigRequest, GraphRAGRequest,
)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings, prompt
from hugegraph_llm.utils.log import log
def graph_rag_recall(
text: str,
rerank_method: Literal["bleu", "reranker"],
near_neighbor_first: bool,
custom_related_information: str
) -> dict:
from hugegraph_llm.operators.graph_rag_task import RAGPipeline
rag = RAGPipeline()
rag.extract_keywords().keywords_to_vid().query_graphdb().merge_dedup_rerank(
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
)
context = rag.run(verbose=True, query=text, graph_search=True)
return context
def rag_http_api(
router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(
req.query,
req.raw_answer,
req.vector_only,
req.graph_only,
req.graph_vector_answer,
req.graph_ratio,
req.rerank_method,
req.near_neighbor_first,
req.custom_priority_info,
req.answer_prompt or prompt.answer_prompt
)
return {
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
}
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall(
text=req.query,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info
)
if isinstance(result, dict):
return {"graph_recall": result}
return {"graph_recall": json.dumps(result)}
except TypeError as e:
log.error("TypeError in graph_rag_recall_api: %s", e)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e
@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
# Accept status code
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
# TODO: restructure the implement of llm to three types, like "/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type
if req.llm_type == "openai":
res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
@router.post("/config/embedding", status_code=status.HTTP_201_CREATED)
def embedding_config_api(req: LLMConfigRequest):
settings.embedding_type = req.llm_type
if req.llm_type == "openai":
res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http")
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
@router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
def rerank_config_api(req: RerankerConfigRequest):
settings.reranker_type = req.reranker_type
if req.reranker_type == "cohere":
res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
elif req.reranker_type == "siliconflow":
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))