From 1ec278d49251ce69c2042cf4e42b1e29f4fa313a Mon Sep 17 00:00:00 2001 From: Nov1c444 <857526207@qq.com> Date: Fri, 15 Nov 2024 11:20:07 +0800 Subject: [PATCH 01/16] feat: nodes continue on error --- .../advanced_chat/generate_task_pipeline.py | 3 +- api/core/app/apps/workflow_app_runner.py | 32 ++ api/core/app/entities/queue_entities.py | 32 ++ .../task_pipeline/workflow_cycle_manage.py | 16 +- .../workflow/graph_engine/entities/event.py | 4 + .../entities/runtime_route_state.py | 10 +- .../workflow/graph_engine/graph_engine.py | 33 +- api/core/workflow/nodes/base/entities.py | 6 +- api/core/workflow/nodes/base/node.py | 33 +- api/core/workflow/nodes/enums.py | 8 + api/core/workflow/utils/condition/entities.py | 6 + api/models/workflow.py | 1 + .../workflow/nodes/test_code.py | 39 +++ .../workflow/nodes/test_continue_on_error.py | 284 ++++++++++++++++++ 14 files changed, 487 insertions(+), 20 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 1d4c0ea0fa6f4d..7e25df439e7c2d 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -19,6 +19,7 @@ QueueIterationNextEvent, QueueIterationStartEvent, QueueMessageReplaceEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -315,7 +316,7 @@ def _process_stream_response( if response: yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2872390d4662db..04cd47e51f216a 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -8,6 +8,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -32,6 +33,7 @@ IterationRunStartedEvent, IterationRunSucceededEvent, NodeInIterationFailedEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -255,6 +257,36 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeRunExceptionEvent): + self._publish_event( + QueueNodeExceptionEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) elif isinstance(event, NodeInIterationFailedEvent): self._publish_event( QueueNodeInIterationFailedEvent( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 69bc0d7f9ec102..0c2329abd3a266 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -31,6 +31,7 @@ class QueueEvent(str, Enum): NODE_STARTED = "node_started" NODE_SUCCEEDED = "node_succeeded" NODE_FAILED = "node_failed" + NODE_EXCEPTION = "node_exception" RETRIEVER_RESOURCES = "retriever_resources" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" @@ -343,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): error: str +class QueueNodeExceptionEvent(AppQueueEvent): + """ + QueueNodeExceptionEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_EXCEPTION + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 042339969fb8ad..751b34b25b4fa1 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -11,6 +11,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -314,7 +315,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -347,8 +348,12 @@ def _handle_workflow_node_execution_failed( db.session.commit() db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + status = ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ) + workflow_node_execution.status = status workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None @@ -479,7 +484,10 @@ def _workflow_node_start_to_stream_response( def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, + event: QueueNodeSucceededEvent + | QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 3736e632c3f1eb..45dc416b1ff433 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -82,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunExceptionEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index bb24b511127395..1f36568d7d73ee 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -15,6 +15,7 @@ class Status(Enum): SUCCESS = "success" FAILED = "failed" PAUSED = "paused" + EXCEPTION = "exception" id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" @@ -51,7 +52,11 @@ def set_finished(self, run_result: NodeRunResult) -> None: :param run_result: run result """ - if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: + if self.status in { + RouteNodeState.Status.SUCCESS, + RouteNodeState.Status.FAILED, + RouteNodeState.Status.EXCEPTION, + }: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -59,6 +64,9 @@ def set_finished(self, run_result: NodeRunResult) -> None: elif run_result.status == WorkflowNodeExecutionStatus.FAILED: self.status = RouteNodeState.Status.FAILED self.failed_reason = run_result.error + elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + self.status = RouteNodeState.Status.EXCEPTION + self.failed_reason = run_result.error else: raise Exception(f"Invalid route status {run_result.status}") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f07ad4de11bdfe..ee07b7a74dd05c 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -20,6 +20,7 @@ GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -599,7 +600,10 @@ def _run_node( parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + elif run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ): if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( @@ -632,18 +636,23 @@ def _run_node( run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) - - yield NodeRunSucceededEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, + event_args = { + "id": node_instance.id, + "node_id": node_instance.node_id, + "node_type": node_instance.node_type, + "node_data": node_instance.node_data, + "route_node_state": route_node_state, + "parallel_id": parallel_id, + "parallel_start_node_id": parallel_start_node_id, + "parent_parallel_id": parent_parallel_id, + "parent_parallel_start_node_id": parent_parallel_start_node_id, + } + event = ( + NodeRunSucceededEvent(**event_args) + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + else NodeRunExceptionEvent(**event_args, error=run_result.error) ) + yield event break elif isinstance(item, RunStreamChunkEvent): diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 2a864dd7a84c8b..a78802e4554486 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,12 +1,16 @@ from abc import ABC -from typing import Optional +from typing import Any, Optional from pydantic import BaseModel +from core.workflow.nodes.enums import ErrorStrategy + class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + error_strategy: Optional[ErrorStrategy] = None + default_value: Optional[Any] = None class BaseIterationNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 1433c8eaed6d4d..59814d2b289ba1 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,8 +4,9 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent +from core.workflow.utils.condition.entities import ContinueOnErrorCondition from models.workflow import WorkflowNodeExecutionStatus from .entities import BaseNodeData @@ -76,6 +77,10 @@ def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: ) if isinstance(result, NodeRunResult): + if self.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH: + result.edge_source_handle = ContinueOnErrorCondition.SUCCESS + if result.status == WorkflowNodeExecutionStatus.FAILED and self._should_continue_on_error: + result = self.__handle_continue_on_error(result) yield RunCompletedEvent(run_result=result) else: yield from result @@ -135,3 +140,29 @@ def node_type(self) -> NodeType: :return: """ return self._node_type + + @property + def _should_continue_on_error(self) -> bool: + """judge if should continue on error + + Returns: + bool: if should continue on error + """ + return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + + def __handle_continue_on_error(self, error: NodeRunResult) -> NodeRunResult: + if self.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error=error.error, + inputs=error.inputs, + outputs=self.node_data.default_value, + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error=error.error, + inputs=error.inputs, + outputs=None, + edge_source_handle=ContinueOnErrorCondition.EXCEPTION, + ) diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 208144655b5a59..af5df230b58433 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -22,3 +22,11 @@ class NodeType(str, Enum): CONVERSATION_VARIABLE_ASSIGNER = "assigner" DOCUMENT_EXTRACTOR = "document-extractor" LIST_OPERATOR = "list-operator" + + +class ErrorStrategy(str, Enum): + FAIL_BRANCH = "fail-branch" + DEFAULT_VALUE = "default-value" + + +CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index 799c735f5409ee..49545e29b139a8 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from enum import Enum from typing import Literal from pydantic import BaseModel, Field @@ -47,3 +48,8 @@ class Condition(BaseModel): comparison_operator: SupportedComparisonOperator value: str | Sequence[str] | None = None sub_variable_condition: SubVariableCondition | None = None + + +class ContinueOnErrorCondition(str, Enum): + SUCCESS = "success" + EXCEPTION = "exception" diff --git a/api/models/workflow.py b/api/models/workflow.py index 4f0e9a5e03705f..2d4f433c03cf67 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -520,6 +520,7 @@ class WorkflowNodeExecutionStatus(Enum): RUNNING = "running" SUCCEEDED = "succeeded" FAILED = "failed" + EXCEPTION = "exception" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 4de985ae7c9dea..7e8fde42c44b8b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -14,6 +14,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.nodes.event.event import RunCompletedEvent from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -353,3 +354,41 @@ def main(args1: int, args2: int) -> dict: # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) + + +def test_excute_code_continue_on_error(): + code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + code = "\n".join([line[4:] for line in code.split("\n")]) + + code_config = { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "error_strategy": "default-value", + "title": "123", + "variables": [], + "answer": "123", + "code_language": "python3", + "default_value": {"result": 132123}, + "code": code, + }, + } + + node = init_code_node(code_config) + + # execute node + result = node.run() + for r in result: + assert isinstance(r, RunCompletedEvent) + run_ruslt = r.run_result + assert run_ruslt.status == WorkflowNodeExecutionStatus.EXCEPTION + assert run_ruslt.outputs == {"result": 132123} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py new file mode 100644 index 00000000000000..25302fcca9d83f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -0,0 +1,284 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow import graph_engine +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent, NodeRunExceptionEvent +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.graph_engine import GraphEngine +from models.enums import UserFrom +from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.graph_engine.test_graph_engine import VariablePool + + +def test_default_value_continue_on_error(): + # LLM, Tool, HTTP Request, Code in the Grpah error handle + code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": "code-source-answer-target", + "source": "code", + "target": "answer", + "sourceHandle": "source", + "targetHandle": "target", + }, + ], + "nodes": [ + {"data": {"title": "开始", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "直接回复", "type": "answer", "answer": "{{#code.result#}}"}, "id": "answer"}, + { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "error_strategy": "default-value", + "title": "123", + "variables": [], + "code_language": "python3", + "default_value": {"result": 132123}, + "code": code, + "type": "code", + }, + }, + ], + } + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "清空对话", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "Novice"}, + ) + + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + rst = graph_engine.run() + arr = [] + for r in rst: + if isinstance(r, GraphRunSucceededEvent): + assert r.outputs == {"answer": "132123"} + arr.append(r) + assert isinstance(arr[4], NodeRunExceptionEvent) + + +def test_fail_branch_continue_on_error(): + code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": "code-true-code_success-target", + "source": "code", + "target": "code_success", + "sourceHandle": "success", + "targetHandle": "target", + }, + { + "id": "code-false-code_error-target", + "source": "code", + "target": "code_error", + "sourceHandle": "exception", + "targetHandle": "target", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "error_strategy": "fail-branch", + "title": "code", + "variables": [], + "code_language": "python3", + "code": code, + "type": "code", + }, + }, + { + "data": {"title": "code_success", "type": "answer", "answer": "code node run successfully"}, + "id": "code_success", + }, + { + "data": {"title": "code_error", "type": "answer", "answer": "code node run failed"}, + "id": "code_error", + }, + ], + } + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "清空对话", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + rst = graph_engine.run() + arr = [] + for r in rst: + arr.append(r) + if isinstance(r, GraphRunSucceededEvent): + assert r.outputs == {"answer": "code node run failed"} + print(arr) + + +def test_success_branch_continue_on_error(): + code = """ + def main() -> dict: + return { + "result": 1 / 1, + } + """ + code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { + "edges": [ + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": "code-true-code_success-target", + "source": "code", + "target": "code_success", + "sourceHandle": "success", + "targetHandle": "target", + }, + { + "id": "code-false-code_error-target", + "source": "code", + "target": "code_error", + "sourceHandle": "exception", + "targetHandle": "target", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "id": "code", + "data": { + "outputs": { + "result": { + "type": "number", + }, + }, + "error_strategy": "fail-branch", + "title": "code", + "variables": [], + "code_language": "python3", + "code": code, + "type": "code", + }, + }, + { + "data": {"title": "code_success", "type": "answer", "answer": "code node run successfully"}, + "id": "code_success", + }, + { + "data": {"title": "code_error", "type": "answer", "answer": "code node run failed"}, + "id": "code_error", + }, + ], + } + graph = Graph.init(graph_config=graph_config) + + variable_pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "清空对话", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + user_inputs={"uid": "takato"}, + ) + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + rst = graph_engine.run() + arr = [] + for r in rst: + arr.append(r) + if isinstance(r, GraphRunSucceededEvent): + assert r.outputs == {"answer": "code node run successfully"} + print(arr) From 4405dff337eccd17e96523a4e52de774139c14e8 Mon Sep 17 00:00:00 2001 From: Nov1c444 <857526207@qq.com> Date: Fri, 15 Nov 2024 15:05:51 +0800 Subject: [PATCH 02/16] feat: add error type to variable pool --- api/core/workflow/entities/node_entities.py | 1 + api/core/workflow/nodes/base/node.py | 41 +++++++++++++------- api/core/workflow/nodes/code/code_node.py | 4 +- api/core/workflow/nodes/http_request/node.py | 1 + api/core/workflow/nodes/llm/node.py | 1 + api/core/workflow/nodes/tool/tool_node.py | 2 + 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index a7472666614fab..eac3939790784b 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -43,3 +43,4 @@ class NodeRunResult(BaseModel): edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + error_type: Optional[str] = None # error type if status is failed diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 59814d2b289ba1..d7f9b42dfb638c 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.utils.condition.entities import ContinueOnErrorCondition @@ -71,16 +72,16 @@ def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: result = self._run() except Exception as e: logger.exception(f"Node {self.node_id} failed to run: {e}") - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") if isinstance(result, NodeRunResult): - if self.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH: + if ( + result.status == WorkflowNodeExecutionStatus.SUCCEEDED + and self.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + ): result.edge_source_handle = ContinueOnErrorCondition.SUCCESS if result.status == WorkflowNodeExecutionStatus.FAILED and self._should_continue_on_error: - result = self.__handle_continue_on_error(result) + result = self.__handle_continue_on_error(result, self.graph_runtime_state.variable_pool) yield RunCompletedEvent(run_result=result) else: yield from result @@ -150,19 +151,33 @@ def _should_continue_on_error(self) -> bool: """ return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE - def __handle_continue_on_error(self, error: NodeRunResult) -> NodeRunResult: + def __handle_continue_on_error(self, error_result: NodeRunResult, variable_pool: VariablePool) -> NodeRunResult: + """ + handle continue on error when self._should_continue_on_error is True + + Args: + error_result (NodeRunResult): error run result + variable_pool (VariablePool): variable pool + Returns: + NodeRunResult: excption run result + """ + # add error message and error type to variable pool + variable_pool.add((self.node_id, "error_message"), error_result.error) + variable_pool.add((self.node_id, "error_type"), error_result.error_type) + + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": error_result.error, + "inputs": error_result.inputs, + } if self.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - error=error.error, - inputs=error.inputs, + **node_error_args, outputs=self.node_data.default_value, ) return NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - error=error.error, - inputs=error.inputs, + **node_error_args, outputs=None, edge_source_handle=ContinueOnErrorCondition.EXCEPTION, ) diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index ce283e38ec9b12..19b9078a5ce4a4 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -61,7 +61,9 @@ def _run(self) -> NodeRunResult: # Transform result result = self._transform_result(result, self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ + ) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 5b399bed63df97..f62f1f39884550 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -85,6 +85,7 @@ def _run(self) -> NodeRunResult: status=WorkflowNodeExecutionStatus.FAILED, error=str(e), process_data=process_data, + error_type=type(e).__name__, ) @staticmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87aa6a..87c25dca5cc036 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -178,6 +178,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, ) ) return diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6870b7467d11a4..635776cf0863ea 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -57,6 +57,7 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to get tool runtime: {str(e)}", + error_type=type(e).__name__, ) # get parameters @@ -90,6 +91,7 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, ) # convert tool messages From 09da57c3c73b92b7f7d2fd334038e4b3d31ac8c2 Mon Sep 17 00:00:00 2001 From: Nov1c444 <857526207@qq.com> Date: Tue, 19 Nov 2024 10:58:44 +0800 Subject: [PATCH 03/16] feat: catch the llm node error --- .../workflow/graph_engine/graph_engine.py | 102 +++- api/core/workflow/nodes/base/node.py | 44 +- .../workflow/nodes/test_code.py | 38 -- .../workflow/nodes/test_continue_on_error.py | 561 +++++++++++------- 4 files changed, 420 insertions(+), 325 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ee07b7a74dd05c..303667b6110c2e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -11,7 +11,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -38,8 +38,10 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.utils.condition.entities import ContinueOnErrorCondition from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -588,22 +590,49 @@ def _run_node( route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) - elif run_result.status in ( - WorkflowNodeExecutionStatus.SUCCEEDED, - WorkflowNodeExecutionStatus.EXCEPTION, - ): + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, item.run_result, self.graph_runtime_state.variable_pool + ) + route_node_state.node_run_result = run_result + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error: + run_result.edge_source_handle = ContinueOnErrorCondition.SUCCESS if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( @@ -647,11 +676,7 @@ def _run_node( "parent_parallel_id": parent_parallel_id, "parent_parallel_start_node_id": parent_parallel_start_node_id, } - event = ( - NodeRunSucceededEvent(**event_args) - if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - else NodeRunExceptionEvent(**event_args, error=run_result.error) - ) + event = NodeRunSucceededEvent(**event_args) yield event break @@ -744,6 +769,39 @@ def create_copy(self): new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) return new_instance + def _handle_continue_on_error( + self, node_instance: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool + ) -> NodeRunResult: + """ + handle continue on error when self._should_continue_on_error is True + + Args: + error_result (NodeRunResult): error run result + variable_pool (VariablePool): variable pool + Returns: + NodeRunResult: excption run result + """ + # add error message and error type to variable pool + variable_pool.add([node_instance.node_id, "error_message"], error_result.error) + variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": error_result.error, + "inputs": error_result.inputs, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + return NodeRunResult( + **node_error_args, + outputs=node_instance.node_data.default_value, + ) + + return NodeRunResult( + **node_error_args, + outputs=None, + edge_source_handle=ContinueOnErrorCondition.EXCEPTION, + ) + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index d7f9b42dfb638c..13bb23909d4380 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, ErrorStrategy, NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from core.workflow.utils.condition.entities import ContinueOnErrorCondition from models.workflow import WorkflowNodeExecutionStatus from .entities import BaseNodeData @@ -75,13 +73,6 @@ def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") if isinstance(result, NodeRunResult): - if ( - result.status == WorkflowNodeExecutionStatus.SUCCEEDED - and self.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - ): - result.edge_source_handle = ContinueOnErrorCondition.SUCCESS - if result.status == WorkflowNodeExecutionStatus.FAILED and self._should_continue_on_error: - result = self.__handle_continue_on_error(result, self.graph_runtime_state.variable_pool) yield RunCompletedEvent(run_result=result) else: yield from result @@ -143,41 +134,10 @@ def node_type(self) -> NodeType: return self._node_type @property - def _should_continue_on_error(self) -> bool: + def should_continue_on_error(self) -> bool: """judge if should continue on error Returns: bool: if should continue on error """ return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE - - def __handle_continue_on_error(self, error_result: NodeRunResult, variable_pool: VariablePool) -> NodeRunResult: - """ - handle continue on error when self._should_continue_on_error is True - - Args: - error_result (NodeRunResult): error run result - variable_pool (VariablePool): variable pool - Returns: - NodeRunResult: excption run result - """ - # add error message and error type to variable pool - variable_pool.add((self.node_id, "error_message"), error_result.error) - variable_pool.add((self.node_id, "error_type"), error_result.error_type) - - node_error_args = { - "status": WorkflowNodeExecutionStatus.EXCEPTION, - "error": error_result.error, - "inputs": error_result.inputs, - } - if self.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: - return NodeRunResult( - **node_error_args, - outputs=self.node_data.default_value, - ) - - return NodeRunResult( - **node_error_args, - outputs=None, - edge_source_handle=ContinueOnErrorCondition.EXCEPTION, - ) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 7e8fde42c44b8b..362ab127472696 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -354,41 +354,3 @@ def main(args1: int, args2: int) -> dict: # validate with pytest.raises(ValueError): node._transform_result(result, node.node_data.outputs) - - -def test_excute_code_continue_on_error(): - code = """ - def main() -> dict: - return { - "result": 1 / 0, - } - """ - code = "\n".join([line[4:] for line in code.split("\n")]) - - code_config = { - "id": "code", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "error_strategy": "default-value", - "title": "123", - "variables": [], - "answer": "123", - "code_language": "python3", - "default_value": {"result": 132123}, - "code": code, - }, - } - - node = init_code_node(code_config) - - # execute node - result = node.run() - for r in result: - assert isinstance(r, RunCompletedEvent) - run_ruslt = r.run_result - assert run_ruslt.status == WorkflowNodeExecutionStatus.EXCEPTION - assert run_ruslt.outputs == {"result": 132123} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 25302fcca9d83f..189b1931121be0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,284 +1,399 @@ from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow import graph_engine from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent, NodeRunExceptionEvent from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.graph_engine import GraphEngine from models.enums import UserFrom from models.workflow import WorkflowType -from tests.unit_tests.core.workflow.graph_engine.test_graph_engine import VariablePool -def test_default_value_continue_on_error(): - # LLM, Tool, HTTP Request, Code in the Grpah error handle - code = """ +class ContinueOnErrorTestHelper: + @staticmethod + def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a code node configuration""" + node = { + "id": "node", + "data": { + "outputs": {"result": {"type": "number"}}, + "error_strategy": error_strategy, + "title": "code", + "variables": [], + "code_language": "python3", + "code": "\n".join([line[4:] for line in code.split("\n")]), + "type": "code", + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a http node configuration""" + node = { + "id": "node", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": { + "type": "api-key", + # missing config field + }, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + "type": "http-request", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a tool node configuration""" + node = { + "id": "node", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], + } + }, + "type": "tool", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a llm node configuration""" + node = { + "id": "node", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): + """Helper method to create a graph engine instance for testing""" + graph = Graph.init(graph_config=graph_config) + variable_pool = { + "system_variables": { + SystemVariableKey.QUERY: "clear", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + "user_inputs": user_inputs or {"uid": "takato"}, + } + + return GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + "targetHandle": "target", + }, +] + +FAIL_BRANCH_EDGES = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + "targetHandle": "target", + }, + { + "id": "node-true-success-target", + "source": "node", + "target": "success", + "sourceHandle": "success", + "targetHandle": "target", + }, + { + "id": "node-false-error-target", + "source": "node", + "target": "error", + "sourceHandle": "exception", + "targetHandle": "target", + }, +] + + +def test_code_default_value_continue_on_error(): + error_code = """ def main() -> dict: return { "result": 1 / 0, } """ - code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { - "edges": [ - { - "id": "start-source-code-target", - "source": "start", - "target": "code", - "sourceHandle": "source", - "targetHandle": "target", - }, - { - "id": "code-source-answer-target", - "source": "code", - "target": "answer", - "sourceHandle": "source", - "targetHandle": "target", - }, - ], + "edges": DEFAULT_VALUE_EDGE, "nodes": [ - {"data": {"title": "开始", "type": "start", "variables": []}, "id": "start"}, - {"data": {"title": "直接回复", "type": "answer", "answer": "{{#code.result#}}"}, "id": "answer"}, - { - "id": "code", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "error_strategy": "default-value", - "title": "123", - "variables": [], - "code_language": "python3", - "default_value": {"result": 132123}, - "code": code, - "type": "code", - }, - }, + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_code_node(error_code, "default-value", {"result": 132123}), ], } - graph = Graph.init(graph_config=graph_config) - - variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "清空对话", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={"uid": "Novice"}, - ) - graph_engine = GraphEngine( - tenant_id="111", - app_id="222", - workflow_type=WorkflowType.CHAT, - workflow_id="333", - graph_config=graph_config, - user_id="444", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - graph=graph, - variable_pool=variable_pool, - max_execution_steps=500, - max_execution_time=1200, - ) - rst = graph_engine.run() - arr = [] - for r in rst: - if isinstance(r, GraphRunSucceededEvent): - assert r.outputs == {"answer": "132123"} - arr.append(r) - assert isinstance(arr[4], NodeRunExceptionEvent) + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) -def test_fail_branch_continue_on_error(): - code = """ +def test_code_fail_branch_continue_on_error(): + error_code = """ def main() -> dict: return { "result": 1 / 0, } """ - code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { - "edges": [ - { - "id": "start-source-code-target", - "source": "start", - "target": "code", - "sourceHandle": "source", - "targetHandle": "target", - }, - { - "id": "code-true-code_success-target", - "source": "code", - "target": "code_success", - "sourceHandle": "success", - "targetHandle": "target", - }, - { - "id": "code-false-code_error-target", - "source": "code", - "target": "code_error", - "sourceHandle": "exception", - "targetHandle": "target", - }, - ], + "edges": FAIL_BRANCH_EDGES, "nodes": [ {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "id": "code", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "error_strategy": "fail-branch", - "title": "code", - "variables": [], - "code_language": "python3", - "code": code, - "type": "code", - }, - }, - { - "data": {"title": "code_success", "type": "answer", "answer": "code node run successfully"}, - "id": "code_success", + "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, + "id": "success", }, { - "data": {"title": "code_error", "type": "answer", "answer": "code node run failed"}, - "id": "code_error", + "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, + "id": "error", }, + ContinueOnErrorTestHelper.get_code_node(error_code), ], } - graph = Graph.init(graph_config=graph_config) - - variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "清空对话", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={"uid": "takato"}, - ) - graph_engine = GraphEngine( - tenant_id="111", - app_id="222", - workflow_type=WorkflowType.CHAT, - workflow_id="333", - graph_config=graph_config, - user_id="444", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - graph=graph, - variable_pool=variable_pool, - max_execution_steps=500, - max_execution_time=1200, + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events ) - rst = graph_engine.run() - arr = [] - for r in rst: - arr.append(r) - if isinstance(r, GraphRunSucceededEvent): - assert r.outputs == {"answer": "code node run failed"} - print(arr) -def test_success_branch_continue_on_error(): - code = """ +def test_code_success_branch_continue_on_error(): + success_code = """ def main() -> dict: return { "result": 1 / 1, } """ - code = "\n".join([line[4:] for line in code.split("\n")]) + graph_config = { - "edges": [ + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "id": "start-source-code-target", - "source": "start", - "target": "code", - "sourceHandle": "source", - "targetHandle": "target", + "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, + "id": "success", }, { - "id": "code-true-code_success-target", - "source": "code", - "target": "code_success", - "sourceHandle": "success", - "targetHandle": "target", + "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, + "id": "error", }, + ContinueOnErrorTestHelper.get_code_node(success_code), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any( + isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run successfully"} for e in events + ) + + +def test_http_node_default_value_continue_on_error(): + """Test HTTP node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node("default-value", {"response": "http node got error response"}), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "http node got error response"} + for e in events + ) + + +def test_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "id": "code-false-code_error-target", - "source": "code", - "target": "code_error", - "sourceHandle": "exception", - "targetHandle": "target", + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", }, + { + "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_http_node(), ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events) + + +def test_tool_node_default_value_continue_on_error(): + """Test tool node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_tool_node("default-value", {"result": "default tool result"}), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events) + + +def test_tool_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, "nodes": [ {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "id": "code", - "data": { - "outputs": { - "result": { - "type": "number", - }, - }, - "error_strategy": "fail-branch", - "title": "code", - "variables": [], - "code_language": "python3", - "code": code, - "type": "code", - }, + "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, + "id": "success", }, { - "data": {"title": "code_success", "type": "answer", "answer": "code node run successfully"}, - "id": "code_success", + "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, + "id": "error", }, + ContinueOnErrorTestHelper.get_tool_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events) + + +def test_llm_node_default_value_continue_on_error(): + """Test LLM node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_llm_node("default-value", {"answer": "default LLM response"}), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events + ) + + +def test_llm_node_fail_branch_continue_on_error(): + """Test LLM node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, { - "data": {"title": "code_error", "type": "answer", "answer": "code node run failed"}, - "id": "code_error", + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", }, + { + "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), ], } - graph = Graph.init(graph_config=graph_config) - - variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "清空对话", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, - user_inputs={"uid": "takato"}, - ) - graph_engine = GraphEngine( - tenant_id="111", - app_id="222", - workflow_type=WorkflowType.CHAT, - workflow_id="333", - graph_config=graph_config, - user_id="444", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - graph=graph, - variable_pool=variable_pool, - max_execution_steps=500, - max_execution_time=1200, - ) - rst = graph_engine.run() - arr = [] - for r in rst: - arr.append(r) - if isinstance(r, GraphRunSucceededEvent): - assert r.outputs == {"answer": "code node run successfully"} - print(arr) + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events) From ae5890037eeb9572dac3eb2023dd5a20c28a720f Mon Sep 17 00:00:00 2001 From: Nov1c444 <857526207@qq.com> Date: Wed, 20 Nov 2024 16:45:52 +0800 Subject: [PATCH 04/16] fix: Excessive stream output --- .../workflow/graph_engine/entities/graph.py | 3 +- .../workflow/graph_engine/graph_engine.py | 13 +++------ .../answer/answer_stream_generate_router.py | 21 +++++++++----- .../nodes/answer/answer_stream_processor.py | 3 +- api/core/workflow/utils/condition/entities.py | 6 ---- .../workflow/nodes/test_continue_on_error.py | 29 ++++++++++++------- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d87c039409d62e..f55f9ff2bd4819 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -92,7 +92,8 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non run_condition = None if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) - + if edge_config.get("errorHandle"): + run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("errorHandle")) graph_edge = GraphEdge( source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 303667b6110c2e..a5fc593d56bace 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from typing import Any, Optional +from typing import Any, Optional, cast from flask import Flask, current_app @@ -41,7 +41,6 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_type_classes_mapping -from core.workflow.utils.condition.entities import ContinueOnErrorCondition from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -337,7 +336,7 @@ def _run( if len(sub_edge_mappings) == 0: continue - edge = sub_edge_mappings[0] + edge = cast(GraphEdge, sub_edge_mappings[0]) result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -632,7 +631,7 @@ def _run_node( elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if node_instance.should_continue_on_error: - run_result.edge_source_handle = ContinueOnErrorCondition.SUCCESS + run_result.edge_source_handle = "false" if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( @@ -796,11 +795,7 @@ def _handle_continue_on_error( outputs=node_instance.node_data.default_value, ) - return NodeRunResult( - **node_error_args, - outputs=None, - edge_source_handle=ContinueOnErrorCondition.EXCEPTION, - ) + return NodeRunResult(**node_error_args, outputs=None, edge_source_handle="true") class GraphRunFailedError(Exception): diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 96e24a7db3725e..9c7703e86a3c2a 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -148,13 +148,18 @@ def _recursive_fetch_answer_dependencies( for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.CONVERSATION_VARIABLE_ASSIGNER, - }: + source_node_data = node_id_config_mapping[source_node_id].get("data", {}) + if ( + source_node_type + in { + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, + NodeType.CONVERSATION_VARIABLE_ASSIGNER, + } + or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH + ): answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 8a768088da660e..d94f0590584842 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -6,6 +6,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, + NodeRunExceptionEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -50,7 +51,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat for _ in stream_out_answer_node_ids: yield event - elif isinstance(event, NodeRunSucceededEvent): + elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent): yield event if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # update self.route_position after all stream event finished diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index 49545e29b139a8..799c735f5409ee 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from enum import Enum from typing import Literal from pydantic import BaseModel, Field @@ -48,8 +47,3 @@ class Condition(BaseModel): comparison_operator: SupportedComparisonOperator value: str | Sequence[str] | None = None sub_variable_condition: SubVariableCondition | None = None - - -class ContinueOnErrorCondition(str, Enum): - SUCCESS = "success" - EXCEPTION = "exception" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 189b1931121be0..99e6e257b8f1ee 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,6 +1,10 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import GraphRunSucceededEvent, NodeRunExceptionEvent +from core.workflow.graph_engine.entities.event import ( + GraphRunSucceededEvent, + NodeRunExceptionEvent, + NodeRunStreamChunkEvent, +) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine from models.enums import UserFrom @@ -140,14 +144,12 @@ def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None "source": "start", "target": "node", "sourceHandle": "source", - "targetHandle": "target", }, { "id": "node-source-answer-target", "source": "node", "target": "answer", "sourceHandle": "source", - "targetHandle": "target", }, ] @@ -157,21 +159,20 @@ def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None "source": "start", "target": "node", "sourceHandle": "source", - "targetHandle": "target", }, { "id": "node-true-success-target", "source": "node", "target": "success", - "sourceHandle": "success", - "targetHandle": "target", + "sourceHandle": "source", + "errorHandle": "false", }, { "id": "node-false-error-target", "source": "node", "target": "error", - "sourceHandle": "exception", - "targetHandle": "target", + "sourceHandle": "source", + "errorHandle": "true", }, ] @@ -198,6 +199,7 @@ def main() -> dict: assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_code_fail_branch_continue_on_error(): @@ -226,7 +228,7 @@ def main() -> dict: graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) events = list(graph_engine.run()) - + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any( isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events @@ -245,6 +247,7 @@ def main() -> dict: "edges": FAIL_BRANCH_EDGES, "nodes": [ {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + ContinueOnErrorTestHelper.get_code_node(success_code), { "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, "id": "success", @@ -253,7 +256,6 @@ def main() -> dict: "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, "id": "error", }, - ContinueOnErrorTestHelper.get_code_node(success_code), ], } @@ -263,6 +265,7 @@ def main() -> dict: assert any( isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run successfully"} for e in events ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_http_node_default_value_continue_on_error(): @@ -284,6 +287,7 @@ def test_http_node_default_value_continue_on_error(): isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "http node got error response"} for e in events ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_http_node_fail_branch_continue_on_error(): @@ -309,6 +313,7 @@ def test_http_node_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_tool_node_default_value_continue_on_error(): @@ -327,6 +332,7 @@ def test_tool_node_default_value_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_tool_node_fail_branch_continue_on_error(): @@ -352,6 +358,7 @@ def test_tool_node_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_llm_node_default_value_continue_on_error(): @@ -372,6 +379,7 @@ def test_llm_node_default_value_continue_on_error(): assert any( isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 def test_llm_node_fail_branch_continue_on_error(): @@ -397,3 +405,4 @@ def test_llm_node_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 From 32d807dcce8212795315250df054127f70f08605 Mon Sep 17 00:00:00 2001 From: Nov1c444 <857526207@qq.com> Date: Thu, 21 Nov 2024 14:35:57 +0800 Subject: [PATCH 05/16] feat: error code status continue on error --- api/core/helper/ssrf_proxy.py | 8 +- .../workflow/nodes/http_request/executor.py | 7 +- api/core/workflow/nodes/http_request/node.py | 15 ++ .../workflow/nodes/test_continue_on_error.py | 133 ++++++++++++++++++ 4 files changed, 160 insertions(+), 3 deletions(-) diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 80f01fa12b3544..2ab087113d65ec 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,6 +24,12 @@ STATUS_FORCELIST = [429, 500, 502, 503, 504] +class MaxRetriesExceededError(Exception): + """Raised when the maximum number of retries is exceeded.""" + + pass + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -66,7 +72,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") + raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 80b322b068ec50..dfa071672f28a0 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -21,6 +21,7 @@ from .exc import ( AuthorizationConfigError, FileFetchError, + HttpRequestNodeError, InvalidHttpMethodError, ResponseSizeError, ) @@ -208,8 +209,10 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: "follow_redirects": True, } # request_args = {k: v for k, v in request_args.items() if v is not None} - - response = getattr(ssrf_proxy, self.method)(**request_args) + try: + response = getattr(ssrf_proxy, self.method)(**request_args) + except ssrf_proxy.MaxRetriesExceededError as e: + raise HttpRequestNodeError(str(e)) return response def invoke(self) -> Response: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index f62f1f39884550..aab16c51f2aea2 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -67,6 +67,21 @@ def _run(self) -> NodeRunResult: response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) + if not response.response.is_success and self.should_continue_on_error: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + error=f"Request failed with status code {response.status_code}", + error_type="HTTPResponseCodeError", + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 99e6e257b8f1ee..2344413d28db8b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -56,6 +56,30 @@ def get_http_node(error_strategy: str = "fail-branch", default_value: dict | Non node["data"]["default_value"] = default_value return node + @staticmethod + def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a http node configuration""" + node = { + "id": "node", + "data": { + "type": "http-request", + "title": "HTTP Request", + "desc": "", + "variables": [], + "method": "get", + "url": "https://api.github.com/issues", + "authorization": {"type": "no-auth", "config": None}, + "headers": "", + "params": "", + "body": {"type": "none", "data": []}, + "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + @staticmethod def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): """Helper method to create a tool node configuration""" @@ -406,3 +430,112 @@ def test_llm_node_fail_branch_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_status_code_error_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_variable_pool_error_type_variable(): + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + list(graph_engine.run()) + error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) + error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) + assert error_message.value == "Request failed with status code 404" + assert error_type.value == "HTTPResponseCodeError" + + +def test_continue_on_error_link_fail_branch(): + success_code = """ + def main() -> dict: + return { + "result": 1 / 1, + } + """ + graph_config = { + "edges": [ + *FAIL_BRANCH_EDGES, + { + "id": "start-source-code-target", + "source": "start", + "target": "code", + "sourceHandle": "source", + }, + { + "id": "code-source-error-target", + "source": "code", + "target": "error", + "sourceHandle": "source", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_code_node(code=success_code), + { + "id": "code", + "data": { + "outputs": {"result": {"type": "number"}}, + "title": "code", + "variables": [], + "code_language": "python3", + "code": "\n".join([line[4:] for line in success_code.split("\n")]), + "type": "code", + }, + }, + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert any( + isinstance(e, GraphRunSucceededEvent) + and e.outputs == {"answer": "http execute successful\nhttp execute failed"} + for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 2 From 6c125fb4366b05bb7ec4e45407fe9edae825ef4a Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Wed, 27 Nov 2024 15:14:44 +0800 Subject: [PATCH 06/16] feat: Adjust the strategy for fail-branch --- .../workflow/graph_engine/entities/graph.py | 18 ++++++++++++++---- api/core/workflow/graph_engine/graph_engine.py | 6 +++--- api/core/workflow/nodes/enums.py | 5 +++++ .../workflow/nodes/test_continue_on_error.py | 5 +---- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index f55f9ff2bd4819..d27dd6b6360d3f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -71,6 +71,9 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_mapping: dict[str, list[GraphEdge]] = {} reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() + fail_branch_source_node_id = [ + edge["source"] for edge in edge_configs if edge.get("sourceHandle") == "fail-branch" + ] for edge_config in edge_configs: source_node_id = edge_config.get("source") if not source_node_id: @@ -90,10 +93,17 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non # parse run condition run_condition = None - if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": - run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) - if edge_config.get("errorHandle"): - run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("errorHandle")) + if edge_config.get("sourceHandle"): + if ( + edge_config.get("source") in fail_branch_source_node_id + and edge_config.get("sourceHandle") != "fail-branch" + ): + run_condition = RunCondition(type="branch_identify", branch_identify="success-branch") + elif edge_config.get("sourceHandle") != "source": + run_condition = RunCondition( + type="branch_identify", branch_identify=edge_config.get("sourceHandle") + ) + graph_edge = GraphEdge( source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition ) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e6b8002a149735..6279be19df38c5 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor -from core.workflow.nodes.enums import ErrorStrategy +from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import node_type_classes_mapping from extensions.ext_database import db @@ -631,7 +631,7 @@ def _run_node( elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if node_instance.should_continue_on_error: - run_result.edge_source_handle = "false" + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( @@ -795,7 +795,7 @@ def _handle_continue_on_error( outputs=node_instance.node_data.default_value, ) - return NodeRunResult(**node_error_args, outputs=None, edge_source_handle="true") + return NodeRunResult(**node_error_args, outputs=None, edge_source_handle=FailBranchSourceHandle.FAILED) class GraphRunFailedError(Exception): diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index af5df230b58433..4487f4d8f7c7be 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -29,4 +29,9 @@ class ErrorStrategy(str, Enum): DEFAULT_VALUE = "default-value" +class FailBranchSourceHandle(str, Enum): + FAILED = "fail-branch" + SUCCESS = "success-branch" + + CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 2344413d28db8b..bf7dda01059347 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -189,14 +189,12 @@ def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None "source": "node", "target": "success", "sourceHandle": "source", - "errorHandle": "false", }, { "id": "node-false-error-target", "source": "node", "target": "error", - "sourceHandle": "source", - "errorHandle": "true", + "sourceHandle": "fail-branch", }, ] @@ -285,7 +283,6 @@ def main() -> dict: graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) events = list(graph_engine.run()) - assert any( isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run successfully"} for e in events ) From d22ab298c66450118d587b4c699bacf9c6582db1 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Wed, 4 Dec 2024 11:02:53 +0800 Subject: [PATCH 07/16] feat: add workflow exception logs --- .../advanced_chat/generate_task_pipeline.py | 24 ++++ .../app/apps/workflow/app_queue_manager.py | 4 +- .../apps/workflow/generate_task_pipeline.py | 42 +++++-- api/core/app/apps/workflow_app_runner.py | 8 +- api/core/app/entities/queue_entities.py | 12 ++ api/core/app/entities/task_entities.py | 1 + .../task_pipeline/workflow_cycle_manage.py | 59 +++++++++- .../callbacks/workflow_logging_callback.py | 3 + .../workflow/graph_engine/entities/event.py | 6 + .../workflow/graph_engine/graph_engine.py | 53 ++++++--- api/core/workflow/nodes/base/entities.py | 108 +++++++++++++++++- api/core/workflow/nodes/base/exc.py | 10 ++ api/core/workflow/nodes/enums.py | 4 +- api/fields/workflow_run_fields.py | 4 + ...4fc45278_add_exceptions_count_field_to_.py | 33 ++++++ api/models/workflow.py | 6 +- .../workflow/nodes/test_continue_on_error.py | 46 +++++--- 17 files changed, 375 insertions(+), 48 deletions(-) create mode 100644 api/core/workflow/nodes/base/exc.py create mode 100644 api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 456dd974ecd5ae..8b9d0a76939ab2 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -32,6 +32,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -383,6 +384,29 @@ def _process_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_partial_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 76371f800ba1e5..349b8eb51b1546 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -6,6 +6,7 @@ QueueMessageEndEvent, QueueStopEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, ) @@ -34,7 +35,8 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | QueueErrorEvent | QueueMessageEndEvent | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent, + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, ): self.stop_listen() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9e4921d6a22c5a..2e75ba4ac7fd08 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -15,6 +15,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -26,6 +27,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -276,7 +278,7 @@ def _process_stream_response( if response: yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( @@ -345,22 +347,20 @@ def _process_stream_response( yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: raise Exception("Workflow run not initialized.") if not graph_runtime_state: raise Exception("Graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( + workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + outputs=event.outputs, + exceptions_count=event.exceptions_count, conversation_id=None, trace_manager=trace_manager, ) @@ -368,6 +368,34 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + handle_args = { + "workflow_run": workflow_run, + "start_at": graph_runtime_state.start_at, + "total_tokens": graph_runtime_state.total_tokens, + "total_steps": graph_runtime_state.node_run_steps, + "status": WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + "conversation_id": None, + "trace_manager": trace_manager, + } + if isinstance(event, QueueWorkflowFailedEvent): + handle_args["exceptions_count"] = event.exceptions_count + workflow_run = self._handle_workflow_run_failed(**handle_args) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ae0d37ac9631bc..c696205fa71963 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -19,6 +19,7 @@ QueueRetrieverResourcesEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -26,6 +27,7 @@ from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -177,8 +179,12 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunPartialSucceededEvent): + self._publish_event( + QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count) + ) elif isinstance(event, GraphRunFailedEvent): - self._publish_event(QueueWorkflowFailedEvent(error=event.error)) + self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 2d5abcada7794c..6d12119a0fba24 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -25,6 +25,7 @@ class QueueEvent(StrEnum): WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" + WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" ITERATION_START = "iteration_start" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" @@ -250,6 +251,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str + exceptions_count: int + + +class QueueWorkflowPartialSuccessEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED + exceptions_count: int + outputs: Optional[dict[str, Any]] = None class QueueNodeStartedEvent(AppQueueEvent): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 03cc6941a84623..7fe06b3af8bbb4 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -213,6 +213,7 @@ class Data(BaseModel): created_by: Optional[dict] = None created_at: int finished_at: int + exceptions_count: Optional[int] = 0 files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 80428c875d34aa..1e7ada20538040 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -165,6 +165,55 @@ def _handle_workflow_run_success( return workflow_run + def _handle_workflow_run_partial_success( + self, + workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + exceptions_count: int = 0, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :param conversation_id: conversation id + :return: + """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + + outputs = WorkflowEntry.handle_special_values(outputs) + + workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.exceptions_count = exceptions_count + db.session.commit() + db.session.refresh(workflow_run) + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + db.session.close() + + return workflow_run + def _handle_workflow_run_failed( self, workflow_run: WorkflowRun, @@ -338,7 +387,11 @@ def _handle_workflow_node_execution_failed( ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, + WorkflowNodeExecution.status: ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ), WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, @@ -352,12 +405,11 @@ def _handle_workflow_node_execution_failed( db.session.commit() db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) - status = ( + workflow_node_execution.status = ( WorkflowNodeExecutionStatus.FAILED.value if not isinstance(event, QueueNodeExceptionEvent) else WorkflowNodeExecutionStatus.EXCEPTION.value ) - workflow_node_execution.status = status workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None @@ -438,6 +490,7 @@ def _workflow_finish_to_stream_response( created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + exceptions_count=workflow_run.exceptions_count, ), ) diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index 17913de7b0d2ce..ed737e7316973c 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -39,6 +40,8 @@ def on_event(self, event: GraphEngineEvent) -> None: self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunPartialSucceededEvent): + self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink") elif isinstance(event, GraphRunFailedEvent): self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 45dc416b1ff433..c10280776e162a 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -32,6 +32,12 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") + exceptions_count: int = Field(..., description="exception count") + + +class GraphRunPartialSucceededEvent(BaseGraphEvent): + exceptions_count: int = Field(..., description="exception count") + outputs: Optional[dict[str, Any]] = None ########################################### diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 7e1caab1f485f0..e7cb52acfe4a04 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -18,6 +18,7 @@ BaseIterationEvent, GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunExceptionEvent, @@ -37,6 +38,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent @@ -142,8 +144,10 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ) # run graph - generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) - + handle_exceptions = [] + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) + ) for item in generator: try: yield item @@ -176,17 +180,23 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: logger.exception("Graph run failed") yield GraphRunFailedEvent(error=str(e)) return - - # trigger graph run success event - yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + # count exceptions to determine partial success + exceptions_count = len(handle_exceptions) + if exceptions_count > 0: + yield GraphRunPartialSucceededEvent( + exceptions_count=exceptions_count, outputs=self.graph_runtime_state.outputs + ) + else: + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) self._release_thread() except GraphRunFailedError as e: - yield GraphRunFailedEvent(error=e.error) + yield GraphRunFailedEvent(error=e.error, exceptions_count=exceptions_count) self._release_thread() return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), handle_exceptions=handle_exceptions) self._release_thread() raise e @@ -200,6 +210,7 @@ def _run( in_parallel_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: @@ -253,6 +264,7 @@ def _run( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: @@ -549,6 +561,7 @@ def _run_node( parallel_start_node_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -591,7 +604,10 @@ def _run_node( if node_instance.should_continue_on_error: # if run failed, handle error run_result = self._handle_continue_on_error( - node_instance, item.run_result, self.graph_runtime_state.variable_pool + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, ) route_node_state.node_run_result = run_result if run_result.outputs: @@ -768,30 +784,35 @@ def create_copy(self): return new_instance def _handle_continue_on_error( - self, node_instance: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool + self, + node_instance: BaseNode[BaseNodeData], + error_result: NodeRunResult, + variable_pool: VariablePool, + handle_exceptions: list[str] = [], ) -> NodeRunResult: """ handle continue on error when self._should_continue_on_error is True - Args: - error_result (NodeRunResult): error run result - variable_pool (VariablePool): variable pool - Returns: - NodeRunResult: excption run result + + :param error_result (NodeRunResult): error run result + :param variable_pool (VariablePool): variable pool + :return: excption run result """ # add error message and error type to variable pool variable_pool.add([node_instance.node_id, "error_message"], error_result.error) variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) - + # add error message to handle_exceptions + handle_exceptions.append(error_result.error) node_error_args = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, - outputs=node_instance.node_data.default_value, + outputs=node_instance.node_data.default_value_dict, ) return NodeRunResult(**node_error_args, outputs=None, edge_source_handle=FailBranchSourceHandle.FAILED) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index a78802e4554486..a2a61a9306a2c6 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,16 +1,118 @@ +import json from abc import ABC -from typing import Any, Optional +from enum import StrEnum +from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, model_validator +from core.workflow.nodes.base.exc import DefaultValueTypeError from core.workflow.nodes.enums import ErrorStrategy +class DefaultValueType(StrEnum): + STRING = "String" + NUMBER = "Number" + OBJECT = "Object" + ARRAY_NUMBER = "Array[Number]" + ARRAY_STRING = "Array[String]" + ARRAY_OBJECT = "Array[Object]" + ARRAY_FILES = "Array[Files]" + + +NumberType = Union[int, float] +ObjectType = dict[str, Any] + + +class DefaultValue(BaseModel): + value: Union[ + str, + NumberType, + ObjectType, + list[NumberType], + list[str], + list[ObjectType], + ] + type: DefaultValueType + key: str + + @model_validator(mode="after") + def validate_value_type(self) -> Any: + value_type = self.type + value = self.value + if value_type is None: + raise DefaultValueTypeError("type field is required") + + # validate string type + if value_type == DefaultValueType.STRING: + if not isinstance(value, str): + raise DefaultValueTypeError(f"Value must be string type for {value}") + + # validate number type + elif value_type == DefaultValueType.NUMBER: + if not isinstance(value, NumberType): + raise DefaultValueTypeError(f"Value must be number type for {value}") + + # validate object type + elif value_type == DefaultValueType.OBJECT: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Value must be object type for {value}") + if not isinstance(value, ObjectType): + raise DefaultValueTypeError(f"Value must be object type for {value}") + + # validate array[number] type + elif value_type == DefaultValueType.ARRAY_NUMBER: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Value must be object type for {value}") + if not isinstance(value, list): + raise DefaultValueTypeError(f"Value must be array type for {value}") + if not all(isinstance(x, NumberType) for x in value): + raise DefaultValueTypeError(f"All elements must be numbers for {value}") + + # validate array[string] type + elif value_type == DefaultValueType.ARRAY_STRING: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Value must be object type for {value}") + if not isinstance(value, list): + raise DefaultValueTypeError(f"Value must be array type for {value}") + if not all(isinstance(x, str) for x in value): + raise DefaultValueTypeError(f"All elements must be strings for {value}") + + # validate array[object] type + elif value_type == DefaultValueType.ARRAY_OBJECT: + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Value must be object type for {value}") + if not isinstance(value, list): + raise DefaultValueTypeError(f"Value must be array type for {value}") + if not all(isinstance(x, ObjectType) for x in value): + raise DefaultValueTypeError(f"All elements must be objects for {value}") + elif value_type == DefaultValueType.ARRAY_FILES: + # handle files type + pass + + class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None error_strategy: Optional[ErrorStrategy] = None - default_value: Optional[Any] = None + default_value: Optional[list[DefaultValue]] = None + + @property + def default_value_dict(self): + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} class BaseIterationNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py new file mode 100644 index 00000000000000..aeecf406403e6d --- /dev/null +++ b/api/core/workflow/nodes/base/exc.py @@ -0,0 +1,10 @@ +class BaseNodeError(ValueError): + """Base class for node errors.""" + + pass + + +class DefaultValueTypeError(BaseNodeError): + """Raised when the default value type is invalid.""" + + pass diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 323ff7f7fb0eae..71a2e70ec10190 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -24,12 +24,12 @@ class NodeType(StrEnum): LIST_OPERATOR = "list-operator" -class ErrorStrategy(str, Enum): +class ErrorStrategy(StrEnum): FAIL_BRANCH = "fail-branch" DEFAULT_VALUE = "default-value" -class FailBranchSourceHandle(str, Enum): +class FailBranchSourceHandle(StrEnum): FAILED = "fail-branch" SUCCESS = "success-branch" diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 1413adf7196879..8390c665561841 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -14,6 +14,7 @@ "total_steps": fields.Integer, "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } workflow_run_for_list_fields = { @@ -27,6 +28,7 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -42,6 +44,7 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -73,6 +76,7 @@ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } workflow_run_node_execution_fields = { diff --git a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py new file mode 100644 index 00000000000000..8c576339bae8cf --- /dev/null +++ b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py @@ -0,0 +1,33 @@ +"""add exceptions_count field to WorkflowRun model + +Revision ID: cf8f4fc45278 +Revises: 01d6889832f7 +Create Date: 2024-11-28 05:53:21.576178 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf8f4fc45278' +down_revision = '01d6889832f7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_column('exceptions_count') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index 742b4a748a966d..a3f670c7f41231 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -323,6 +323,7 @@ class WorkflowRunStatus(StrEnum): SUCCEEDED = "succeeded" FAILED = "failed" STOPPED = "stopped" + PARTIAL_SUCCESSED = "partial-succeeded" @classmethod def value_of(cls, value: str) -> "WorkflowRunStatus": @@ -393,7 +394,7 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[str] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) @@ -403,6 +404,7 @@ class WorkflowRun(db.Model): created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) + exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): @@ -462,6 +464,7 @@ def to_dict(self): "created_by": self.created_by, "created_at": self.created_at, "finished_at": self.finished_at, + "exceptions_count": self.exceptions_count, } @classmethod @@ -487,6 +490,7 @@ def from_dict(cls, data: dict) -> "WorkflowRun": created_by=data.get("created_by"), created_at=data.get("created_at"), finished_at=data.get("finished_at"), + exceptions_count=data.get("exceptions_count"), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index bf7dda01059347..06a4b114005ed3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -1,6 +1,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( + GraphRunPartialSucceededEvent, GraphRunSucceededEvent, NodeRunExceptionEvent, NodeRunStreamChunkEvent, @@ -212,15 +213,16 @@ def main() -> dict: "nodes": [ {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_code_node(error_code, "default-value", {"result": 132123}), + ContinueOnErrorTestHelper.get_code_node( + error_code, "default-value", [{"key": "result", "type": "Number", "value": 132123}] + ), ], } graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) events = list(graph_engine.run()) - assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -253,7 +255,7 @@ def main() -> dict: assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any( - isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events ) @@ -296,7 +298,9 @@ def test_http_node_default_value_continue_on_error(): "nodes": [ {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_http_node("default-value", {"response": "http node got error response"}), + ContinueOnErrorTestHelper.get_http_node( + "default-value", [{"key": "response", "type": "String", "value": "http node got error response"}] + ), ], } @@ -305,7 +309,7 @@ def test_http_node_default_value_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any( - isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "http node got error response"} + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} for e in events ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -333,7 +337,9 @@ def test_http_node_fail_branch_continue_on_error(): events = list(graph_engine.run()) assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events + ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -344,7 +350,9 @@ def test_tool_node_default_value_continue_on_error(): "nodes": [ {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_tool_node("default-value", {"result": "default tool result"}), + ContinueOnErrorTestHelper.get_tool_node( + "default-value", [{"key": "result", "type": "String", "value": "default tool result"}] + ), ], } @@ -352,7 +360,9 @@ def test_tool_node_default_value_continue_on_error(): events = list(graph_engine.run()) assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events + ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -378,7 +388,9 @@ def test_tool_node_fail_branch_continue_on_error(): events = list(graph_engine.run()) assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events + ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -389,7 +401,9 @@ def test_llm_node_default_value_continue_on_error(): "nodes": [ {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, - ContinueOnErrorTestHelper.get_llm_node("default-value", {"answer": "default LLM response"}), + ContinueOnErrorTestHelper.get_llm_node( + "default-value", [{"key": "answer", "type": "String", "value": "default LLM response"}] + ), ], } @@ -398,7 +412,7 @@ def test_llm_node_default_value_continue_on_error(): assert any(isinstance(e, NodeRunExceptionEvent) for e in events) assert any( - isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -425,7 +439,9 @@ def test_llm_node_fail_branch_continue_on_error(): events = list(graph_engine.run()) assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events + ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 @@ -451,7 +467,9 @@ def test_status_code_error_http_node_fail_branch_continue_on_error(): events = list(graph_engine.run()) assert any(isinstance(e, NodeRunExceptionEvent) for e in events) - assert any(isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events + ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 From 560b7fec50328408700c65dadef07b9610444120 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Wed, 4 Dec 2024 19:10:45 +0800 Subject: [PATCH 08/16] feat: add error type to defualt value --- api/core/workflow/entities/node_entities.py | 1 + .../workflow/graph_engine/entities/event.py | 2 +- .../workflow/graph_engine/graph_engine.py | 27 +++++++++++-- api/core/workflow/nodes/base/entities.py | 4 +- api/services/workflow_service.py | 38 +++++++++++++++++-- .../workflow/nodes/test_continue_on_error.py | 24 +++++++++++- 6 files changed, 85 insertions(+), 11 deletions(-) diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index ad44798a4d37c1..976a5ef74e320d 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum): PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs + ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index c10280776e162a..ed816b81ff9578 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -32,7 +32,7 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") - exceptions_count: int = Field(..., description="exception count") + exceptions_count: Optional[int] = Field(description="exception count", default=0) class GraphRunPartialSucceededEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e1a45714518137..ea5d53d550c754 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -196,7 +196,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e), handle_exceptions=handle_exceptions) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) self._release_thread() raise e @@ -314,6 +314,11 @@ def _run( break if len(edge_mappings) == 1: + if ( + previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + ): + break edge = edge_mappings[0] if edge.run_condition: @@ -332,7 +337,6 @@ def _run( next_node_id = edge.target_node_id else: final_node_id = None - if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results condition_edge_mappings = {} @@ -611,6 +615,7 @@ def _run_node( handle_exceptions=handle_exceptions, ) route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION if run_result.outputs: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively @@ -808,15 +813,29 @@ def _handle_continue_on_error( "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": error_result.error, "inputs": error_result.inputs, + "metadata": { + NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + }, } if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, - outputs=node_instance.node_data.default_value_dict, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": error_result.error, + "error_type": error_result.error_type, + }, ) - return NodeRunResult(**node_error_args, outputs=None, edge_source_handle=FailBranchSourceHandle.FAILED) + return NodeRunResult( + **node_error_args, + outputs={ + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + edge_source_handle=FailBranchSourceHandle.FAILED, + ) class GraphRunFailedError(Exception): diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 6945a76ed35198..35e683877ef1ce 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -16,7 +16,7 @@ class DefaultValueType(StrEnum): ARRAY_NUMBER = "Array[Number]" ARRAY_STRING = "Array[String]" ARRAY_OBJECT = "Array[Object]" - ARRAY_FILES = "Array[Files]" + ARRAY_FILES = "Array[File]" NumberType = Union[int, float] @@ -101,6 +101,8 @@ def validate_value_type(self) -> Any: # handle files type pass + return self + class BaseNodeData(ABC, BaseModel): title: str diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 37d7d0937cd492..ab33095ecea3c4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,6 +11,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes import NodeType +from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry @@ -237,8 +238,34 @@ def run_draft_workflow_node( if not node_run_result: raise ValueError("Node run failed with no run result") - - run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance @@ -260,7 +287,6 @@ def run_draft_workflow_node( workflow_node_execution.created_by = account.id workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - if run_succeeded and node_run_result: # create workflow node execution inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None @@ -277,7 +303,11 @@ def run_draft_workflow_node( workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.error = node_run_result.error else: # create workflow node execution workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 06a4b114005ed3..ec6d067a7165ab 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -494,7 +494,7 @@ def test_variable_pool_error_type_variable(): list(graph_engine.run()) error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) - assert error_message.value == "Request failed with status code 404" + assert error_message != None assert error_type.value == "HTTPResponseCodeError" @@ -554,3 +554,25 @@ def main() -> dict: for e in events ) assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 2 + + +def test_no_node_in_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES[:-1], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0 From 6f46c36c41a97537a15abfd5c8fb48fccf6a0741 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Wed, 4 Dec 2024 21:39:25 +0800 Subject: [PATCH 09/16] fix: update default value type strings to lowercase --- api/core/workflow/nodes/base/entities.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 35e683877ef1ce..0a3d288aa0315b 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -10,13 +10,13 @@ class DefaultValueType(StrEnum): - STRING = "String" - NUMBER = "Number" - OBJECT = "Object" - ARRAY_NUMBER = "Array[Number]" - ARRAY_STRING = "Array[String]" - ARRAY_OBJECT = "Array[Object]" - ARRAY_FILES = "Array[File]" + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" NumberType = Union[int, float] From e307f1c403fc910f843896e4f8599fedeb822204 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Thu, 5 Dec 2024 15:05:53 +0800 Subject: [PATCH 10/16] feat: Handle edge cases for the fail branch --- .../apps/workflow/generate_task_pipeline.py | 29 ++++++++++++++++-- .../workflow/graph_engine/graph_engine.py | 30 +++++++++++-------- api/services/workflow_service.py | 23 +++++++------- 3 files changed, 57 insertions(+), 25 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 82bec2e4d8e75b..fd0ac7e814abcc 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -392,13 +392,38 @@ def _process_stream_response( "conversation_id": None, "trace_manager": trace_manager, } - if isinstance(event, QueueWorkflowFailedEvent): - handle_args["exceptions_count"] = event.exceptions_count workflow_run = self._handle_workflow_run_failed(**handle_args) # save workflow app log self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + handle_args = { + "workflow_run": workflow_run, + "start_at": graph_runtime_state.start_at, + "total_tokens": graph_runtime_state.total_tokens, + "total_steps": graph_runtime_state.node_run_steps, + "status": WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + "conversation_id": None, + "trace_manager": trace_manager, + "exceptions_count": event.exceptions_count, + } + workflow_run = self._handle_workflow_run_partial_success(**handle_args) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ea5d53d550c754..55869cccacca13 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -255,7 +255,7 @@ def _run( previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) try: # run node generator = self._run_node( @@ -314,13 +314,13 @@ def _run( break if len(edge_mappings) == 1: + edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and edge.run_condition is None ): break - edge = edge_mappings[0] - if edge.run_condition: result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -651,7 +651,9 @@ def _run_node( ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - if node_instance.should_continue_on_error: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens @@ -827,15 +829,17 @@ def _handle_continue_on_error( "error_type": error_result.error_type, }, ) - - return NodeRunResult( - **node_error_args, - outputs={ - "error_message": error_result.error, - "error_type": error_result.error_type, - }, - edge_source_handle=FailBranchSourceHandle.FAILED, - ) + elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node_instance.node_id): + node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED + return NodeRunResult( + **node_error_args, + outputs={ + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + ) + return error_result class GraphRunFailedError(Exception): diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ab33095ecea3c4..84768d5af053e4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional +from typing import Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -11,6 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING @@ -226,7 +228,7 @@ def run_draft_workflow_node( user_inputs=user_inputs, user_id=account.id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) node_run_result: NodeRunResult | None = None for event in generator: if isinstance(event, RunCompletedEvent): @@ -238,6 +240,7 @@ def run_draft_workflow_node( if not node_run_result: raise ValueError("Node run failed with no run result") + # single step debug mode error handling return if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: node_error_args = { "status": WorkflowNodeExecutionStatus.EXCEPTION, @@ -254,14 +257,14 @@ def run_draft_workflow_node( "error_type": node_run_result.error_type, }, ) - - node_run_result = NodeRunResult( - **node_error_args, - outputs={ - "error_message": node_run_result.error, - "error_type": node_run_result.error_type, - }, - ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) run_succeeded = node_run_result.status in ( WorkflowNodeExecutionStatus.SUCCEEDED, WorkflowNodeExecutionStatus.EXCEPTION, From e061bacc7773506f7622a5fd508d81dec086bfcf Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Thu, 5 Dec 2024 17:42:14 +0800 Subject: [PATCH 11/16] feat: Handle edge cases for the fail branch --- .../apps/advanced_chat/generate_task_pipeline.py | 1 + .../app/apps/workflow/generate_task_pipeline.py | 1 + .../app/task_pipeline/workflow_cycle_manage.py | 3 ++- api/core/workflow/graph_engine/graph_engine.py | 15 ++++++++++++--- .../integration_tests/workflow/nodes/test_code.py | 2 -- 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 52d4fc0ff2c08b..c3b2f5e4887887 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -429,6 +429,7 @@ def _process_stream_response( error=event.error, conversation_id=self._conversation.id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) yield self._workflow_finish_to_stream_response( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index fd0ac7e814abcc..8bb073dd63e74b 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -391,6 +391,7 @@ def _process_stream_response( "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), "conversation_id": None, "trace_manager": trace_manager, + "exceptions_count": event.exceptions_count, } workflow_run = self._handle_workflow_run_failed(**handle_args) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index ab2536c3b04696..d78f124e3a2690 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -224,6 +224,7 @@ def _handle_workflow_run_failed( error: str, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, + exceptions_count: int = 0, ) -> WorkflowRun: """ Workflow run failed @@ -243,7 +244,7 @@ def _handle_workflow_run_failed( workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - + workflow_run.exceptions_count = exceptions_count db.session.commit() running_workflow_node_executions = ( diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 55869cccacca13..3226ce8b0d3587 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -152,7 +152,10 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") + yield GraphRunFailedEvent( + error=item.route_node_state.failed_reason or "Unknown error.", + exceptions_count=len(handle_exceptions), + ) return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: @@ -178,7 +181,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ].strip() except Exception as e: logger.exception("Graph run failed") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) return # count exceptions to determine partial success exceptions_count = len(handle_exceptions) @@ -196,7 +199,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=exceptions_count) self._release_thread() raise e @@ -387,6 +390,12 @@ def _run( break next_node_id = final_node_id + elif ( + node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node_instance.should_continue_on_error + and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + ): + break else: parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 362ab127472696..9ffd3bc0afc6d9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -14,10 +14,8 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.event.event import RunCompletedEvent from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType -from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) From bb4b6e709779c2d1731d5d348591bae0cdf4cac7 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Fri, 6 Dec 2024 11:21:46 +0800 Subject: [PATCH 12/16] fix: parallel logs error --- api/core/workflow/graph_engine/graph_engine.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 3226ce8b0d3587..35f5afef300e5f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -376,6 +376,7 @@ def _run( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in parallel_generator: @@ -401,6 +402,7 @@ def _run( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in parallel_generator: @@ -422,6 +424,7 @@ def _run_parallel_branches( edge_mappings: list[GraphEdge], in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -465,6 +468,7 @@ def _run_parallel_branches( "parallel_start_node_id": edge.target_node_id, "parent_parallel_id": in_parallel_id, "parent_parallel_start_node_id": parallel_start_node_id, + "handle_exceptions": handle_exceptions, }, ) @@ -508,6 +512,7 @@ def _run_parallel_node( parallel_start_node_id: str, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> None: """ Run parallel nodes @@ -529,6 +534,7 @@ def _run_parallel_node( in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: From 97fdea037fda1027ac4c5baa8125d875c394e224 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Sun, 8 Dec 2024 13:16:54 +0800 Subject: [PATCH 13/16] feat: handle parallel fail branch --- .../app/apps/workflow/generate_task_pipeline.py | 2 +- api/core/workflow/graph_engine/entities/graph.py | 14 ++++++-------- api/core/workflow/graph_engine/graph_engine.py | 11 +++++------ .../question_classifier_node.py | 2 +- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 8bb073dd63e74b..60030d38bde128 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -391,7 +391,7 @@ def _process_stream_response( "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), "conversation_id": None, "trace_manager": trace_manager, - "exceptions_count": event.exceptions_count, + "exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, } workflow_run = self._handle_workflow_run_failed(**handle_args) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d27dd6b6360d3f..4f7bc60e26b5e2 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -64,15 +64,20 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] + # node configs + node_configs = graph_config.get("nodes") + if not node_configs: + raise ValueError("Graph must have at least one node") edge_configs = cast(list, edge_configs) + node_configs = cast(list, node_configs) # reorganize edges mapping edge_mapping: dict[str, list[GraphEdge]] = {} reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() fail_branch_source_node_id = [ - edge["source"] for edge in edge_configs if edge.get("sourceHandle") == "fail-branch" + node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch" ] for edge_config in edge_configs: source_node_id = edge_config.get("source") @@ -111,13 +116,6 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) - # node configs - node_configs = graph_config.get("nodes") - if not node_configs: - raise ValueError("Graph must have at least one node") - - node_configs = cast(list, node_configs) - # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 35f5afef300e5f..0730a2732008d0 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -132,6 +132,7 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() + handle_exceptions = [] try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -144,7 +145,6 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ) # run graph - handle_exceptions = [] generator = stream_processor.process( self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) ) @@ -184,22 +184,21 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) return # count exceptions to determine partial success - exceptions_count = len(handle_exceptions) - if exceptions_count > 0: + if len(handle_exceptions) > 0: yield GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, outputs=self.graph_runtime_state.outputs + exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs ) else: # trigger graph run success event yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) self._release_thread() except GraphRunFailedError as e: - yield GraphRunFailedEvent(error=e.error, exceptions_count=exceptions_count) + yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions)) self._release_thread() return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e), exceptions_count=exceptions_count) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) self._release_thread() raise e diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index e855ab2d2b0659..7594036b50abb4 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -139,7 +139,7 @@ def _run(self): "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } - outputs = {"class_name": category_name} + outputs = {"class_name": category_name, "class_id": category_id} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, From 1d118a0c4c26676d805a56fbf2dce408496746b5 Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Sun, 8 Dec 2024 14:56:08 +0800 Subject: [PATCH 14/16] fix: test cases error --- .../workflow/graph_engine/graph_engine.py | 26 ++++---- .../workflow/nodes/test_continue_on_error.py | 66 +++++++------------ 2 files changed, 37 insertions(+), 55 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0730a2732008d0..e03d4a7194a11e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -339,6 +339,7 @@ def _run( next_node_id = edge.target_node_id else: final_node_id = None + if any(edge.run_condition for edge in edge_mappings): # if nodes has run conditions, get node id which branch to take based on the run condition results condition_edge_mappings = {} @@ -701,19 +702,18 @@ def _run_node( run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( parent_parallel_start_node_id ) - event_args = { - "id": node_instance.id, - "node_id": node_instance.node_id, - "node_type": node_instance.node_type, - "node_data": node_instance.node_data, - "route_node_state": route_node_state, - "parallel_id": parallel_id, - "parallel_start_node_id": parallel_start_node_id, - "parent_parallel_id": parent_parallel_id, - "parent_parallel_start_node_id": parent_parallel_start_node_id, - } - event = NodeRunSucceededEvent(**event_args) - yield event + + yield NodeRunSucceededEvent( + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) break elif isinstance(item, RunStreamChunkEvent): diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index ec6d067a7165ab..30751fc104fdb1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -33,8 +33,25 @@ def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: return node @staticmethod - def get_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + def get_http_node( + error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + ): """Helper method to create a http node configuration""" + authorization = ( + { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + } + if authorization_success + else { + "type": "api-key", + # missing config field + } + ) node = { "id": "node", "data": { @@ -42,10 +59,7 @@ def get_http_node(error_strategy: str = "fail-branch", default_value: dict | Non "desc": "", "method": "get", "url": "http://example.com", - "authorization": { - "type": "api-key", - # missing config field - }, + "authorization": authorization, "headers": "X-Header:123", "params": "A:b", "body": None, @@ -214,7 +228,7 @@ def main() -> dict: {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_code_node( - error_code, "default-value", [{"key": "result", "type": "Number", "value": 132123}] + error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] ), ], } @@ -259,38 +273,6 @@ def main() -> dict: ) -def test_code_success_branch_continue_on_error(): - success_code = """ - def main() -> dict: - return { - "result": 1 / 1, - } - """ - - graph_config = { - "edges": FAIL_BRANCH_EDGES, - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - ContinueOnErrorTestHelper.get_code_node(success_code), - { - "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, - "id": "error", - }, - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert any( - isinstance(e, GraphRunSucceededEvent) and e.outputs == {"answer": "node node run successfully"} for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 - - def test_http_node_default_value_continue_on_error(): """Test HTTP node with default value error strategy""" graph_config = { @@ -299,7 +281,7 @@ def test_http_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_http_node( - "default-value", [{"key": "response", "type": "String", "value": "http node got error response"}] + "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] ), ], } @@ -351,7 +333,7 @@ def test_tool_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_tool_node( - "default-value", [{"key": "result", "type": "String", "value": "default tool result"}] + "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] ), ], } @@ -402,7 +384,7 @@ def test_llm_node_default_value_continue_on_error(): {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, ContinueOnErrorTestHelper.get_llm_node( - "default-value", [{"key": "answer", "type": "String", "value": "default LLM response"}] + "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] ), ], } @@ -531,7 +513,7 @@ def main() -> dict: "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, "id": "error", }, - ContinueOnErrorTestHelper.get_code_node(code=success_code), + ContinueOnErrorTestHelper.get_http_node(authorization_success=True), { "id": "code", "data": { From bbf9b11ff595f645f68257abbca8b99b62cd182b Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Sun, 8 Dec 2024 15:39:23 +0800 Subject: [PATCH 15/16] fix: test cases error --- .../workflow/nodes/test_code.py | 1 + .../workflow/nodes/test_continue_on_error.py | 58 ------------------- 2 files changed, 1 insertion(+), 58 deletions(-) diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 9ffd3bc0afc6d9..4de985ae7c9dea 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -16,6 +16,7 @@ from core.workflow.nodes.code.entities import CodeNodeData from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType +from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000")) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 30751fc104fdb1..ba209e4020afad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -480,64 +480,6 @@ def test_variable_pool_error_type_variable(): assert error_type.value == "HTTPResponseCodeError" -def test_continue_on_error_link_fail_branch(): - success_code = """ - def main() -> dict: - return { - "result": 1 / 1, - } - """ - graph_config = { - "edges": [ - *FAIL_BRANCH_EDGES, - { - "id": "start-source-code-target", - "source": "start", - "target": "code", - "sourceHandle": "source", - }, - { - "id": "code-source-error-target", - "source": "code", - "target": "error", - "sourceHandle": "source", - }, - ], - "nodes": [ - {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, - { - "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, - "id": "success", - }, - { - "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, - "id": "error", - }, - ContinueOnErrorTestHelper.get_http_node(authorization_success=True), - { - "id": "code", - "data": { - "outputs": {"result": {"type": "number"}}, - "title": "code", - "variables": [], - "code_language": "python3", - "code": "\n".join([line[4:] for line in success_code.split("\n")]), - "type": "code", - }, - }, - ], - } - - graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) - events = list(graph_engine.run()) - assert any( - isinstance(e, GraphRunSucceededEvent) - and e.outputs == {"answer": "http execute successful\nhttp execute failed"} - for e in events - ) - assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 2 - - def test_no_node_in_fail_branch_continue_on_error(): """Test HTTP node with fail-branch error strategy""" graph_config = { From fe4e3ae895d04937317cc23ef1727c56fffc5a3a Mon Sep 17 00:00:00 2001 From: Novice Lee Date: Mon, 9 Dec 2024 11:52:06 +0800 Subject: [PATCH 16/16] feat: correct the default value error message --- api/core/workflow/nodes/base/entities.py | 144 ++++++++++++----------- api/core/workflow/nodes/base/exc.py | 2 +- 2 files changed, 74 insertions(+), 72 deletions(-) diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 0a3d288aa0315b..9271867afffa6e 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -20,86 +20,88 @@ class DefaultValueType(StrEnum): NumberType = Union[int, float] -ObjectType = dict[str, Any] class DefaultValue(BaseModel): - value: Union[ - str, - NumberType, - ObjectType, - list[NumberType], - list[str], - list[ObjectType], - ] + value: Any type: DefaultValueType key: str + @staticmethod + def _parse_json(value: str) -> Any: + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + @model_validator(mode="after") - def validate_value_type(self) -> Any: - value_type = self.type - value = self.value - if value_type is None: + def validate_value_type(self) -> "DefaultValue": + if self.type is None: raise DefaultValueTypeError("type field is required") - # validate string type - if value_type == DefaultValueType.STRING: - if not isinstance(value, str): - raise DefaultValueTypeError(f"Value must be string type for {value}") - - # validate number type - elif value_type == DefaultValueType.NUMBER: - if not isinstance(value, NumberType): - raise DefaultValueTypeError(f"Value must be number type for {value}") - - # validate object type - elif value_type == DefaultValueType.OBJECT: - if isinstance(value, str): - try: - value = json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Value must be object type for {value}") - if not isinstance(value, ObjectType): - raise DefaultValueTypeError(f"Value must be object type for {value}") - - # validate array[number] type - elif value_type == DefaultValueType.ARRAY_NUMBER: - if isinstance(value, str): - try: - value = json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Value must be object type for {value}") - if not isinstance(value, list): - raise DefaultValueTypeError(f"Value must be array type for {value}") - if not all(isinstance(x, NumberType) for x in value): - raise DefaultValueTypeError(f"All elements must be numbers for {value}") - - # validate array[string] type - elif value_type == DefaultValueType.ARRAY_STRING: - if isinstance(value, str): - try: - value = json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Value must be object type for {value}") - if not isinstance(value, list): - raise DefaultValueTypeError(f"Value must be array type for {value}") - if not all(isinstance(x, str) for x in value): - raise DefaultValueTypeError(f"All elements must be strings for {value}") - - # validate array[object] type - elif value_type == DefaultValueType.ARRAY_OBJECT: - if isinstance(value, str): - try: - value = json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Value must be object type for {value}") - if not isinstance(value, list): - raise DefaultValueTypeError(f"Value must be array type for {value}") - if not all(isinstance(x, ObjectType) for x in value): - raise DefaultValueTypeError(f"All elements must be objects for {value}") - elif value_type == DefaultValueType.ARRAY_FILES: - # handle files type - pass + # Type validation configuration + type_validators = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator = type_validators.get(self.type) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") return self diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py index aeecf406403e6d..ec134e031cf9d3 100644 --- a/api/core/workflow/nodes/base/exc.py +++ b/api/core/workflow/nodes/base/exc.py @@ -1,4 +1,4 @@ -class BaseNodeError(ValueError): +class BaseNodeError(Exception): """Base class for node errors.""" pass