| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################# |
| from typing import Any, Dict, List, cast |
| |
| from pydantic import BaseModel, field_serializer, model_validator |
| |
| from flink_agents.api.agent import Agent |
| from flink_agents.api.resource import Resource, ResourceType |
| from flink_agents.api.tools.mcp import MCPServer |
| from flink_agents.plan.actions.action import Action |
| from flink_agents.plan.actions.chat_model_action import CHAT_MODEL_ACTION |
| from flink_agents.plan.actions.context_retrieval_action import CONTEXT_RETRIEVAL_ACTION |
| from flink_agents.plan.actions.tool_call_action import TOOL_CALL_ACTION |
| from flink_agents.plan.configuration import AgentConfiguration |
| from flink_agents.plan.function import PythonFunction |
| from flink_agents.plan.resource_provider import ( |
| JavaResourceProvider, |
| JavaSerializableResourceProvider, |
| PythonResourceProvider, |
| PythonSerializableResourceProvider, |
| ResourceProvider, |
| ) |
| from flink_agents.plan.tools.function_tool import from_callable |
| |
| BUILT_IN_ACTIONS = [CHAT_MODEL_ACTION, TOOL_CALL_ACTION, CONTEXT_RETRIEVAL_ACTION] |
| |
| |
| class AgentPlan(BaseModel): |
| """Agent plan compiled from user defined agent. |
| |
| Attributes: |
| ---------- |
| actions: Dict[str, Action] |
| Mapping of action names to actions |
| actions_by_event : Dict[Type[Event], str] |
| Mapping of event types to the list of actions name that listen to them. |
| resource_providers: ResourceProvider |
| Two level mapping of resource type to resource name to resource provider. |
| """ |
| |
| actions: Dict[str, Action] |
| actions_by_event: Dict[str, List[str]] |
| resource_providers: Dict[ResourceType, Dict[str, ResourceProvider]] | None = None |
| config: AgentConfiguration | None = None |
| __resources: Dict[ResourceType, Dict[str, Resource]] = {} |
| |
| @field_serializer("resource_providers") |
| def __serialize_resource_providers( |
| self, providers: Dict[ResourceType, Dict[str, ResourceProvider]] |
| ) -> dict: |
| # append meta info to help deserialize resource providers |
| data = {} |
| for type in providers: |
| data[type] = {} |
| for name, provider in providers[type].items(): |
| data[type][name] = provider.model_dump() |
| if isinstance(provider, PythonResourceProvider): |
| data[type][name]["__resource_provider_type__"] = ( |
| "PythonResourceProvider" |
| ) |
| elif isinstance(provider, PythonSerializableResourceProvider): |
| data[type][name]["__resource_provider_type__"] = ( |
| "PythonSerializableResourceProvider" |
| ) |
| elif isinstance(provider, JavaResourceProvider): |
| data[type][name]["__resource_provider_type__"] = ( |
| "JavaResourceProvider" |
| ) |
| elif isinstance(provider, JavaSerializableResourceProvider): |
| data[type][name]["__resource_provider_type__"] = ( |
| "JavaSerializableResourceProvider" |
| ) |
| return data |
| |
| @model_validator(mode="before") |
| def __custom_deserialize(self) -> "AgentPlan": |
| if "resource_providers" in self: |
| providers = self["resource_providers"] |
| # restore exec from serialized json. |
| if isinstance(providers, dict): |
| for type in providers: |
| for name, provider in providers[type].items(): |
| if isinstance(provider, dict): |
| provider_type = provider["__resource_provider_type__"] |
| if provider_type == "PythonResourceProvider": |
| self["resource_providers"][type][name] = ( |
| PythonResourceProvider.model_validate(provider) |
| ) |
| elif provider_type == "PythonSerializableResourceProvider": |
| self["resource_providers"][type][name] = ( |
| PythonSerializableResourceProvider.model_validate( |
| provider |
| ) |
| ) |
| elif provider_type == "JavaResourceProvider": |
| self["resource_providers"][type][name] = ( |
| JavaResourceProvider.model_validate(provider) |
| ) |
| elif provider_type == "JavaSerializableResourceProvider": |
| self["resource_providers"][type][name] = ( |
| JavaSerializableResourceProvider.model_validate( |
| provider |
| ) |
| ) |
| return self |
| |
| @staticmethod |
| def from_agent(agent: Agent, config: AgentConfiguration) -> "AgentPlan": |
| """Build a AgentPlan from user defined agent.""" |
| actions = {} |
| actions_by_event = {} |
| for action in _get_actions(agent) + BUILT_IN_ACTIONS: |
| assert action.name not in actions, f"Duplicate action name: {action.name}" |
| actions[action.name] = action |
| for event_type in action.listen_event_types: |
| if event_type not in actions_by_event: |
| actions_by_event[event_type] = [] |
| actions_by_event[event_type].append(action.name) |
| |
| resource_providers = {} |
| for provider in _get_resource_providers(agent): |
| type = provider.type |
| if type not in resource_providers: |
| resource_providers[type] = {} |
| name = provider.name |
| assert name not in resource_providers[type], ( |
| f"Duplicate resource name: {name}" |
| ) |
| resource_providers[type][name] = provider |
| return AgentPlan( |
| actions=actions, |
| actions_by_event=actions_by_event, |
| resource_providers=resource_providers, |
| config=config, |
| ) |
| |
| def get_actions(self, event_type: str) -> List[Action]: |
| """Get actions that listen to the specified event type. |
| |
| Parameters |
| ---------- |
| event_type : Type[Event] |
| The event type to query. |
| |
| Returns: |
| ------- |
| list[Action] |
| List of Actions that will respond to this event type. |
| """ |
| return [self.actions[name] for name in self.actions_by_event[event_type]] |
| |
| def get_action_config(self, action_name: str) -> Dict[str, Any]: |
| """Get config of the action. |
| |
| Parameters |
| ---------- |
| action_name : str |
| The name of the action. |
| |
| Returns: |
| ------- |
| Dict[str, Any] |
| The config of action. |
| """ |
| return self.actions[action_name].config |
| |
| def get_action_config_value(self, action_name: str, key: str) -> Any: |
| """Get config of the action. |
| |
| Parameters |
| ---------- |
| action_name : str |
| The name of the action. |
| key : str |
| The name of the option. |
| |
| Returns: |
| ------- |
| Dict[str, Any] |
| The option value of the action config. |
| """ |
| return self.actions[action_name].config.get(key, None) |
| |
| def get_resource(self, name: str, type: ResourceType) -> Resource: |
| """Get resource from agent plan. |
| |
| Parameters |
| ---------- |
| name : str |
| The name of the resource. |
| type : ResourceType |
| The type of the resource. |
| """ |
| if type not in self.__resources: |
| self.__resources[type] = {} |
| if name not in self.__resources[type]: |
| resource_provider = self.resource_providers[type][name] |
| resource = resource_provider.provide( |
| get_resource=self.get_resource, config=self.config |
| ) |
| self.__resources[type][name] = resource |
| return self.__resources[type][name] |
| |
| |
| def _get_actions(agent: Agent) -> List[Action]: |
| """Extract all registered agent actions from an agent. |
| |
| Parameters |
| ---------- |
| agent : Agent |
| The agent to be analyzed. |
| |
| Returns: |
| ------- |
| List[Action] |
| List of Action defined in the agent. |
| """ |
| actions = [] |
| for name, value in agent.__class__.__dict__.items(): |
| if isinstance(value, staticmethod) and hasattr(value, "_listen_events"): |
| actions.append( |
| Action( |
| name=name, |
| exec=PythonFunction.from_callable(value.__func__), |
| listen_event_types=[ |
| f"{event_type.__module__}.{event_type.__name__}" |
| for event_type in value._listen_events |
| ], |
| ) |
| ) |
| elif callable(value) and hasattr(value, "_listen_events"): |
| actions.append( |
| Action( |
| name=name, |
| exec=PythonFunction.from_callable(value), |
| listen_event_types=[ |
| f"{event_type.__module__}.{event_type.__name__}" |
| for event_type in value._listen_events |
| ], |
| ) |
| ) |
| for name, action in agent.actions.items(): |
| actions.append( |
| Action( |
| name=name, |
| exec=PythonFunction.from_callable(action[1]), |
| listen_event_types=[ |
| f"{event_type.__module__}.{event_type.__name__}" |
| for event_type in action[0] |
| ], |
| config=action[2], |
| ) |
| ) |
| return actions |
| |
| |
| def _get_resource_providers(agent: Agent) -> List[ResourceProvider]: |
| resource_providers = [] |
| # retrieve resource declared by decorator |
| for name, value in agent.__class__.__dict__.items(): |
| if ( |
| hasattr(value, "_is_chat_model_setup") |
| or hasattr(value, "_is_chat_model_connection") |
| or hasattr(value, "_is_embedding_model_setup") |
| or hasattr(value, "_is_embedding_model_connection") |
| or hasattr(value, "_is_vector_store") |
| ): |
| if isinstance(value, staticmethod): |
| value = value.__func__ |
| |
| if callable(value): |
| resource_providers.append( |
| PythonResourceProvider.get(name=name, descriptor=value()) |
| ) |
| |
| elif hasattr(value, "_is_tool"): |
| if isinstance(value, staticmethod): |
| value = value.__func__ |
| |
| if callable(value): |
| # TODO: support other tool type. |
| tool = from_callable(func=value) |
| resource_providers.append( |
| PythonSerializableResourceProvider.from_resource( |
| name=name, resource=tool |
| ) |
| ) |
| elif hasattr(value, "_is_prompt"): |
| if isinstance(value, staticmethod): |
| value = value.__func__ |
| prompt = value() |
| resource_providers.append( |
| PythonSerializableResourceProvider.from_resource( |
| name=name, resource=prompt |
| ) |
| ) |
| elif hasattr(value, "_is_mcp_server"): |
| if isinstance(value, staticmethod): |
| value = value.__func__ |
| |
| mcp_server = value() |
| _add_mcp_server(name, resource_providers, mcp_server) |
| |
| # retrieve resource declared by add interface |
| for name, prompt in agent.resources[ResourceType.PROMPT].items(): |
| resource_providers.append( |
| PythonSerializableResourceProvider.from_resource(name=name, resource=prompt) |
| ) |
| |
| for name, tool in agent.resources[ResourceType.TOOL].items(): |
| resource_providers.append( |
| PythonSerializableResourceProvider.from_resource( |
| name=name, resource=from_callable(tool.func) |
| ) |
| ) |
| |
| for name, mcp_server in agent.resources[ResourceType.MCP_SERVER].items(): |
| mcp_server = cast("MCPServer", mcp_server) |
| _add_mcp_server(name, resource_providers, mcp_server) |
| |
| for resource_type in [ |
| ResourceType.CHAT_MODEL, |
| ResourceType.CHAT_MODEL_CONNECTION, |
| ResourceType.EMBEDDING_MODEL, |
| ResourceType.EMBEDDING_MODEL_CONNECTION, |
| ResourceType.VECTOR_STORE, |
| ]: |
| for name, descriptor in agent.resources[resource_type].items(): |
| resource_providers.append( |
| PythonResourceProvider.get(name=name, descriptor=descriptor) |
| ) |
| |
| return resource_providers |
| |
| |
| def _add_mcp_server( |
| name: str, resource_providers: List[ResourceProvider], mcp_server: MCPServer |
| ) -> None: |
| resource_providers.append( |
| PythonSerializableResourceProvider.from_resource(name=name, resource=mcp_server) |
| ) |
| resource_providers.extend( |
| [ |
| PythonSerializableResourceProvider.from_resource( |
| name=prompt.name, resource=prompt |
| ) |
| for prompt in mcp_server.list_prompts() |
| ] |
| ) |
| |
| resource_providers.extend( |
| [ |
| PythonSerializableResourceProvider.from_resource( |
| name=tool.name, resource=tool |
| ) |
| for tool in mcp_server.list_tools() |
| ] |
| ) |
| |
| mcp_server.close() |