blob: 35e926fea10eda4a9ebceb30d6e92c415a9a61b5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from pycgraph import GNode, CStatus
from hugegraph_llm.nodes.util import init_context
from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState
from hugegraph_llm.utils.log import log
class BaseNode(GNode):
"""
Base class for workflow nodes, providing context management and operation scheduling.
All custom nodes should inherit from this class and implement the operator_schedule method.
Attributes:
context: Shared workflow state
wk_input: Workflow input parameters
"""
context: Optional[WkFlowState] = None
wk_input: Optional[WkFlowInput] = None
def init(self):
return init_context(self)
def node_init(self):
"""
Node initialization method, can be overridden by subclasses.
Returns a CStatus object indicating whether initialization succeeded.
"""
if self.wk_input is None or self.context is None:
return CStatus(-1, "wk_input or context not initialized")
if self.wk_input.data_json is not None:
self.context.assign_from_json(self.wk_input.data_json)
self.wk_input.data_json = None
return CStatus()
def run(self):
"""
Main logic for node execution, can be overridden by subclasses.
Returns a CStatus object indicating whether execution succeeded.
"""
sts = self.node_init()
if sts.isErr():
return sts
if self.context is None:
return CStatus(-1, "Context not initialized")
self.context.lock()
try:
data_json = self.context.to_json()
finally:
self.context.unlock()
try:
res = self.operator_schedule(data_json)
except (ValueError, TypeError, KeyError, NotImplementedError) as exc:
import traceback
node_info = f"Node type: {type(self).__name__}, Node object: {self}"
err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}"
return CStatus(-1, err_msg)
# For unexpected exceptions, re-raise to let them propagate or be caught elsewhere
self.context.lock()
try:
if res is not None and isinstance(res, dict):
self.context.assign_from_json(res)
elif res is not None:
log.warning("operator_schedule returned non-dict type: %s", type(res))
finally:
self.context.unlock()
return CStatus()
def operator_schedule(self, data_json) -> Optional[Dict]:
"""
Operation scheduling method that must be implemented by subclasses.
Args:
data_json: Context serialized as JSON data
Returns:
Dictionary of processing results, or None to indicate no update
Raises:
NotImplementedError: If the subclass has not implemented this method
"""
raise NotImplementedError("Subclasses must implement operator_schedule")