blob: 6012b75344912effbb054d3efc5ca90978ad2503 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict, Optional, List, Set, Tuple
from hugegraph_llm.config import huge_settings, prompt
from hugegraph_llm.models.embeddings.base import BaseEmbedding
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator
from hugegraph_llm.utils.log import log
from pyhugegraph.client import PyHugeClient
# TODO: remove 'as('subj)' step
VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()"
# TODO: we could use a simpler query (like kneighbor-api to get the edges)
# TODO: test with profile()/explain() to speed up the query
VID_QUERY_NEIGHBOR_TPL = """\
g.V({keywords})
.repeat(
bothE({edge_labels}).limit({edge_limit}).otherV().dedup()
).times({max_deep}).emit()
.simplePath()
.path()
.by(project('label', 'id', 'props')
.by(label())
.by(id())
.by(valueMap().by(unfold()))
)
.by(project('label', 'inV', 'outV', 'props')
.by(label())
.by(inV().id())
.by(outV().id())
.by(valueMap().by(unfold()))
)
.limit({max_items})
.toList()
"""
PROPERTY_QUERY_NEIGHBOR_TPL = """\
g.V().has('{prop}', within({keywords}))
.repeat(
bothE({edge_labels}).limit({edge_limit}).otherV().dedup()
).times({max_deep}).emit()
.simplePath()
.path()
.by(project('label', 'props')
.by(label())
.by(valueMap().by(unfold()))
)
.by(project('label', 'inV', 'outV', 'props')
.by(label())
.by(inV().values('{prop}'))
.by(outV().values('{prop}'))
.by(valueMap().by(unfold()))
)
.limit({max_items})
.toList()
"""
class GraphRAGQuery:
def __init__(
self,
max_deep: int = 2,
max_graph_items: int = huge_settings.max_graph_items,
prop_to_match: Optional[str] = None,
llm: Optional[BaseLLM] = None,
embedding: Optional[BaseEmbedding] = None,
max_v_prop_len: Optional[int] = 2048,
max_e_prop_len: Optional[int] = 256,
num_gremlin_generate_example: Optional[int] = -1,
gremlin_prompt: Optional[str] = None,
):
self._client = PyHugeClient(
url=huge_settings.graph_url,
graph=huge_settings.graph_name,
user=huge_settings.graph_user,
pwd=huge_settings.graph_pwd,
graphspace=huge_settings.graph_space,
)
self._max_deep = max_deep
self._max_items = max_graph_items
self._prop_to_match = prop_to_match
self._schema = ""
self._limit_property = huge_settings.limit_property.lower() == "true"
self._max_v_prop_len = max_v_prop_len
self._max_e_prop_len = max_e_prop_len
self._gremlin_generator = GremlinGenerator(
llm=llm,
embedding=embedding,
)
self._num_gremlin_generate_example = num_gremlin_generate_example
self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
self.init_client(context)
# initial flag: -1 means no result, 0 means subgraph query, 1 means gremlin query
context["graph_result_flag"] = -1
# 1. Try to perform a query based on the generated gremlin
if self._num_gremlin_generate_example >= 0:
context = self._gremlin_generate_query(context)
# 2. Try to perform a query based on subgraph-search if the previous query failed
if not context.get("graph_result"):
context = self._subgraph_query(context)
if context.get("graph_result"):
log.debug("Knowledge from Graph:\n%s", "\n".join(context["graph_result"]))
else:
log.debug("No Knowledge Extracted from Graph")
return context
def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
query = context["query"]
vertices = context.get("match_vids")
query_embedding = context.get("query_embedding")
self._gremlin_generator.clear()
self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example)
gremlin_response = self._gremlin_generator.gremlin_generate_synthesize(
context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt
).run(query=query, query_embedding=query_embedding)
if self._num_gremlin_generate_example > 0:
gremlin = gremlin_response["result"]
else:
gremlin = gremlin_response["raw_result"]
log.info("Generated gremlin: %s", gremlin)
context["gremlin"] = gremlin
try:
result = self._client.gremlin().exec(gremlin=gremlin)["data"]
if result == [None]:
result = []
context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result]
if context["graph_result"]:
context["graph_result_flag"] = 1
context["graph_context_head"] = (
f"The following are graph query result " f"from gremlin query `{gremlin}`.\n"
)
except Exception as e: # pylint: disable=broad-except
log.error(e)
context["graph_result"] = ""
return context
def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]:
# 1. Extract params from context
matched_vids = context.get("match_vids")
if isinstance(context.get("max_deep"), int):
self._max_deep = context["max_deep"]
if isinstance(context.get("max_items"), int):
self._max_items = context["max_items"]
if isinstance(context.get("prop_to_match"), str):
self._prop_to_match = context["prop_to_match"]
# 2. Extract edge_labels from graph schema
_, edge_labels = self._extract_labels_from_schema()
edge_labels_str = ",".join("'" + label + "'" for label in edge_labels)
# TODO: enhance the limit logic later
edge_limit_amount = len(edge_labels) * huge_settings.edge_limit_pre_label
use_id_to_match = self._prop_to_match is None
if use_id_to_match:
if not matched_vids:
return context
gremlin_query = VERTEX_QUERY_TPL.format(keywords=matched_vids)
vertexes = self._client.gremlin().exec(gremlin=gremlin_query)["data"]
log.debug("Vids gremlin query: %s", gremlin_query)
vertex_knowledge = self._format_graph_from_vertex(query_result=vertexes)
paths: List[Any] = []
# TODO: use generator or asyncio to speed up the query logic
for matched_vid in matched_vids:
gremlin_query = VID_QUERY_NEIGHBOR_TPL.format(
keywords=f"'{matched_vid}'",
max_deep=self._max_deep,
edge_labels=edge_labels_str,
edge_limit=edge_limit_amount,
max_items=self._max_items,
)
log.debug("Kneighbor gremlin query: %s", gremlin_query)
paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"])
graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_query_result(
query_paths=paths
)
# TODO: we may need to optimize the logic here with global deduplication (may lack some single vertex)
if not graph_chain_knowledge:
graph_chain_knowledge.update(vertex_knowledge)
if vertex_degree_list:
vertex_degree_list[0].update(vertex_knowledge)
else:
vertex_degree_list.append(vertex_knowledge)
else:
# WARN: When will the query enter here?
keywords = context.get("keywords")
assert keywords, "No related property(keywords) for graph query."
keywords_str = ",".join("'" + kw + "'" for kw in keywords)
gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format(
prop=self._prop_to_match,
keywords=keywords_str,
edge_labels=edge_labels_str,
edge_limit=edge_limit_amount,
max_deep=self._max_deep,
max_items=self._max_items,
)
log.warning("Unable to find vid, downgraded to property query, please confirm if it meets expectation.")
paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"]
graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_query_result(
query_paths=paths
)
context["graph_result"] = list(graph_chain_knowledge)
if context["graph_result"]:
context["graph_result_flag"] = 0
context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list]
context["knowledge_with_degree"] = knowledge_with_degree
context["graph_context_head"] = (
f"The following are graph knowledge in {self._max_deep} depth, e.g:\n"
"`vertexA--[links]-->vertexB<--[links]--vertexC ...`"
"extracted based on key entities as subject:\n"
)
return context
# TODO: move this method to a util file for reuse (remove self param)
def init_client(self, context):
"""Initialize the HugeGraph client from context or default settings."""
# pylint: disable=R0915 (too-many-statements)
if self._client is None:
if isinstance(context.get("graph_client"), PyHugeClient):
self._client = context["graph_client"]
else:
url = context.get("url") or "http://localhost:8080"
graph = context.get("graph") or "hugegraph"
user = context.get("user") or "admin"
pwd = context.get("pwd") or "admin"
gs = context.get("graphspace") or None
self._client = PyHugeClient(url, graph, user, pwd, gs)
assert self._client is not None, "No valid graph to search."
def get_vertex_details(self, vertex_ids: List[str]) -> List[Dict[str, Any]]:
if not vertex_ids:
return []
formatted_ids = ", ".join(f"'{vid}'" for vid in vertex_ids)
gremlin_query = f"g.V({formatted_ids}).limit(20)"
result = self._client.gremlin().exec(gremlin=gremlin_query)["data"]
return result
def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]:
knowledge = set()
for item in query_result:
props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items())
node_str = f"{item['id']}{{{props_str}}}"
knowledge.add(node_str)
return knowledge
def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]:
use_id_to_match = self._prop_to_match is None
subgraph = set()
subgraph_with_degree = {}
vertex_degree_list: List[Set[str]] = []
v_cache: Set[str] = set()
e_cache: Set[Tuple[str, str, str]] = set()
for path in query_paths:
# 1. Process each path
path_str, vertex_with_degree = self._process_path(path, use_id_to_match, v_cache, e_cache)
subgraph.add(path_str)
subgraph_with_degree[path_str] = vertex_with_degree
# 2. Update vertex degree list
self._update_vertex_degree_list(vertex_degree_list, vertex_with_degree)
return subgraph, vertex_degree_list, subgraph_with_degree
def _process_path(
self, path: Any, use_id_to_match: bool, v_cache: Set[str], e_cache: Set[Tuple[str, str, str]]
) -> Tuple[str, List[str]]:
flat_rel = ""
raw_flat_rel = path["objects"]
assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should be odd."
node_cache = set()
prior_edge_str_len = 0
depth = 0
nodes_with_degree = []
for i, item in enumerate(raw_flat_rel):
if i % 2 == 0:
# Process each vertex
flat_rel, prior_edge_str_len, depth = self._process_vertex(
item, flat_rel, node_cache, prior_edge_str_len, depth, nodes_with_degree, use_id_to_match, v_cache
)
else:
# Process each edge
flat_rel, prior_edge_str_len = self._process_edge(
item, flat_rel, raw_flat_rel, i, use_id_to_match, e_cache
)
return flat_rel, nodes_with_degree
def _process_vertex(
self,
item: Any,
flat_rel: str,
node_cache: Set[str],
prior_edge_str_len: int,
depth: int,
nodes_with_degree: List[str],
use_id_to_match: bool,
v_cache: Set[str],
) -> Tuple[str, int, int]:
matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match]
if matched_str in node_cache:
flat_rel = flat_rel[:-prior_edge_str_len]
return flat_rel, prior_edge_str_len, depth
node_cache.add(matched_str)
props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v)
# TODO: we may remove label id or replace with label name
if matched_str in v_cache:
node_str = matched_str
else:
v_cache.add(matched_str)
node_str = f"{item['id']}{{{props_str}}}"
flat_rel += node_str
nodes_with_degree.append(node_str)
depth += 1
return flat_rel, prior_edge_str_len, depth
def _process_edge(
self,
item: Any,
path_str: str,
raw_flat_rel: List[Any],
i: int,
use_id_to_match: bool,
e_cache: Set[Tuple[str, str, str]],
) -> Tuple[str, int]:
props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v)
props_str = f"{{{props_str}}}" if props_str else ""
prev_matched_str = (
raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i - 1]["props"][self._prop_to_match]
)
edge_key = (item["inV"], item["label"], item["outV"])
if edge_key not in e_cache:
e_cache.add(edge_key)
edge_label = f"{item['label']}{props_str}"
else:
edge_label = item["label"]
edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--"
path_str += edge_str
prior_edge_str_len = len(edge_str)
return path_str, prior_edge_str_len
def _update_vertex_degree_list(self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str]) -> None:
for depth, node_str in enumerate(nodes_with_degree):
if depth >= len(vertex_degree_list):
vertex_degree_list.append(set())
vertex_degree_list[depth].add(node_str)
def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]:
schema = self._get_graph_schema()
vertex_props_str, edge_props_str = schema.split("\n")[:2]
# TODO: rename to vertex (also need update in the schema)
vertex_props_str = vertex_props_str[len("Vertex properties: "):].strip("[").strip("]")
edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]")
vertex_labels = self._extract_label_names(vertex_props_str)
edge_labels = self._extract_label_names(edge_props_str)
return vertex_labels, edge_labels
@staticmethod
def _extract_label_names(source: str, head: str = "name: ", tail: str = ", ") -> List[str]:
result = []
for s in source.split(head):
end = s.find(tail)
label = s[:end]
if label:
result.append(label)
return result
def _get_graph_schema(self, refresh: bool = False) -> str:
if self._schema and not refresh:
return self._schema
schema = self._client.schema()
vertex_schema = schema.getVertexLabels()
edge_schema = schema.getEdgeLabels()
relationships = schema.getRelations()
self._schema = (
f"Vertex properties: {vertex_schema}\n"
f"Edge properties: {edge_schema}\n"
f"Relationships: {relationships}\n"
)
log.debug("Link(Relation): %s", relationships)
return self._schema
def _limit_property_query(self, value: Optional[str], item_type: str) -> Optional[str]:
# NOTE: we skip the filter for list/set type (e.g., list of string, add it if needed)
if not self._limit_property or not isinstance(value, str):
return value
max_len = self._max_v_prop_len if item_type == "v" else self._max_e_prop_len
return value[:max_len] if value else value