Merge branch 'main' into dev
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index 03ac9ae..836af52 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py
@@ -17,7 +17,7 @@ from typing import Dict, Any, Optional, List, Literal - +from hugegraph_llm.config import huge_settings from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.models.llms.base import BaseLLM @@ -26,6 +26,7 @@ from hugegraph_llm.operators.common_op.print_result import PrintResult from hugegraph_llm.operators.document_op.word_extract import WordExtract from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery +from hugegraph_llm.operators.hugegraph_op.graph_rag_query_acg import GraphRAGACGQuery from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery @@ -96,10 +97,10 @@ return self def keywords_to_vid( - self, - by: Literal["query", "keywords"] = "keywords", - topk_per_keyword: int = 1, - topk_per_query: int = 10, + self, + by: Literal["query", "keywords"] = "keywords", + topk_per_keyword: int = 1, + topk_per_query: int = 10, ): """ Add a semantic ID query operator to the pipeline. @@ -139,18 +140,29 @@ :param prop_to_match: Property to match in the graph. :return: Self-instance for chaining. """ - self._operators.append( - GraphRAGQuery( - max_deep=max_deep, - max_items=max_items, - max_v_prop_len=max_v_prop_len, - max_e_prop_len=max_e_prop_len, - prop_to_match=prop_to_match, - with_gremlin_template=with_gremlin_template, - num_gremlin_generate_example=num_gremlin_generate_example, - gremlin_prompt=gremlin_prompt, + if huge_settings.graph_space == "acgraggs": + self._operators.append( + GraphRAGACGQuery( + max_deep=max_deep, + max_v_prop_len=max_v_prop_len, + max_e_prop_len=max_e_prop_len, + prop_to_match=prop_to_match, + with_gremlin_template=with_gremlin_template + ) ) - ) + else: + self._operators.append( + GraphRAGQuery( + max_deep=max_deep, + max_items=max_items, + max_v_prop_len=max_v_prop_len, + max_e_prop_len=max_e_prop_len, + prop_to_match=prop_to_match, + with_gremlin_template=with_gremlin_template, + num_gremlin_generate_example=num_gremlin_generate_example, + gremlin_prompt=gremlin_prompt, + ) + ) return self def query_vector_index(self, max_items: int = 3): @@ -169,11 +181,11 @@ return self def merge_dedup_rerank( - self, - graph_ratio: float = 0.5, - rerank_method: Literal["bleu", "reranker"] = "bleu", - near_neighbor_first: bool = False, - custom_related_information: str = "", + self, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", ): """ Add a merge, deduplication, and rerank operator to the pipeline.
diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query_acg.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query_acg.py new file mode 100644 index 0000000..bd42895 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query_acg.py
@@ -0,0 +1,406 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import requests +from typing import Any, Dict, Optional, List, Set, Tuple + +from hugegraph_llm.config import huge_settings +from hugegraph_llm.models.embeddings.base import BaseEmbedding +from hugegraph_llm.models.llms.base import BaseLLM +from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator +from hugegraph_llm.utils.log import log +from pyhugegraph.client import PyHugeClient + +# # TODO: remove 'as('subj)' step +VERTEX_QUERY_TPL = "g.V({keywords}).toList()" + +# TODO: we could use a simpler query (like kneighbor-api to get the edges) +# TODO: test with profile()/explain() to speed up the query +VID_QUERY_MODULE_STATION_TPL = """\ +g.V({keywords}).bothE({edge_labels}).otherV().dedup() +.simplePath() +.path() +.by(project('label', 'id', 'props') + .by(label()) + .by(id()) + .by(valueMap().by(unfold())) +).by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().id()) + .by(outV().id()) + .by(valueMap().by(unfold())) +) +.toList() +""" + + +def get_paths_vertex_id(sources, targets, depth=2, capacity=100, limit=100): + log.debug(f"Get_Paths: {sources}, {targets}") + graph_ip = huge_settings.graph_ip + graph_port = huge_settings.graph_port + graph_space = huge_settings.graph_space + graph_name = huge_settings.graph_name + # 定义请求 URL 和头部 + url = f'http://{graph_ip}:{graph_port}/graphspaces/{graph_space}/graphs/{graph_name}/traversers/paths' + headers = { + "Content-Type": "application/json", + "Authorization": "Basic YWRtaW46UzMjcmQ2KHNnIQ==", + "Connection": "close" + } + data = { + "sources": { + "ids": sources + }, + "targets": { + "ids": targets + }, + "step": { + "direction": "BOTH" + }, + "max_depth": depth, + "capacity": capacity, + "limit": limit, + "with_vertex": True + } + log.debug(f"json: {json.dumps(data)}") + response = requests.post(url, headers=headers, data=json.dumps(data, ensure_ascii=False)) + vertex_id = [] + if response.status_code != 200: + log.error(f"Get_Paths Error: {response.status_code}") + return vertex_id + log.debug(f"Get_Paths Response: {response.json()}") + for vertex_info in response.json()["vertices"]: + vertex_id.append(vertex_info["id"]) + return vertex_id + + +class GraphRAGACGQuery: + def __init__( + self, + max_deep: int = 2, + prop_to_match: Optional[str] = None, + llm: Optional[BaseLLM] = None, + embedding: Optional[BaseEmbedding] = None, + max_v_prop_len: int = 2048, + max_e_prop_len: int = 256, + with_gremlin_template: bool = True, + num_gremlin_generate_example: int = 1 + ): + self._client = PyHugeClient( + huge_settings.graph_ip, + huge_settings.graph_port, + huge_settings.graph_name, + huge_settings.graph_user, + huge_settings.graph_pwd, + huge_settings.graph_space, + ) + self._max_deep = max_deep + self._prop_to_match = prop_to_match + self._schema = "" + self._limit_property = huge_settings.limit_property.lower() == "true" + self._max_v_prop_len = max_v_prop_len + self._max_e_prop_len = max_e_prop_len + self._gremlin_generator = GremlinGenerator( + llm=llm, + embedding=embedding, + ) + self._num_gremlin_generate_example = num_gremlin_generate_example + self._with_gremlin_template = with_gremlin_template + + def run(self, context: Dict[str, Any]) -> Dict[str, Any]: + self._init_client(context) + + # initial flag: -1 means no result, 0 means subgraph query, 1 means gremlin query + context["graph_result_flag"] = -1 + # 1. Try to perform a query based on the generated gremlin + context = self._gremlin_generate_query(context) + # 2. Try to perform a query based on subgraph-search if the previous query failed + if not context.get("graph_result"): + context = self._subgraph_query(context) + + if context.get("graph_result"): + log.debug("Knowledge from Graph:\n%s", "\n".join(context["graph_result"])) + else: + log.debug("No Knowledge Extracted from Graph") + return context + + def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + query = context["query"] + vertices = context.get("match_vids") + query_embedding = context.get("query_embedding") + + self._gremlin_generator.clear() + self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) + gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( + context["simple_schema"], + vertices=vertices, + ).run( + query=query, + query_embedding=query_embedding + ) + if self._with_gremlin_template: + gremlin = gremlin_response["result"] + else: + gremlin = gremlin_response["raw_result"] + log.info("Generated gremlin: %s", gremlin) + context["gremlin"] = gremlin + try: + result = self._client.gremlin().exec(gremlin=gremlin)["data"] + if result == [None]: + result = [] + context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] + if context["graph_result"]: + context["graph_result_flag"] = 1 + context["graph_context_head"] = ( + f"The following are graph query result " + f"from gremlin query `{gremlin}`.\n" + ) + except Exception as e: # pylint: disable=broad-except + log.error(e) + context["graph_result"] = "" + return context + + def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + # 1. Extract params from context + matched_vids = context.get("match_vids") + # 2. Extract edge_labels from graph schema + _, edge_labels = self._extract_labels_from_schema() + edge_labels_str = ",".join("'" + label + "'" for label in edge_labels) + + if not matched_vids: + return context + + gremlin_query = VERTEX_QUERY_TPL.format(keywords=matched_vids) + vertexes = self._client.gremlin().exec(gremlin=gremlin_query)["data"] + log.debug("Vids gremlin query: %s", gremlin_query) + vertex_knowledge = self._format_graph_from_vertex(query_result=vertexes) + + paths: List[Any] = [] + module_set: Set[str] = set() + concept_knowledge_set: Set[str] = set() + # TODO: use generator or asyncio to speed up the query logic + # 根据关键词,匹配模块、知识、概念类型的节点,忽略机房类型的节点; + for matched_vid in matched_vids: + if matched_vid[0] == "1": + # 对每个模块类型的节点,gremlin查找相关机房的节点信息和边信息; + gremlin_query = VID_QUERY_MODULE_STATION_TPL.format( + keywords="'{}'".format(matched_vid), + edge_labels=edge_labels_str, + ) + log.debug("Kneighbor gremlin query: %s", gremlin_query) + paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"]) + module_set.add(matched_vid) + elif matched_vid[0] == "3" or matched_vid[0] == "4": + concept_knowledge_set.add(matched_vid) + + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_query_result( + query_paths=paths + ) + # 调用Paths API 高级版,source是模块类型节点集合,target是知识和概念类型节点集合,step为2,获取所有路径所有节点,并去重取集合; + paths_vertex_ids = get_paths_vertex_id(sources=list(module_set), targets=list(concept_knowledge_set)) + paths_gremlin_query = VERTEX_QUERY_TPL.format(keywords=paths_vertex_ids) + paths_vertexes = self._client.gremlin().exec(gremlin=paths_gremlin_query)["data"] + paths_vertex_knowledge = self._format_graph_from_vertex(query_result=paths_vertexes) + graph_chain_knowledge.update(paths_vertex_knowledge) + if vertex_degree_list: + vertex_degree_list[0].update(vertex_knowledge) + else: + vertex_degree_list.append(vertex_knowledge) + + # TODO: we may need to optimize the logic here with global deduplication (may lack some single vertex) + if not graph_chain_knowledge: + graph_chain_knowledge.update(vertex_knowledge) + if vertex_degree_list: + vertex_degree_list[0].update(vertex_knowledge) + else: + vertex_degree_list.append(vertex_knowledge) + + context["graph_result"] = list(graph_chain_knowledge) + if context["graph_result"]: + context["graph_result_flag"] = 0 + context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list] + context["knowledge_with_degree"] = knowledge_with_degree + context["graph_context_head"] = ( + f"The following are graph knowledge in {self._max_deep} depth, e.g:\n" + "`vertexA--[links]-->vertexB<--[links]--vertexC ...`" + "extracted based on key entities as subject:\n" + ) + return context + + def _init_client(self, context): + # pylint: disable=R0915 (too-many-statements) + if self._client is None: + if isinstance(context.get("graph_client"), PyHugeClient): + self._client = context["graph_client"] + else: + ip = context.get("ip") or "localhost" + port = context.get("port") or "8080" + graph = context.get("graph") or "hugegraph" + user = context.get("user") or "admin" + pwd = context.get("pwd") or "admin" + gs = context.get("graphspace") or None + self._client = PyHugeClient(ip, port, graph, user, pwd, gs) + assert self._client is not None, "No valid graph to search." + + def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: + knowledge = set() + for item in query_result: + props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) + node_str = f"{item['id']}{{{props_str}}}" + knowledge.add(node_str) + return knowledge + + def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: + use_id_to_match = self._prop_to_match is None + subgraph = set() + subgraph_with_degree = {} + vertex_degree_list: List[Set[str]] = [] + v_cache: Set[str] = set() + e_cache: Set[Tuple[str, str, str]] = set() + + for path in query_paths: + # 1. Process each path + path_str, vertex_with_degree = self._process_path(path, use_id_to_match, v_cache, e_cache) + subgraph.add(path_str) + subgraph_with_degree[path_str] = vertex_with_degree + # 2. Update vertex degree list + self._update_vertex_degree_list(vertex_degree_list, vertex_with_degree) + + return subgraph, vertex_degree_list, subgraph_with_degree + + def _process_path(self, path: Any, use_id_to_match: bool, v_cache: Set[str], + e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, List[str]]: + flat_rel = "" + raw_flat_rel = path["objects"] + + assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should be odd." + + node_cache = set() + prior_edge_str_len = 0 + depth = 0 + nodes_with_degree = [] + + for i, item in enumerate(raw_flat_rel): + if i % 2 == 0: + # Process each vertex + flat_rel, prior_edge_str_len, depth = self._process_vertex( + item, flat_rel, node_cache, prior_edge_str_len, depth, nodes_with_degree, use_id_to_match, + v_cache + ) + else: + # Process each edge + flat_rel, prior_edge_str_len = self._process_edge( + item, flat_rel, raw_flat_rel, i, use_id_to_match, e_cache + ) + + return flat_rel, nodes_with_degree + + def _process_vertex(self, item: Any, flat_rel: str, node_cache: Set[str], + prior_edge_str_len: int, depth: int, nodes_with_degree: List[str], + use_id_to_match: bool, v_cache: Set[str]) -> Tuple[str, int, int]: + matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match] + if matched_str in node_cache: + flat_rel = flat_rel[:-prior_edge_str_len] + return flat_rel, prior_edge_str_len, depth + + node_cache.add(matched_str) + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" + for k, v in item["props"].items() if v) + + # TODO: we may remove label id or replace with label name + if matched_str in v_cache: + node_str = matched_str + else: + v_cache.add(matched_str) + node_str = f"{item['id']}{{{props_str}}}" + + flat_rel += node_str + nodes_with_degree.append(node_str) + depth += 1 + return flat_rel, prior_edge_str_len, depth + + def _process_edge(self, item: Any, path_str: str, raw_flat_rel: List[Any], i: int, use_id_to_match: bool, + e_cache: Set[Tuple[str, str, str]]) -> Tuple[str, int]: + props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" + for k, v in item["props"].items() if v) + props_str = f"{{{props_str}}}" if props_str else "" + prev_matched_str = raw_flat_rel[i - 1]["id"] if use_id_to_match else ( + raw_flat_rel)[i - 1]["props"][self._prop_to_match] + + edge_key = (item['inV'], item['label'], item['outV']) + if edge_key not in e_cache: + e_cache.add(edge_key) + edge_label = f"{item['label']}{props_str}" + else: + edge_label = item['label'] + + edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" + path_str += edge_str + prior_edge_str_len = len(edge_str) + return path_str, prior_edge_str_len + + def _update_vertex_degree_list(self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str]) -> None: + for depth, node_str in enumerate(nodes_with_degree): + if depth >= len(vertex_degree_list): + vertex_degree_list.append(set()) + vertex_degree_list[depth].add(node_str) + + def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: + schema = self._get_graph_schema() + vertex_props_str, edge_props_str = schema.split("\n")[:2] + # TODO: rename to vertex (also need update in the schema) + vertex_props_str = vertex_props_str[len("Vertex properties: "):].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]") + vertex_labels = self._extract_label_names(vertex_props_str) + edge_labels = self._extract_label_names(edge_props_str) + return vertex_labels, edge_labels + + @staticmethod + def _extract_label_names(source: str, head: str = "name: ", tail: str = ", ") -> List[str]: + result = [] + for s in source.split(head): + end = s.find(tail) + label = s[:end] + if label: + result.append(label) + return result + + def _get_graph_schema(self, refresh: bool = False) -> str: + if self._schema and not refresh: + return self._schema + + schema = self._client.schema() + vertex_schema = schema.getVertexLabels() + edge_schema = schema.getEdgeLabels() + relationships = schema.getRelations() + + self._schema = ( + f"Vertex properties: {vertex_schema}\n" + f"Edge properties: {edge_schema}\n" + f"Relationships: {relationships}\n" + ) + log.debug("Link(Relation): %s", relationships) + return self._schema + + def _limit_property_query(self, value: Optional[str], item_type: str) -> Optional[str]: + # NOTE: we skip the filter for list/set type (e.g., list of string, add it if needed) + if not self._limit_property or not isinstance(value, str): + return value + + max_len = self._max_v_prop_len if item_type == "v" else self._max_e_prop_len + return value[:max_len] if value else value