feat(llm): support multi reranker & enhance the UI (#73)
* refact bleu rerank & support rank ui mapping in gradio
---------
Co-authored-by: imbajin <jin@apache.org>
diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
index ce0eaa7..a211bb8 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
@@ -52,3 +52,10 @@
# ollama-only properties
host: str = None
port: str = None
+
+
+class RerankerConfigRequest(BaseModel):
+ reranker_model: str
+ reranker_type: str
+ api_key: str
+ cohere_base_url: Optional[str] = None
diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
index 923e70f..64daf70 100644
--- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
+++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py
@@ -18,15 +18,23 @@
from fastapi import status, APIRouter
from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
-from hugegraph_llm.api.models.rag_requests import RAGRequest, GraphConfigRequest, LLMConfigRequest
+from hugegraph_llm.api.models.rag_requests import (
+ RAGRequest,
+ GraphConfigRequest,
+ LLMConfigRequest,
+ RerankerConfigRequest,
+)
from hugegraph_llm.api.models.rag_response import RAGResponse
from hugegraph_llm.config import settings
-def rag_http_api(router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf):
+def rag_http_api(
+ router: APIRouter, rag_answer_func, apply_graph_conf, apply_llm_conf, apply_embedding_conf, apply_reranker_conf
+):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
- result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector, req.answer_prompt)
+ result = rag_answer_func(req.query, req.raw_llm, req.vector_only, req.graph_only, req.graph_vector,
+ req.answer_prompt)
return {
key: value
for key, value in zip(["raw_llm", "vector_only", "graph_only", "graph_vector"], result)
@@ -44,9 +52,7 @@
settings.llm_type = req.llm_type
if req.llm_type == "openai":
- res = apply_llm_conf(
- req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http"
- )
+ res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http")
elif req.llm_type == "qianfan_wenxin":
res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http")
else:
@@ -64,3 +70,15 @@
else:
res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
+
+ @router.post("/config/rerank", status_code=status.HTTP_201_CREATED)
+ def rerank_config_api(req: RerankerConfigRequest):
+ settings.reranker_type = req.reranker_type
+
+ if req.reranker_type == "cohere":
+ res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http")
+ elif req.reranker_type == "siliconflow":
+ res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
+ else:
+ res = status.HTTP_501_NOT_IMPLEMENTED
+ return generate_response(RAGResponse(status_code=res, message="Missing Value"))
diff --git a/hugegraph-llm/src/hugegraph_llm/config/__init__.py b/hugegraph-llm/src/hugegraph_llm/config/__init__.py
index f801b88..3e6c9e9 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/__init__.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/__init__.py
@@ -22,8 +22,8 @@
]
import os
-from .config import Config
+from .config import Config
settings = Config()
settings.from_env()
diff --git a/hugegraph-llm/src/hugegraph_llm/config/config.py b/hugegraph-llm/src/hugegraph_llm/config/config.py
index 2fd8262..2a73b62 100644
--- a/hugegraph-llm/src/hugegraph_llm/config/config.py
+++ b/hugegraph-llm/src/hugegraph_llm/config/config.py
@@ -35,31 +35,34 @@
# env_path: Optional[str] = ".env"
llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai"
embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai"
+ reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
# 1. OpenAI settings
openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_language_model: Optional[str] = "gpt-4o-mini"
openai_embedding_model: Optional[str] = "text-embedding-3-small"
openai_max_tokens: int = 4096
- # 2. Ollama settings
+ # 2. Rerank settings
+ cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank")
+ reranker_api_key: Optional[str] = None
+ reranker_model: Optional[str] = None
+ # 3. Ollama settings
ollama_host: Optional[str] = "127.0.0.1"
ollama_port: Optional[int] = 11434
ollama_language_model: Optional[str] = None
ollama_embedding_model: Optional[str] = None
- # 3. QianFan/WenXin settings
+ # 4. QianFan/WenXin settings
qianfan_api_key: Optional[str] = None
qianfan_secret_key: Optional[str] = None
qianfan_access_token: Optional[str] = None
- # 3.1 url settings
- qianfan_url_prefix: Optional[str] = (
- "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
- )
+ # 4.1 URL settings
+ qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/"
qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K"
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
- # https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu
+ # refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details
qianfan_embedding_model: Optional[str] = "embedding-v1"
- # 4. ZhiPu(GLM) settings
+ # 5. ZhiPu(GLM) settings
zhipu_api_key: Optional[str] = None
zhipu_language_model: Optional[str] = "glm-4"
zhipu_embedding_model: Optional[str] = "embedding-2"
diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
index f065b6f..20ed5b0 100644
--- a/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
+++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_web_demo.py
@@ -19,7 +19,7 @@
import argparse
import json
import os
-from typing import List, Union
+from typing import List, Union, Tuple, Literal, Optional
import docx
import gradio as gr
@@ -50,6 +50,7 @@
correct_token = os.getenv("TOKEN")
if credentials.credentials != correct_token:
from fastapi import HTTPException
+
raise HTTPException(
status_code=401,
detail=f"Invalid token {credentials.credentials}, please contact the admin",
@@ -58,8 +59,17 @@
def rag_answer(
- text: str, raw_answer: bool, vector_only_answer: bool, graph_only_answer: bool,
- graph_vector_answer: bool, answer_prompt: str) -> tuple:
+ text: str,
+ raw_answer: bool,
+ vector_only_answer: bool,
+ graph_only_answer: bool,
+ graph_vector_answer: bool,
+ graph_ratio: float,
+ rerank_method: Literal["bleu", "reranker"],
+ near_neighbor_first: bool,
+ custom_related_information: str,
+ answer_prompt: str,
+) -> Tuple:
vector_search = vector_only_answer or graph_vector_answer
graph_search = graph_only_answer or graph_vector_answer
@@ -72,16 +82,20 @@
if graph_search:
searcher.extract_keyword().match_keyword_to_id().query_graph_for_rag()
# TODO: add more user-defined search strategies
- searcher.merge_dedup_rerank().synthesize_answer(
+ searcher.merge_dedup_rerank(
+ graph_ratio, rerank_method, near_neighbor_first, custom_related_information
+ ).synthesize_answer(
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
- answer_prompt=answer_prompt
+ answer_prompt=answer_prompt,
)
try:
- context = searcher.run(verbose=True, query=text)
+ context = searcher.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
+ if context.get("switch_to_bleu"):
+ gr.Warning("Online reranker fails, automatically switches to local bleu method.")
return (
context.get("raw_answer", ""),
context.get("vector_only_answer", ""),
@@ -97,10 +111,10 @@
def build_kg( # pylint: disable=too-many-branches
- files: Union[NamedString, List[NamedString]],
- schema: str,
- example_prompt: str,
- build_mode: str
+ files: Union[NamedString, List[NamedString]],
+ schema: str,
+ example_prompt: str,
+ build_mode: str,
) -> str:
if isinstance(files, NamedString):
files = [files]
@@ -161,8 +175,7 @@
raise gr.Error(str(e))
-def test_api_connection(url, method="GET",
- headers=None, params=None, body=None, auth=None, origin_call=None) -> int:
+def test_api_connection(url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None) -> int:
# TODO: use fastapi.request / starlette instead?
log.debug("Request URL: %s", url)
try:
@@ -188,7 +201,10 @@
log.error(msg)
# TODO: Only the message returned by rag can be processed, and the other return values can't be processed
if origin_call is None:
- raise gr.Error(json.loads(resp.text).get("message", msg))
+ try:
+ raise gr.Error(json.loads(resp.text).get("message", msg))
+ except json.decoder.JSONDecodeError and AttributeError:
+ raise gr.Error(resp.text)
return resp.status_code
@@ -199,10 +215,11 @@
params = {
"grant_type": "client_credentials",
"client_id": arg1,
- "client_secret": arg2
+ "client_secret": arg2,
}
- status_code = test_api_connection("https://aip.baidubce.com/oauth/2.0/token", "POST", params=params,
- origin_call=origin_call)
+ status_code = test_api_connection(
+ "https://aip.baidubce.com/oauth/2.0/token", "POST", params=params, origin_call=origin_call
+ )
return status_code
@@ -229,6 +246,42 @@
return status_code
+def apply_reranker_config(
+ reranker_api_key: Optional[str] = None,
+ reranker_model: Optional[str] = None,
+ cohere_base_url: Optional[str] = None,
+ origin_call=None,
+) -> int:
+ status_code = -1
+ reranker_option = settings.reranker_type
+ if reranker_option == "cohere":
+ settings.reranker_api_key = reranker_api_key
+ settings.reranker_model = reranker_model
+ settings.cohere_base_url = cohere_base_url
+ headers = {"Authorization": f"Bearer {reranker_api_key}"}
+ status_code = test_api_connection(
+ cohere_base_url.rsplit("/", 1)[0] + "/check-api-key",
+ method="POST",
+ headers=headers,
+ origin_call=origin_call,
+ )
+ elif reranker_option == "siliconflow":
+ settings.reranker_api_key = reranker_api_key
+ settings.reranker_model = reranker_model
+ headers = {
+ "accept": "application/json",
+ "authorization": f"Bearer {reranker_api_key}",
+ }
+ status_code = test_api_connection(
+ "https://api.siliconflow.cn/v1/user/info",
+ headers=headers,
+ origin_call=origin_call,
+ )
+ settings.update_env()
+ gr.Info("Configured!")
+ return status_code
+
+
def apply_graph_config(ip, port, name, user, pwd, gs, origin_call=None) -> int:
settings.graph_ip = ip
settings.graph_port = port
@@ -274,9 +327,11 @@
def init_rag_ui() -> gr.Interface:
- with gr.Blocks(theme='default',
- title="HugeGraph RAG Platform",
- css="footer {visibility: hidden}") as hugegraph_llm_ui:
+ with gr.Blocks(
+ theme="default",
+ title="HugeGraph RAG Platform",
+ css="footer {visibility: hidden}",
+ ) as hugegraph_llm_ui:
gr.Markdown(
"""# HugeGraph LLM RAG Demo
1. Set up the HugeGraph server."""
@@ -324,7 +379,7 @@
gr.Textbox(value=settings.qianfan_language_model, label="model_name"),
gr.Textbox(value="", visible=False),
]
- log.debug(llm_config_input)
+ # log.debug(llm_config_input)
else:
llm_config_input = []
llm_config_button = gr.Button("apply configuration")
@@ -367,7 +422,47 @@
# Call the separate apply_embedding_configuration function here
embedding_config_button.click( # pylint: disable=no-member
- apply_embedding_config, inputs=embedding_config_input # pylint: disable=no-member
+ fn=apply_embedding_config,
+ inputs=embedding_config_input, # pylint: disable=no-member
+ )
+
+ gr.Markdown("4. Set up the Reranker (Optional).")
+ reranker_dropdown = gr.Dropdown(
+ choices=["cohere", "siliconflow", ("default/offline", "None")],
+ value=os.getenv("reranker_type") or "None",
+ label="Reranker",
+ )
+
+ @gr.render(inputs=[reranker_dropdown])
+ def reranker_settings(reranker_type):
+ settings.reranker_type = reranker_type if reranker_type != "None" else None
+ if reranker_type == "cohere":
+ with gr.Row():
+ reranker_config_input = [
+ gr.Textbox(value=settings.reranker_api_key, label="api_key", type="password"),
+ gr.Textbox(value=settings.reranker_model, label="model"),
+ gr.Textbox(value=settings.cohere_base_url, label="base_url"),
+ ]
+ elif reranker_type == "siliconflow":
+ with gr.Row():
+ reranker_config_input = [
+ gr.Textbox(value=settings.reranker_api_key, label="api_key", type="password"),
+ gr.Textbox(
+ value="BAAI/bge-reranker-v2-m3",
+ label="model",
+ info="Please refer to https://siliconflow.cn/pricing",
+ ),
+ ]
+ else:
+ reranker_config_input = []
+
+ reranker_config_button = gr.Button("apply configuration")
+
+ # TODO: use "gr.update()" or other way to update the config in time (refactor the click event)
+ # Call the separate apply_reranker_configuration function here
+ reranker_config_button.click( # pylint: disable=no-member
+ fn=apply_reranker_config,
+ inputs=reranker_config_input, # pylint: disable=no-member
)
gr.Markdown(
@@ -425,7 +520,8 @@
input_file = gr.File(
value=[os.path.join(resource_path, "demo", "test.txt")],
label="Docs (multi-files can be selected together)",
- file_count="multiple")
+ file_count="multiple",
+ )
input_schema = gr.Textbox(value=schema, label="Schema")
info_extract_template = gr.Textbox(value=SCHEMA_EXAMPLE_PROMPT, label="Info extract head")
with gr.Column():
@@ -438,7 +534,9 @@
with gr.Row():
out = gr.Textbox(label="Output", show_copy_button=True)
btn.click( # pylint: disable=no-member
- fn=build_kg, inputs=[input_file, input_schema, info_extract_template, mode], outputs=out
+ fn=build_kg,
+ inputs=[input_file, input_schema, info_extract_template, mode],
+ outputs=out,
)
gr.Markdown("""## 2. RAG with HugeGraph 📖""")
@@ -449,15 +547,44 @@
vector_only_out = gr.Textbox(label="Vector-only Answer", show_copy_button=True)
graph_only_out = gr.Textbox(label="Graph-only Answer", show_copy_button=True)
graph_vector_out = gr.Textbox(label="Graph-Vector Answer", show_copy_button=True)
- with gr.Column(scale=1):
- raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer")
- vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer")
- graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer")
- graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer")
- btn = gr.Button("Answer Question")
from hugegraph_llm.operators.llm_op.answer_synthesize import DEFAULT_ANSWER_TEMPLATE
- answer_prompt_input = gr.Textbox(value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt",
- show_copy_button=True)
+
+ answer_prompt_input = gr.Textbox(
+ value=DEFAULT_ANSWER_TEMPLATE, label="Custom Prompt", show_copy_button=True
+ )
+ with gr.Column(scale=1):
+ with gr.Row():
+ raw_radio = gr.Radio(choices=[True, False], value=True, label="Basic LLM Answer")
+ vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer")
+ with gr.Row():
+ graph_only_radio = gr.Radio(choices=[True, False], value=False, label="Graph-only Answer")
+ graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer")
+
+ def toggle_slider(enable):
+ return gr.update(interactive=enable)
+
+ with gr.Column():
+ with gr.Row():
+ online_rerank = os.getenv("reranker_type")
+ rerank_method = gr.Dropdown(
+ choices=["bleu", ("rerank (online)", "reranker")] if online_rerank else ["bleu"],
+ value="reranker" if online_rerank else "bleu",
+ label="Rerank method",
+ )
+ graph_ratio = gr.Slider(0, 1, 0.5, label="Graph Ratio", step=0.1, interactive=False)
+
+ graph_vector_radio.change(toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio)
+ near_neighbor_first = gr.Checkbox(
+ value=False,
+ label="Near neighbor first(Optional)",
+ info="One-depth neighbors > two-depth neighbors",
+ )
+ custom_related_information = gr.Text(
+ "",
+ label="Custom related information(Optional)",
+ )
+ btn = gr.Button("Answer Question", variant="primary")
+
btn.click( # pylint: disable=no-member
fn=rag_answer,
inputs=[
@@ -466,6 +593,10 @@
vector_only_radio,
graph_only_radio,
graph_vector_radio,
+ graph_ratio,
+ rerank_method,
+ near_neighbor_first,
+ custom_related_information,
answer_prompt_input,
],
outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out],
@@ -497,7 +628,14 @@
app_auth = APIRouter(dependencies=[Depends(authenticate)])
hugegraph_llm = init_rag_ui()
- rag_http_api(app_auth, rag_answer, apply_graph_config, apply_llm_config, apply_embedding_config)
+ rag_http_api(
+ app_auth,
+ rag_answer,
+ apply_graph_config,
+ apply_llm_config,
+ apply_embedding_config,
+ apply_reranker_config,
+ )
app.include_router(app_auth)
auth_enabled = os.getenv("ENABLE_LOGIN", "False").lower() == "true"
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py
new file mode 100644
index 0000000..13a8339
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
new file mode 100644
index 0000000..b552717
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, List
+
+import requests
+
+
+class CohereReranker:
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ model: Optional[str] = None,
+ ):
+ self.api_key = api_key
+ self.base_url = base_url
+ self.model = model
+
+ def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]:
+ if not top_n:
+ top_n = len(documents)
+ assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents"
+
+ if top_n == 0:
+ return []
+
+ url = self.base_url
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "Authorization": f"Bearer {self.api_key}",
+ }
+ payload = {
+ "model": self.model,
+ "query": query,
+ "top_n": top_n,
+ "documents": documents,
+ }
+ response = requests.post(url, headers=headers, json=payload)
+ response.raise_for_status() # Raise an error for bad status codes
+ results = response.json()["results"]
+ sorted_docs = [documents[item["index"]] for item in results]
+ return sorted_docs
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
new file mode 100644
index 0000000..541f413
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from hugegraph_llm.config import settings
+from hugegraph_llm.models.rerankers.cohere import CohereReranker
+from hugegraph_llm.models.rerankers.siliconflow import SiliconReranker
+
+
+class Rerankers:
+ def __init__(self):
+ self.reranker_type = settings.reranker_type
+
+ def get_reranker(self):
+ if self.reranker_type == "cohere":
+ return CohereReranker(
+ api_key=settings.reranker_api_key, base_url=settings.cohere_base_url, model=settings.reranker_model
+ )
+ elif self.reranker_type == "siliconflow":
+ return SiliconReranker(api_key=settings.reranker_api_key, model=settings.reranker_model)
+ else:
+ raise Exception(f"reranker type is not supported !")
diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
new file mode 100644
index 0000000..a860a84
--- /dev/null
+++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, List
+
+import requests
+
+
+class SiliconReranker:
+ def __init__(
+ self,
+ api_key: Optional[str] = None,
+ model: Optional[str] = None,
+ ):
+ self.api_key = api_key
+ self.model = model
+
+ def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]:
+ if not top_n:
+ top_n = len(documents)
+ assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents"
+
+ if top_n == 0:
+ return []
+
+ url = "https://api.siliconflow.cn/v1/rerank"
+ payload = {
+ "model": self.model,
+ "query": query,
+ "documents": documents,
+ "return_documents": False,
+ "max_chunks_per_doc": 1024,
+ "overlap_tokens": 80,
+ "top_n": top_n,
+ }
+ headers = {
+ "accept": "application/json",
+ "content-type": "application/json",
+ "authorization": f"Bearer {self.api_key}",
+ }
+ response = requests.post(url, json=payload, headers=headers)
+ response.raise_for_status() # Raise an error for bad status codes
+ results = response.json()["results"]
+ sorted_docs = [documents[item["index"]] for item in results]
+ return sorted_docs
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
index a34cbed..6e356e2 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py
@@ -16,59 +16,131 @@
# under the License.
-from typing import Dict, Any, List, Literal
+from typing import Literal, Dict, Any, List, Optional, Tuple
import jieba
-from hugegraph_llm.models.embeddings.base import BaseEmbedding
+import requests
from nltk.translate.bleu_score import sentence_bleu
+from hugegraph_llm.models.embeddings.base import BaseEmbedding
+from hugegraph_llm.models.rerankers.init_reranker import Rerankers
+from hugegraph_llm.utils.log import log
-def get_score(query: str, content: str) -> float:
+
+def get_bleu_score(query: str, content: str) -> float:
query_tokens = jieba.lcut(query)
content_tokens = jieba.lcut(content)
return sentence_bleu([query_tokens], content_tokens)
+def _bleu_rerank(query: str, results: List[str]) -> List[str]:
+ result_score_list = [[res, get_bleu_score(query, res)] for res in results]
+ result_score_list.sort(key=lambda x: x[1], reverse=True)
+ return [res[0] for res in result_score_list]
+
+
class MergeDedupRerank:
def __init__(
- self,
- embedding: BaseEmbedding,
- topk: int = 10,
- strategy: Literal["bleu", "priority"] = "bleu"
+ self,
+ embedding: BaseEmbedding,
+ topk: int = 20,
+ graph_ratio: float = 0.5,
+ method: Literal["bleu", "reranker"] = "bleu",
+ near_neighbor_first: bool = False,
+ custom_related_information: Optional[str] = None,
+ priority: bool = False, # TODO: implement priority
):
+ assert method in ["bleu", "reranker"], f"Unimplemented rerank method '{method}'."
self.embedding = embedding
+ self.graph_ratio = graph_ratio
self.topk = topk
- if strategy == "bleu":
- self.rerank_func = self._bleu_rerank
- elif strategy == "priority":
- self.rerank_func = self._priority_rerank
- else:
- raise ValueError(f"Unimplemented rerank strategy {strategy}.")
+ self.method = method
+ self.near_neighbor_first = near_neighbor_first
+ self.custom_related_information = custom_related_information
+ if priority:
+ raise ValueError(f"Unimplemented rerank strategy: priority.")
+ self.switch_to_bleu = False
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context.get("query")
+ if self.custom_related_information:
+ query = query + self.custom_related_information
+ context["graph_ratio"] = self.graph_ratio
+ vector_search = context.get("vector_search", False)
+ graph_search = context.get("graph_search", False)
+ if graph_search and vector_search:
+ graph_length = int(self.topk * self.graph_ratio)
+ vector_length = self.topk - graph_length
+ else:
+ graph_length = self.topk
+ vector_length = self.topk
vector_result = context.get("vector_result", [])
- vector_result = self.rerank_func(query, vector_result)[:self.topk]
+ vector_length = min(len(vector_result), vector_length)
+ vector_result = self._dedup_and_rerank(query, vector_result, vector_length)
graph_result = context.get("graph_result", [])
- graph_result = self.rerank_func(query, graph_result)[:self.topk]
+ graph_length = min(len(graph_result), graph_length)
+ if self.near_neighbor_first:
+ graph_result = self._rerank_with_vertex_degree(
+ query,
+ graph_result,
+ graph_length,
+ context.get("vertex_degree_list"),
+ context.get("knowledge_with_degree"),
+ )
+ if self.switch_to_bleu:
+ context["switch_to_bleu"] = True
+ else:
+ graph_result = self._dedup_and_rerank(query, graph_result, graph_length)
context["vector_result"] = vector_result
context["graph_result"] = graph_result
return context
- def _bleu_rerank(self, query: str, results: List[str]):
+ def _dedup_and_rerank(self, query: str, results: List[str], topn: int) -> List[str]:
results = list(set(results))
- result_score_list = [[res, get_score(query, res)] for res in results]
- result_score_list.sort(key=lambda x: x[1], reverse=True)
- return [res[0] for res in result_score_list]
+ if self.method == "bleu":
+ return _bleu_rerank(query, results)[:topn]
+ if self.method == "reranker":
+ reranker = Rerankers().get_reranker()
+ return reranker.get_rerank_lists(query, results, topn)
- def _priority_rerank(self, query: str, results: List[str]):
- # TODO: implement
- # 1. Precise recall > Fuzzy recall
- # 2. 1-degree neighbors > 2-degree neighbors
- # 3. The priority of a certain type of point is higher than others,
- # such as Law being higher than vehicles/people/locations
- raise NotImplementedError()
+ def _rerank_with_vertex_degree(
+ self,
+ query: str,
+ results: List[str],
+ topn: int,
+ vertex_degree_list: Optional[List[List[str]]],
+ knowledge_with_degree: Dict[str, List[str]],
+ ) -> List[str]:
+ if vertex_degree_list is None or len(vertex_degree_list) == 0:
+ return self._dedup_and_rerank(query, results, topn)
+
+ if self.method == "reranker":
+ reranker = Rerankers().get_reranker()
+ try:
+ vertex_rerank_res = [
+ reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list
+ ]
+ except requests.exceptions.RequestException as e:
+ log.warning(f"Online reranker fails, automatically switches to local bleu method: {e}")
+ self.method = "bleu"
+ self.switch_to_bleu = True
+
+ if self.method == "bleu":
+ vertex_rerank_res = [_bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list]
+
+ depth = len(vertex_degree_list)
+ for result in results:
+ if result not in knowledge_with_degree:
+ knowledge_with_degree[result] = [result] + [""] * (depth - 1)
+ if len(knowledge_with_degree[result]) < depth:
+ knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result]))
+
+ def sort_key(res: str) -> Tuple[int, ...]:
+ return tuple(vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth))
+
+ sorted_results = sorted(results, key=sort_key)
+ return sorted_results[:topn]
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
index 5ac3209..91bc7b3 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -100,10 +100,10 @@
return self
def match_keyword_to_id(
- self,
- by: Literal["query", "keywords"] = "keywords",
- topk_per_keyword: int = 1,
- topk_per_query: int = 10
+ self,
+ by: Literal["query", "keywords"] = "keywords",
+ topk_per_keyword: int = 1,
+ topk_per_query: int = 10,
):
"""
Add a semantic ID query operator to the pipeline.
@@ -116,16 +116,16 @@
embedding=self._embedding,
by=by,
topk_per_keyword=topk_per_keyword,
- topk_per_query=topk_per_query
+ topk_per_query=topk_per_query,
)
)
return self
def query_graph_for_rag(
- self,
- max_deep: int = 2,
- max_items: int = 30,
- prop_to_match: Optional[str] = None,
+ self,
+ max_deep: int = 2,
+ max_items: int = 30,
+ prop_to_match: Optional[str] = None,
):
"""
Add a graph RAG query operator to the pipeline.
@@ -144,10 +144,7 @@
)
return self
- def query_vector_index_for_rag(
- self,
- max_items: int = 3
- ):
+ def query_vector_index_for_rag(self, max_items: int = 3):
"""
Add a vector index query operator to the pipeline.
@@ -162,13 +159,27 @@
)
return self
- def merge_dedup_rerank(self):
+ def merge_dedup_rerank(
+ self,
+ graph_ratio: float = 0.5,
+ rerank_method: Literal["bleu", "reranker"] = "bleu",
+ near_neighbor_first: bool = False,
+ custom_related_information: str = "",
+ ):
"""
Add a merge, deduplication, and rerank operator to the pipeline.
:return: Self-instance for chaining.
"""
- self._operators.append(MergeDedupRerank(embedding=self._embedding))
+ self._operators.append(
+ MergeDedupRerank(
+ embedding=self._embedding,
+ graph_ratio=graph_ratio,
+ method=rerank_method,
+ near_neighbor_first=near_neighbor_first,
+ custom_related_information=custom_related_information,
+ )
+ )
return self
def synthesize_answer(
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
index 5f18d3c..fe225c2 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py
@@ -24,9 +24,7 @@
class GraphRAGQuery:
- VERTEX_GREMLIN_QUERY_TEMPL = (
- "g.V().hasId({keywords}).as('subj').toList()"
- )
+ VERTEX_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').toList()"
# ID_RAG_GREMLIN_QUERY_TEMPL = "g.V().hasId({keywords}).as('subj').repeat(bothE({edge_labels}).as('rel').otherV(
# ).as('obj')).times({max_deep}).path().by(project('label', 'id', 'props').by(label()).by(id()).by(valueMap().by(
# unfold()))).by(project('label', 'inV', 'outV', 'props').by(label()).by(inV().id()).by(outV().id()).by(valueMap(
@@ -75,10 +73,10 @@
"""
def __init__(
- self,
- max_deep: int = 2,
- max_items: int = 30,
- prop_to_match: Optional[str] = None,
+ self,
+ max_deep: int = 2,
+ max_items: int = 30,
+ prop_to_match: Optional[str] = None,
):
self._client = PyHugeClient(
settings.graph_ip,
@@ -133,14 +131,16 @@
edge_labels=edge_labels_str,
)
result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- knowledge: Set[str] = self._format_knowledge_from_query_result(query_result=result)
+ graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result(
+ query_result=result
+ )
else:
assert entrance_vids is not None, "No entrance vertices for query."
rag_gremlin_query = self.VERTEX_GREMLIN_QUERY_TEMPL.format(
keywords=entrance_vids,
)
result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- knowledge: Set[str] = self._format_knowledge_from_vertex(query_result=result)
+ vertex_knowledge = self._format_knowledge_from_vertex(query_result=result)
rag_gremlin_query = self.ID_RAG_GREMLIN_QUERY_TEMPL.format(
keywords=entrance_vids,
max_deep=self._max_deep,
@@ -148,21 +148,27 @@
edge_labels=edge_labels_str,
)
result: List[Any] = self._client.gremlin().exec(gremlin=rag_gremlin_query)["data"]
- knowledge.update(self._format_knowledge_from_query_result(query_result=result))
+ graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_knowledge_from_query_result(
+ query_result=result
+ )
+ graph_chain_knowledge.update(vertex_knowledge)
+ vertex_degree_list[0].update(vertex_knowledge)
- context["graph_result"] = list(knowledge)
- context["synthesize_context_head"] = (
+ context["graph_result"] = list(graph_chain_knowledge)
+ context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list]
+ context["knowledge_with_degree"] = knowledge_with_degree
+ context["graph_context_head"] = (
f"The following are knowledge sequence in max depth {self._max_deep} "
f"in the form of directed graph like:\n"
- "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...` "
- "extracted based on key entities as subject:"
+ "`subject -[predicate]-> object <-[predicate_next_hop]- object_next_hop ...`"
+ "extracted based on key entities as subject:\n"
)
# TODO: replace print to log
verbose = context.get("verbose") or False
if verbose:
print("\033[93mKnowledge from Graph:")
- print("\n".join(rel for rel in context["graph_result"]) + "\033[0m")
+ print("\n".join(chain for chain in context["graph_result"]) + "\033[0m")
return context
@@ -174,20 +180,24 @@
knowledge.add(node_str)
return knowledge
- def _format_knowledge_from_query_result(self, query_result: List[Any]) -> Set[str]:
+ def _format_knowledge_from_query_result(
+ self, query_result: List[Any]
+ ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]:
use_id_to_match = self._prop_to_match is None
knowledge = set()
+ knowledge_with_degree = {}
+ vertex_degree_list: List[Set[str]] = []
for line in query_result:
flat_rel = ""
raw_flat_rel = line["objects"]
assert len(raw_flat_rel) % 2 == 1
node_cache = set()
prior_edge_str_len = 0
+ depth = 0
+ nodes_with_degree = []
for i, item in enumerate(raw_flat_rel):
if i % 2 == 0:
- matched_str = (
- item["id"] if use_id_to_match else item["props"][self._prop_to_match]
- )
+ matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match]
if matched_str in node_cache:
flat_rel = flat_rel[:-prior_edge_str_len]
break
@@ -195,8 +205,14 @@
props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
node_str = f"{item['id']}{{{props_str}}}"
flat_rel += node_str
+ nodes_with_degree.append(node_str)
if flat_rel in knowledge:
knowledge.remove(flat_rel)
+ knowledge_with_degree.pop(flat_rel)
+ if depth >= len(vertex_degree_list):
+ vertex_degree_list.append(set())
+ vertex_degree_list[depth].add(node_str)
+ depth += 1
else:
props_str = ", ".join(f"{k}: {v}" for k, v in item["props"].items())
props_str = f"{{{props_str}}}" if len(props_str) > 0 else ""
@@ -212,22 +228,23 @@
flat_rel += edge_str
prior_edge_str_len = len(edge_str)
knowledge.add(flat_rel)
- return knowledge
+ knowledge_with_degree[flat_rel] = nodes_with_degree
+ return knowledge, vertex_degree_list, knowledge_with_degree
def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
schema = self._get_graph_schema()
node_props_str, edge_props_str = schema.split("\n")[:2]
- node_props_str = node_props_str[len("Node properties: "):].strip("[").strip("]")
- edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]")
+ node_props_str = node_props_str[len("Node properties: ") :].strip("[").strip("]")
+ edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]")
node_labels = self._extract_label_names(node_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return node_labels, edge_labels
@staticmethod
def _extract_label_names(
- source: str,
- head: str = "name: ",
- tail: str = ", ",
+ source: str,
+ head: str = "name: ",
+ tail: str = ", ",
) -> List[str]:
result = []
for s in source.split(head):
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
index 8019143..2d05160 100644
--- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
+++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py
@@ -31,10 +31,6 @@
---------------------
{{context_str}}
---------------------
-Please refer to the context based on the following priority:
-1. Graph data > vector data
-2. Precise data > fuzzy data
-3. One-depth neighbors > two-depth neighbors
Given the context information and without using fictive knowledge,
answer the following query in a concise and professional manner.
@@ -95,19 +91,20 @@
vector_result = context.get("vector_result", [])
if len(vector_result) == 0:
- vector_result_context = "There are no paragraphs related to the query."
+ vector_result_context = "No (vector)phrase related to the query."
else:
- vector_result_context = ("The following are paragraphs related to the query:\n"
- + "\n".join([f"{i + 1}. {res}"
- for i, res in enumerate(vector_result)]))
+ vector_result_context = "Phrases related to the query:\n" + "\n".join(
+ f"{i + 1}. {res}" for i, res in enumerate(vector_result)
+ )
graph_result = context.get("graph_result", [])
if len(graph_result) == 0:
- graph_result_context = "There are no knowledge from HugeGraph related to the query."
+ graph_result_context = "No knowledge found in HugeGraph for the query."
else:
- graph_result_context = (
- "The following are knowledge from HugeGraph related to the query:\n"
- + "\n".join([f"{i + 1}. {res}"
- for i, res in enumerate(graph_result)]))
+ graph_context_head = context.get("graph_context_head",
+ "The following are knowledge from HugeGraph related to the query:\n")
+ graph_result_context = graph_context_head + "\n".join(
+ f"{i + 1}. {res}" for i, res in enumerate(graph_result)
+ )
context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str,
vector_result_context, graph_result_context))
@@ -117,6 +114,7 @@
context_tail_str: str, vector_result_context: str,
graph_result_context: str):
verbose = context.get("verbose") or False
+ # TODO: replace task_cache with a better name
task_cache = {}
if self._raw_answer:
prompt = self._question
@@ -126,31 +124,24 @@
f"{vector_result_context}\n"
f"{context_tail_str}".strip("\n"))
- prompt = self._prompt_template.format(
- context_str=context_str,
- query_str=self._question,
- )
+ prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt))
if self._graph_only_answer:
context_str = (f"{context_head_str}\n"
f"{graph_result_context}\n"
f"{context_tail_str}".strip("\n"))
- prompt = self._prompt_template.format(
- context_str=context_str,
- query_str=self._question,
- )
+ prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt))
if self._graph_vector_answer:
context_body_str = f"{vector_result_context}\n{graph_result_context}"
+ if context.get("graph_ratio", 0.5) < 0.5:
+ context_body_str = f"{graph_result_context}\n{vector_result_context}"
context_str = (f"{context_head_str}\n"
f"{context_body_str}\n"
f"{context_tail_str}".strip("\n"))
- prompt = self._prompt_template.format(
- context_str=context_str,
- query_str=self._question,
- )
+ prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["graph_vector_task"] = asyncio.create_task(
self._llm.agenerate(prompt=prompt)
)