blob: 3d1e01827c8dd0629f68337d6fef93ebb9213e85 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import asyncio
from typing import Any, Dict, Optional
from hugegraph_llm.models.llms.base import BaseLLM
from hugegraph_llm.models.llms.init_llm import LLMs
from hugegraph_llm.config import prompt
from hugegraph_llm.utils.log import log
DEFAULT_ANSWER_TEMPLATE = prompt.answer_prompt
class AnswerSynthesize:
def __init__(
self,
llm: Optional[BaseLLM] = None,
prompt_template: Optional[str] = None,
question: Optional[str] = None,
context_body: Optional[str] = None,
context_head: Optional[str] = None,
context_tail: Optional[str] = None,
raw_answer: bool = False,
vector_only_answer: bool = True,
graph_only_answer: bool = False,
graph_vector_answer: bool = False,
):
self._llm = llm
self._prompt_template = prompt_template or DEFAULT_ANSWER_TEMPLATE
self._question = question
self._context_body = context_body
self._context_head = context_head
self._context_tail = context_tail
self._raw_answer = raw_answer
self._vector_only_answer = vector_only_answer
self._graph_only_answer = graph_only_answer
self._graph_vector_answer = graph_vector_answer
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._llm is None:
self._llm = context.get("llm") or LLMs().get_llm()
if context.get("llm") is None:
context["llm"] = self._llm
if self._question is None:
self._question = context.get("query") or None
assert self._question is not None, "No question for synthesizing."
context_head_str = context.get("synthesize_context_head") or self._context_head or ""
context_tail_str = context.get("synthesize_context_tail") or self._context_tail or ""
if self._context_body is not None:
context_str = (f"{context_head_str}\n"
f"{self._context_body}\n"
f"{context_tail_str}".strip("\n"))
prompt = self._prompt_template.format(
context_str=context_str,
query_str=self._question,
)
response = self._llm.generate(prompt=prompt)
return {"answer": response}
vector_result = context.get("vector_result")
if vector_result:
vector_result_context = "Phrases related to the query:\n" + "\n".join(
f"{i + 1}. {res}" for i, res in enumerate(vector_result)
)
else:
vector_result_context = "No (vector)phrase related to the query."
graph_result = context.get("graph_result")
if graph_result:
graph_context_head = context.get("graph_context_head", "Knowledge from graphdb for the query:\n")
graph_result_context = graph_context_head + "\n".join(
f"{i + 1}. {res}" for i, res in enumerate(graph_result)
)
else:
graph_result_context = "No related graph data found for current query."
log.warning(graph_result_context)
context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str,
vector_result_context, graph_result_context))
return context
async def async_generate(self, context: Dict[str, Any], context_head_str: str,
context_tail_str: str, vector_result_context: str,
graph_result_context: str):
verbose = context.get("verbose") or False
# TODO: replace task_cache with a better name
task_cache = {}
if self._raw_answer:
prompt = self._question
task_cache["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt))
if self._vector_only_answer:
context_str = (f"{context_head_str}\n"
f"{vector_result_context}\n"
f"{context_tail_str}".strip("\n"))
prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt))
if self._graph_only_answer:
context_str = (f"{context_head_str}\n"
f"{graph_result_context}\n"
f"{context_tail_str}".strip("\n"))
prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=prompt))
if self._graph_vector_answer:
context_body_str = f"{vector_result_context}\n{graph_result_context}"
if context.get("graph_ratio", 0.5) < 0.5:
context_body_str = f"{graph_result_context}\n{vector_result_context}"
context_str = (f"{context_head_str}\n"
f"{context_body_str}\n"
f"{context_tail_str}".strip("\n"))
prompt = self._prompt_template.format(context_str=context_str, query_str=self._question)
task_cache["graph_vector_task"] = asyncio.create_task(
self._llm.agenerate(prompt=prompt)
)
# TODO: use log.debug instead of print
if task_cache.get("raw_task"):
response = await task_cache["raw_task"]
context["raw_answer"] = response
if verbose:
print(f"\033[91mANSWER: {response}\033[0m")
if task_cache.get("vector_only_task"):
response = await task_cache["vector_only_task"]
context["vector_only_answer"] = response
if verbose:
print(f"\033[91mANSWER: {response}\033[0m")
if task_cache.get("graph_only_task"):
response = await task_cache["graph_only_task"]
context["graph_only_answer"] = response
if verbose:
print(f"\033[91mANSWER: {response}\033[0m")
if task_cache.get("graph_vector_task"):
response = await task_cache["graph_vector_task"]
context["graph_vector_answer"] = response
if verbose:
print(f"\033[91mANSWER: {response}\033[0m")
ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer])
context['call_count'] = context.get('call_count', 0) + ops
return context