diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 140c69018e9b8a..4da3355b894360 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,8 +1,10 @@ from collections.abc import Generator -from typing import cast +from typing import Any, Sequence, cast +from core.agent.plugin_entities import AgentParameter from core.plugin.manager.exc import PluginDaemonClientSideError from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import RunCompletedEvent @@ -46,14 +48,14 @@ def _run(self) -> Generator: # get parameters parameters = self._generate_parameters( - tool_parameters=agent_parameters, + agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=node_data, ) parameters_for_log = self._generate_parameters( - tool_parameters=agent_parameters, + agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=node_data, for_log=True, ) @@ -84,3 +86,46 @@ def _run(self) -> Generator: error=f"Failed to transform agent message: {str(e)}", ) ) + + def _generate_parameters( + self, + *, + agent_parameters: Sequence[AgentParameter], + variable_pool: VariablePool, + node_data: AgentNodeData, + for_log: bool = False, + ) -> dict[str, Any]: + """ + Generate parameters based on the given tool parameters, variable pool, and node data. + + Args: + tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + variable_pool (VariablePool): The variable pool containing the variables. + node_data (ToolNodeData): The data associated with the tool node. + + Returns: + Mapping[str, Any]: A dictionary containing the generated parameters. + + """ + agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} + + result = {} + for parameter_name in node_data.agent_parameters: + parameter = agent_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + agent_input = node_data.agent_parameters[parameter_name] + if agent_input.type == "variable": + variable = variable_pool.get(agent_input.value) + if variable is None: + raise ValueError(f"Variable {agent_input.value} does not exist") + parameter_value = variable.value + elif agent_input.type in {"mixed", "constant"}: + segment_group = variable_pool.convert_template(str(agent_input.value)) + parameter_value = segment_group.log if for_log else segment_group.text + else: + raise ValueError(f"Unknown agent input type '{agent_input.type}'") + result[parameter_name] = parameter_value + + return result