blob: 33be1cd779dc17b05c8d070a429a84b76b752b38 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import re
from typing import List, Any, Dict
from hugegraph_llm.config import prompt
from hugegraph_llm.document.chunk_split import ChunkSplitter
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
SCHEMA_EXAMPLE_PROMPT = prompt.extract_graph_prompt
def generate_extract_property_graph_prompt(text, schema=None) -> str:
return f"""---
Following the full instructions above, try to extract the following text from the given schema, output the JSON result:
# Input
## Text:
{text}
## Graph schema
{schema}
# Output"""
def split_text(text: str) -> List[str]:
chunk_splitter = ChunkSplitter(split_type="paragraph", language="zh")
chunks = chunk_splitter.split(text)
return chunks
def filter_item(schema, items) -> List[Dict[str, Any]]:
# filter vertex and edge with invalid properties
filtered_items = []
properties_map = {"vertex": {}, "edge": {}}
for vertex in schema["vertexlabels"]:
properties_map["vertex"][vertex["name"]] = {
"primary_keys": vertex["primary_keys"],
"nullable_keys": vertex["nullable_keys"],
"properties": vertex["properties"]
}
for edge in schema["edgelabels"]:
properties_map["edge"][edge["name"]] = {
"properties": edge["properties"]
}
log.info("properties_map: %s", properties_map)
for item in items:
item_type = item["type"]
if item_type == "vertex":
label = item["label"]
non_nullable_keys = (
set(properties_map[item_type][label]["properties"])
.difference(set(properties_map[item_type][label]["nullable_keys"])))
for key in non_nullable_keys:
if key not in item["properties"]:
item["properties"][key] = "NULL"
for key, value in item["properties"].items():
if not isinstance(value, str):
item["properties"][key] = str(value)
filtered_items.append(item)
return filtered_items
class PropertyGraphExtract:
def __init__(
self,
llm: BaseLLM,
example_prompt: str = SCHEMA_EXAMPLE_PROMPT
) -> None:
self.llm = llm
self.example_prompt = example_prompt
self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name
def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]:
schema = context["schema"]
chunks = context["chunks"]
if "vertices" not in context:
context["vertices"] = []
if "edges" not in context:
context["edges"] = []
items = []
for chunk in chunks:
proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk)
log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, chunk, proceeded_chunk)
items.extend(self._extract_and_filter_label(schema, proceeded_chunk))
items = filter_item(schema, items)
for item in items:
if item["type"] == "vertex":
context["vertices"].append(item)
elif item["type"] == "edge":
context["edges"].append(item)
context["call_count"] = context.get("call_count", 0) + len(chunks)
return context
def extract_property_graph_by_llm(self, schema, chunk):
prompt = generate_extract_property_graph_prompt(chunk, schema)
if self.example_prompt is not None:
prompt = self.example_prompt + prompt
return self.llm.generate(prompt=prompt)
def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]:
# analyze llm generated text to JSON
json_strings = re.findall(r'(\[.*?])', text, re.DOTALL)
longest_json = max(json_strings, key=lambda x: len(''.join(x)), default=('', ''))
longest_json_str = ''.join(longest_json).strip()
items = []
try:
property_graph = json.loads(longest_json_str)
vertex_label_set = {vertex["name"] for vertex in schema["vertexlabels"]}
edge_label_set = {edge["name"] for edge in schema["edgelabels"]}
for item in property_graph:
if not isinstance(item, dict):
log.warning("Invalid property graph item type '%s'.", type(item))
continue
if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()):
log.warning("Invalid item keys '%s'.", item.keys())
continue
if item["type"] == "vertex" or item["type"] == "edge":
if (item["label"] not in vertex_label_set
and item["label"] not in edge_label_set):
log.warning("Invalid '%s' label '%s' has been ignored.", item["type"], item["label"])
else:
items.append(item)
else:
log.warning("Invalid item type '%s' has been ignored.", item["type"])
except json.JSONDecodeError:
log.critical("Invalid property graph! Please check the extracted JSON data carefully")
return items