blob: 1ef7e5c55c0966f0e30f1172a070b993b4fd09ed [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from PyCGraph import CStatus
from hugegraph_llm.nodes.base_node import BaseNode
from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState
from hugegraph_llm.models.llms.init_llm import get_chat_llm
from hugegraph_llm.config import llm_settings
from hugegraph_llm.operators.llm_op.schema_build import SchemaBuilder
from hugegraph_llm.utils.log import log
class SchemaBuildNode(BaseNode):
schema_builder: SchemaBuilder
context: WkFlowState = None
wk_input: WkFlowInput = None
def node_init(self):
llm = get_chat_llm(llm_settings)
self.schema_builder = SchemaBuilder(llm)
# texts -> raw_texts
raw_texts = []
if self.wk_input.texts:
if isinstance(self.wk_input.texts, list):
raw_texts = [t for t in self.wk_input.texts if isinstance(t, str)]
elif isinstance(self.wk_input.texts, str):
raw_texts = [self.wk_input.texts]
# query_examples: already parsed list[dict] or raw JSON string
query_examples = []
qe_src = self.wk_input.query_examples if self.wk_input.query_examples else None
if qe_src:
try:
parsed_examples = json.loads(qe_src)
# Validate and retain the description and gremlin fields
query_examples = [
{
"description": ex.get("description", ""),
"gremlin": ex.get("gremlin", ""),
}
for ex in parsed_examples
if isinstance(ex, dict) and "description" in ex and "gremlin" in ex
]
except json.JSONDecodeError as e:
return CStatus(-1, f"Query Examples is not in a valid JSON format: {e}")
# few_shot_schema: already parsed dict or raw JSON string
few_shot_schema = {}
fss_src = (
self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None
)
if fss_src:
try:
few_shot_schema = json.loads(fss_src)
except json.JSONDecodeError as e:
return CStatus(
-1, f"Few Shot Schema is not in a valid JSON format: {e}"
)
_context_payload = {
"raw_texts": raw_texts,
"query_examples": query_examples,
"few_shot_schema": few_shot_schema,
}
self.context.assign_from_json(_context_payload)
return super().node_init()
def operator_schedule(self, data_json):
try:
schema_result = self.schema_builder.run(data_json)
return {"schema": schema_result}
except (ValueError, RuntimeError) as e:
log.error("Failed to generate schema: %s", e)
return {"schema": f"Schema generation failed: {e}"}