blob: 42bb6b108e67767bfbe657c714ddefa91e3bce2e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from typing import List, Any, Dict, Optional
from hugegraph_llm.document.chunk_split import ChunkSplitter
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.utils.log import log
SCHEMA_EXAMPLE_PROMPT = """## Main Task
Extract Triples from the given text and graph schema
## Basic Rules
1. The output format must be: (X,Y,Z) - LABEL
In this format, Y must be a value from "properties" or "edge_label",
and LABEL must be X's vertex_label or Y's edge_label.
2. Don't extract attribute/property fields that do not exist in the given schema
3. Ensure the extract property is in the same type as the schema (like 'age' should be a number)
4. Translate the given schema filed into Chinese if the given text is Chinese but the schema is in English (Optional)
## Example (Note: Update the example to correspond to the given text and schema)
### Input example:
Graph schema:
{"vertices":[{"vertex_label":"person","properties":["name","age","occupation"]}], "edges":[{"edge_label":"roommate",
"source_vertex_label":"person","target_vertex_label":"person","properties":["date"]]}
Text:
Meet Sarah, a 30-year-old attorney, and her roommate,
James, whom she's shared a home with since 2010. James,
in his professional life, works as a journalist.
### Output example:
(Sarah, name, Sarah) - person
(Sarah, age, 30) - person
(Sarah, occupation, attorney) - person
(James, name, James) - person
(James, occupation, journalist) - person
(Sarah, roommate, James) - roommate
(James, roommate, Sarah) - roommate
(Sarah, date, 2010) - roommate
"""
def generate_extract_triple_prompt(text, schema=None) -> str:
text_based_prompt = f"""
Extract subject-verb-object (SPO) triples from text strictly according to the
following format, each structure has only three elements: ("vertex_1", "edge", "vertex_2").
For example:
Alice lawyer and is 25 years old and Bob is her roommate since 2001. Bob works as a journalist.
Alice owns the webpage www.alice.com and Bob owns the webpage www.bob.com
Output: [("Alice", "Age", "25"),("Alice", "Profession", "lawyer"),("Bob", "Job", "journalist"),
("Alice", "Roommate of", "Bob"),("Alice", "Owns", "https://www.alice.com"),
("Bob", "Owns", "https://www.bob.com")]
The extracted text is: {text}"""
schema_real_prompt = f"""## Real result
1. The extracted text is: {text}
2. The graph schema is: {schema}
"""
if schema:
return schema_real_prompt
log.warning("Recommend to provide a graph schema to improve the extraction accuracy. "
"Now using the default schema.")
return text_based_prompt
def split_text(text: str) -> List[str]:
chunk_splitter = ChunkSplitter(split_type="paragraph", language="en")
chunks = chunk_splitter.split(text)
return chunks
def extract_triples_by_regex(text, triples):
text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ")
pattern = r"\((.*?), (.*?), (.*?)\)"
triples["triples"] += re.findall(pattern, text)
def extract_triples_by_regex_with_schema(schema, text, graph):
text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ")
pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)"
matches = re.findall(pattern, text)
vertices_dict = {v["id"]: v for v in graph["vertices"]}
for match in matches:
s, p, o, label = [item.strip() for item in match]
if None in [label, s, p, o]:
continue
# TODO: use a more efficient way to compare the extract & input property
p_lower = p.lower()
for vertex in schema["vertices"]:
if vertex["vertex_label"] == label and any(pp.lower() == p_lower
for pp in vertex["properties"]):
id = f"{label}-{s}"
if id not in vertices_dict:
vertices_dict[id] = {"id": id, "name": s, "label": label, "properties": {p: o}}
else:
vertices_dict[id]["properties"].update({p: o})
break
for edge in schema["edges"]:
if edge["edge_label"] == label:
source_label = edge["source_vertex_label"]
source_id = f"{source_label}-{s}"
if source_id not in vertices_dict:
vertices_dict[source_id] = {"id": source_id, "name": s, "label": source_label,
"properties": {}}
target_label = edge["target_vertex_label"]
target_id = f"{target_label}-{o}"
if target_id not in vertices_dict:
vertices_dict[target_id] = {"id": target_id, "name": o, "label": target_label,
"properties": {}}
graph["edges"].append({"start": source_id, "end": target_id, "type": label,
"properties": {}})
break
graph["vertices"] = vertices_dict.values()
class InfoExtract:
def __init__(
self,
llm: BaseLLM,
example_prompt: Optional[str] = None
) -> None:
self.llm = llm
self.example_prompt = example_prompt
def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]:
chunks = context["chunks"]
schema = context["schema"]
if schema:
context["vertices"] = []
context["edges"] = []
else:
context["triples"] = []
for sentence in chunks:
proceeded_chunk = self.extract_triples_by_llm(schema, sentence)
log.debug("[Legacy] %s input: %s \n output:%s", self.__class__.__name__, sentence, proceeded_chunk)
if schema:
extract_triples_by_regex_with_schema(schema, proceeded_chunk, context)
else:
extract_triples_by_regex(proceeded_chunk, context)
context["call_count"] = context.get("call_count", 0) + len(chunks)
return self._filter_long_id(context)
def extract_triples_by_llm(self, schema, chunk) -> str:
prompt = generate_extract_triple_prompt(chunk, schema)
if self.example_prompt is not None:
prompt = self.example_prompt + prompt
return self.llm.generate(prompt=prompt)
# TODO: make 'max_length' be a configurable param in settings.py/settings.cfg
def valid(self, element_id: str, max_length: int = 256) -> bool:
if len(element_id.encode("utf-8")) >= max_length:
log.warning("Filter out GraphElementID too long: %s", element_id)
return False
return True
def _filter_long_id(self, graph) -> Dict[str, List[Any]]:
graph["vertices"] = [vertex for vertex in graph["vertices"] if self.valid(vertex["id"])]
graph["edges"] = [edge for edge in graph["edges"]
if self.valid(edge["start"]) and self.valid(edge["end"])]
return graph