blob: 3029b62595988808b6039477a1997095bf400be1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Literal, cast
from PyCGraph import GPipeline, GRegion, GCondition
from hugegraph_llm.flows.common import BaseFlow
from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode
from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode
from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode
from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode
from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode
from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode
from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState
from hugegraph_llm.config import huge_settings, prompt
from hugegraph_llm.utils.log import log
class GraphRecallCondition(GCondition):
def choose(self):
prepared_input: WkFlowInput = cast(
WkFlowInput, self.getGParamWithNoEmpty("wkflow_input")
)
return 0 if prepared_input.is_graph_rag_recall else 1
class VectorOnlyCondition(GCondition):
def choose(self):
prepared_input: WkFlowInput = cast(
WkFlowInput, self.getGParamWithNoEmpty("wkflow_input")
)
return 0 if prepared_input.is_vector_only else 1
# pylint: disable=arguments-differ,keyword-arg-before-vararg
class RAGGraphOnlyFlow(BaseFlow):
"""
Workflow for graph-only answering (graph_only_answer)
"""
def prepare(
self,
prepared_input: WkFlowInput,
query: str,
vector_search: bool = False,
graph_search: bool = True,
raw_answer: bool = False,
vector_only_answer: bool = False,
graph_only_answer: bool = True,
graph_vector_answer: bool = False,
rerank_method: Literal["bleu", "reranker"] = "bleu",
near_neighbor_first: bool = False,
custom_related_information: str = "",
answer_prompt: Optional[str] = None,
keywords_extract_prompt: Optional[str] = None,
gremlin_tmpl_num: Optional[int] = -1,
gremlin_prompt: Optional[str] = None,
max_graph_items: Optional[int] = None,
topk_return_results: Optional[int] = None,
vector_dis_threshold: Optional[float] = None,
topk_per_keyword: Optional[int] = None,
is_graph_rag_recall: bool = False,
is_vector_only: bool = False,
**kwargs,
):
prepared_input.query = query
prepared_input.vector_search = vector_search
prepared_input.graph_search = graph_search
prepared_input.raw_answer = raw_answer
prepared_input.vector_only_answer = vector_only_answer
prepared_input.graph_only_answer = graph_only_answer
prepared_input.graph_vector_answer = graph_vector_answer
prepared_input.gremlin_tmpl_num = gremlin_tmpl_num
prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
prepared_input.max_graph_items = (
max_graph_items or huge_settings.max_graph_items
)
prepared_input.topk_per_keyword = (
topk_per_keyword or huge_settings.topk_per_keyword
)
prepared_input.topk_return_results = (
topk_return_results or huge_settings.topk_return_results
)
prepared_input.rerank_method = rerank_method
prepared_input.near_neighbor_first = near_neighbor_first
prepared_input.keywords_extract_prompt = (
keywords_extract_prompt or prompt.keywords_extract_prompt
)
prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt
prepared_input.custom_related_information = custom_related_information
prepared_input.vector_dis_threshold = (
vector_dis_threshold or huge_settings.vector_dis_threshold
)
prepared_input.schema = huge_settings.graph_name
prepared_input.is_graph_rag_recall = is_graph_rag_recall
prepared_input.is_vector_only = is_vector_only
prepared_input.data_json = {
"query": query,
"vector_search": vector_search,
"graph_search": graph_search,
"max_graph_items": max_graph_items or huge_settings.max_graph_items,
"is_graph_rag_recall": is_graph_rag_recall,
}
def build_flow(self, **kwargs):
pipeline = GPipeline()
prepared_input = WkFlowInput()
self.prepare(prepared_input, **kwargs)
pipeline.createGParam(prepared_input, "wkflow_input")
pipeline.createGParam(WkFlowState(), "wkflow_state")
# Create nodes and register them with registerGElement
only_keyword_extract_node = KeywordExtractNode("only_keyword")
only_semantic_id_query_node = SemanticIdQueryNode(
{only_keyword_extract_node}, "only_semantic"
)
vector_region: GRegion = GRegion(
[only_keyword_extract_node, only_semantic_id_query_node]
)
only_schema_node = SchemaNode()
schema_node = VectorOnlyCondition([GRegion(), only_schema_node])
only_graph_query_node = GraphQueryNode("only_graph")
merge_rerank_node = MergeRerankNode({only_graph_query_node}, "merge_rerank")
graph_region: GRegion = GRegion([only_graph_query_node, merge_rerank_node])
graph_condition_region = VectorOnlyCondition([GRegion(), graph_region])
answer_synthesize_node = AnswerSynthesizeNode()
answer_node = GraphRecallCondition([GRegion(), answer_synthesize_node])
pipeline.registerGElement(vector_region, set(), "vector_fetch")
pipeline.registerGElement(schema_node, set(), "schema_condition")
pipeline.registerGElement(
graph_condition_region,
{schema_node, vector_region},
"graph_condition",
)
pipeline.registerGElement(
answer_node, {graph_condition_region}, "answer_condition"
)
log.info("RAGGraphOnlyFlow pipeline built successfully")
return pipeline
def post_deal(self, pipeline=None, **kwargs):
if pipeline is None:
return {"error": "No pipeline provided"}
res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json()
log.info("RAGGraphOnlyFlow post processing success")
return (
{
"raw_answer": res.get("raw_answer", ""),
"vector_only_answer": res.get("vector_only_answer", ""),
"graph_only_answer": res.get("graph_only_answer", ""),
"graph_vector_answer": res.get("graph_vector_answer", ""),
}
if not res.get("is_graph_rag_recall", False)
else res
)