From 79a710ce98c8ae5b346fb856a9d33b7d657542e1 Mon Sep 17 00:00:00 2001 From: Novice <857526207@qq.com> Date: Wed, 11 Dec 2024 14:22:42 +0800 Subject: [PATCH] Feat: continue on error (#11458) Co-authored-by: Novice Lee Co-authored-by: Novice Lee --- .../advanced_chat/generate_task_pipeline.py | 28 +- .../app/apps/workflow/app_queue_manager.py | 4 +- .../apps/workflow/generate_task_pipeline.py | 68 ++- api/core/app/apps/workflow_app_runner.py | 40 +- api/core/app/entities/queue_entities.py | 44 ++ api/core/app/entities/task_entities.py | 1 + .../task_pipeline/workflow_cycle_manage.py | 74 ++- api/core/helper/ssrf_proxy.py | 8 +- .../callbacks/workflow_logging_callback.py | 3 + api/core/workflow/entities/node_entities.py | 2 + .../workflow/graph_engine/entities/event.py | 10 + .../workflow/graph_engine/entities/graph.py | 27 +- .../entities/runtime_route_state.py | 10 +- .../workflow/graph_engine/graph_engine.py | 172 +++++- .../answer/answer_stream_generate_router.py | 21 +- .../nodes/answer/answer_stream_processor.py | 3 +- api/core/workflow/nodes/base/entities.py | 114 +++- api/core/workflow/nodes/base/exc.py | 10 + api/core/workflow/nodes/base/node.py | 16 +- api/core/workflow/nodes/code/code_node.py | 4 +- api/core/workflow/nodes/enums.py | 13 + .../workflow/nodes/http_request/executor.py | 7 +- api/core/workflow/nodes/http_request/node.py | 16 + api/core/workflow/nodes/llm/node.py | 1 + .../question_classifier_node.py | 2 +- api/core/workflow/nodes/tool/tool_node.py | 2 + api/fields/workflow_run_fields.py | 4 + ...4fc45278_add_exceptions_count_field_to_.py | 33 ++ api/models/workflow.py | 7 +- api/services/workflow_service.py | 45 +- .../workflow/nodes/test_continue_on_error.py | 502 ++++++++++++++++++ 31 files changed, 1211 insertions(+), 80 deletions(-) create mode 100644 api/core/workflow/nodes/base/exc.py create mode 100644 api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 0c8cd384d5a930..32a23a7fdb8690 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -19,6 +19,7 @@ QueueIterationNextEvent, QueueIterationStartEvent, QueueMessageReplaceEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -31,6 +32,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -317,7 +319,7 @@ def _process_stream_response( if response: yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( @@ -384,6 +386,29 @@ def _process_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + + workflow_run = self._handle_workflow_run_partial_success( + workflow_run=workflow_run, + start_at=graph_runtime_state.start_at, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) elif isinstance(event, QueueWorkflowFailedEvent): if not workflow_run: @@ -401,6 +426,7 @@ def _process_stream_response( error=event.error, conversation_id=self._conversation.id, trace_manager=trace_manager, + exceptions_count=event.exceptions_count, ) yield self._workflow_finish_to_stream_response( diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 76371f800ba1e5..349b8eb51b1546 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -6,6 +6,7 @@ QueueMessageEndEvent, QueueStopEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, WorkflowQueueMessage, ) @@ -34,7 +35,8 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: | QueueErrorEvent | QueueMessageEndEvent | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent, + | QueueWorkflowFailedEvent + | QueueWorkflowPartialSuccessEvent, ): self.stop_listen() diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9e4921d6a22c5a..8483fa91f80a02 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -15,6 +15,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -26,6 +27,7 @@ QueueStopEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -276,7 +278,7 @@ def _process_stream_response( if response: yield response - elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( @@ -345,22 +347,20 @@ def _process_stream_response( yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) - elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + elif isinstance(event, QueueWorkflowPartialSuccessEvent): if not workflow_run: raise Exception("Workflow run not initialized.") if not graph_runtime_state: raise Exception("Graph runtime state not initialized.") - workflow_run = self._handle_workflow_run_failed( + workflow_run = self._handle_workflow_run_partial_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + outputs=event.outputs, + exceptions_count=event.exceptions_count, conversation_id=None, trace_manager=trace_manager, ) @@ -368,6 +368,60 @@ def _process_stream_response( # save workflow app log self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + handle_args = { + "workflow_run": workflow_run, + "start_at": graph_runtime_state.start_at, + "total_tokens": graph_runtime_state.total_tokens, + "total_steps": graph_runtime_state.node_run_steps, + "status": WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + "conversation_id": None, + "trace_manager": trace_manager, + "exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + } + workflow_run = self._handle_workflow_run_failed(**handle_args) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + + yield self._workflow_finish_to_stream_response( + task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + ) + elif isinstance(event, QueueWorkflowPartialSuccessEvent): + if not workflow_run: + raise Exception("Workflow run not initialized.") + + if not graph_runtime_state: + raise Exception("Graph runtime state not initialized.") + handle_args = { + "workflow_run": workflow_run, + "start_at": graph_runtime_state.start_at, + "total_tokens": graph_runtime_state.total_tokens, + "total_steps": graph_runtime_state.node_run_steps, + "status": WorkflowRunStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowRunStatus.STOPPED, + "error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + "conversation_id": None, + "trace_manager": trace_manager, + "exceptions_count": event.exceptions_count, + } + workflow_run = self._handle_workflow_run_partial_success(**handle_args) + + # save workflow app log + self._save_workflow_app_log(workflow_run) + yield self._workflow_finish_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run ) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 3d46b8bab03e17..97c2cc5bb9bfd9 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -8,6 +8,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -18,6 +19,7 @@ QueueRetrieverResourcesEvent, QueueTextChunkEvent, QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) @@ -25,6 +27,7 @@ from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -32,6 +35,7 @@ IterationRunStartedEvent, IterationRunSucceededEvent, NodeInIterationFailedEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -176,8 +180,12 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) ) elif isinstance(event, GraphRunSucceededEvent): self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs)) + elif isinstance(event, GraphRunPartialSucceededEvent): + self._publish_event( + QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count) + ) elif isinstance(event, GraphRunFailedEvent): - self._publish_event(QueueWorkflowFailedEvent(error=event.error)) + self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count)) elif isinstance(event, NodeRunStartedEvent): self._publish_event( QueueNodeStartedEvent( @@ -253,6 +261,36 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeRunExceptionEvent): + self._publish_event( + QueueNodeExceptionEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + error=event.route_node_state.node_run_result.error + if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error + else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + ) + ) elif isinstance(event, NodeInIterationFailedEvent): self._publish_event( QueueNodeInIterationFailedEvent( diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5e9b6517bae01f..5b2036c7f9ba6a 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -25,12 +25,14 @@ class QueueEvent(StrEnum): WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" + WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded" ITERATION_START = "iteration_start" ITERATION_NEXT = "iteration_next" ITERATION_COMPLETED = "iteration_completed" NODE_STARTED = "node_started" NODE_SUCCEEDED = "node_succeeded" NODE_FAILED = "node_failed" + NODE_EXCEPTION = "node_exception" RETRIEVER_RESOURCES = "retriever_resources" ANNOTATION_REPLY = "annotation_reply" AGENT_THOUGHT = "agent_thought" @@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent): event: QueueEvent = QueueEvent.WORKFLOW_FAILED error: str + exceptions_count: int + + +class QueueWorkflowPartialSuccessEvent(AppQueueEvent): + """ + QueueWorkflowFailedEvent entity + """ + + event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED + exceptions_count: int + outputs: Optional[dict[str, Any]] = None class QueueNodeStartedEvent(AppQueueEvent): @@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): error: str +class QueueNodeExceptionEvent(AppQueueEvent): + """ + QueueNodeExceptionEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_EXCEPTION + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 03cc6941a84623..7fe06b3af8bbb4 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -213,6 +213,7 @@ class Data(BaseModel): created_by: Optional[dict] = None created_at: int finished_at: int + exceptions_count: Optional[int] = 0 files: Optional[Sequence[Mapping[str, Any]]] = [] event: StreamEvent = StreamEvent.WORKFLOW_FINISHED diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 57a02f8bc85eac..d78f124e3a2690 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,6 +12,7 @@ QueueIterationCompletedEvent, QueueIterationNextEvent, QueueIterationStartEvent, + QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, @@ -164,6 +165,55 @@ def _handle_workflow_run_success( return workflow_run + def _handle_workflow_run_partial_success( + self, + workflow_run: WorkflowRun, + start_at: float, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + exceptions_count: int = 0, + conversation_id: Optional[str] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> WorkflowRun: + """ + Workflow run success + :param workflow_run: workflow run + :param start_at: start time + :param total_tokens: total tokens + :param total_steps: total steps + :param outputs: outputs + :param conversation_id: conversation id + :return: + """ + workflow_run = self._refetch_workflow_run(workflow_run.id) + + outputs = WorkflowEntry.handle_special_values(outputs) + + workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value + workflow_run.outputs = json.dumps(outputs or {}) + workflow_run.elapsed_time = time.perf_counter() - start_at + workflow_run.total_tokens = total_tokens + workflow_run.total_steps = total_steps + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_run.exceptions_count = exceptions_count + db.session.commit() + db.session.refresh(workflow_run) + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_run=workflow_run, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + db.session.close() + + return workflow_run + def _handle_workflow_run_failed( self, workflow_run: WorkflowRun, @@ -174,6 +224,7 @@ def _handle_workflow_run_failed( error: str, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, + exceptions_count: int = 0, ) -> WorkflowRun: """ Workflow run failed @@ -193,7 +244,7 @@ def _handle_workflow_run_failed( workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - + workflow_run.exceptions_count = exceptions_count db.session.commit() running_workflow_node_executions = ( @@ -318,7 +369,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent return workflow_node_execution def _handle_workflow_node_execution_failed( - self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent ) -> WorkflowNodeExecution: """ Workflow node execution failed @@ -337,7 +388,11 @@ def _handle_workflow_node_execution_failed( ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { - WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, + WorkflowNodeExecution.status: ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ), WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, @@ -351,8 +406,11 @@ def _handle_workflow_node_execution_failed( db.session.commit() db.session.close() process_data = WorkflowEntry.handle_special_values(event.process_data) - - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value + workflow_node_execution.status = ( + WorkflowNodeExecutionStatus.FAILED.value + if not isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.EXCEPTION.value + ) workflow_node_execution.error = event.error workflow_node_execution.inputs = json.dumps(inputs) if inputs else None workflow_node_execution.process_data = json.dumps(process_data) if process_data else None @@ -433,6 +491,7 @@ def _workflow_finish_to_stream_response( created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), + exceptions_count=workflow_run.exceptions_count, ), ) @@ -483,7 +542,10 @@ def _workflow_node_start_to_stream_response( def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, + event: QueueNodeSucceededEvent + | QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeExceptionEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 566293d1250402..ef4516b404af13 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,6 +24,12 @@ STATUS_FORCELIST = [429, 500, 502, 503, 504] +class MaxRetriesExceededError(Exception): + """Raised when the maximum number of retries is exceeded.""" + + pass + + def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if "allow_redirects" in kwargs: allow_redirects = kwargs.pop("allow_redirects") @@ -64,7 +70,7 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): if retries <= max_retries: time.sleep(BACKOFF_FACTOR * (2 ** (retries - 1))) - raise Exception(f"Reached maximum retries ({max_retries}) for URL {url}") + raise MaxRetriesExceededError(f"Reached maximum retries ({max_retries}) for URL {url}") def get(url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py index 17913de7b0d2ce..ed737e7316973c 100644 --- a/api/core/workflow/callbacks/workflow_logging_callback.py +++ b/api/core/workflow/callbacks/workflow_logging_callback.py @@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, IterationRunFailedEvent, @@ -39,6 +40,8 @@ def on_event(self, event: GraphEngineEvent) -> None: self.print_text("\n[GraphRunStartedEvent]", color="pink") elif isinstance(event, GraphRunSucceededEvent): self.print_text("\n[GraphRunSucceededEvent]", color="green") + elif isinstance(event, GraphRunPartialSucceededEvent): + self.print_text("\n[GraphRunPartialSucceededEvent]", color="pink") elif isinstance(event, GraphRunFailedEvent): self.print_text(f"\n[GraphRunFailedEvent] reason: {event.error}", color="red") elif isinstance(event, NodeRunStartedEvent): diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index e174d3baa0c736..976a5ef74e320d 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -25,6 +25,7 @@ class NodeRunMetadataKey(StrEnum): PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs + ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field class NodeRunResult(BaseModel): @@ -43,3 +44,4 @@ class NodeRunResult(BaseModel): edge_source_handle: Optional[str] = None # source handle id of node with multiple branches error: Optional[str] = None # error message if status is failed + error_type: Optional[str] = None # error type if status is failed diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index cb73da3cd613f6..73450349ded634 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -33,6 +33,12 @@ class GraphRunSucceededEvent(BaseGraphEvent): class GraphRunFailedEvent(BaseGraphEvent): error: str = Field(..., description="failed reason") + exceptions_count: Optional[int] = Field(description="exception count", default=0) + + +class GraphRunPartialSucceededEvent(BaseGraphEvent): + exceptions_count: int = Field(..., description="exception count") + outputs: Optional[dict[str, Any]] = None ########################################### @@ -83,6 +89,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeRunExceptionEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + class NodeInIterationFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index d87c039409d62e..4f7bc60e26b5e2 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -64,13 +64,21 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_configs = graph_config.get("edges") if edge_configs is None: edge_configs = [] + # node configs + node_configs = graph_config.get("nodes") + if not node_configs: + raise ValueError("Graph must have at least one node") edge_configs = cast(list, edge_configs) + node_configs = cast(list, node_configs) # reorganize edges mapping edge_mapping: dict[str, list[GraphEdge]] = {} reverse_edge_mapping: dict[str, list[GraphEdge]] = {} target_edge_ids = set() + fail_branch_source_node_id = [ + node["id"] for node in node_configs if node["data"].get("error_strategy") == "fail-branch" + ] for edge_config in edge_configs: source_node_id = edge_config.get("source") if not source_node_id: @@ -90,8 +98,16 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non # parse run condition run_condition = None - if edge_config.get("sourceHandle") and edge_config.get("sourceHandle") != "source": - run_condition = RunCondition(type="branch_identify", branch_identify=edge_config.get("sourceHandle")) + if edge_config.get("sourceHandle"): + if ( + edge_config.get("source") in fail_branch_source_node_id + and edge_config.get("sourceHandle") != "fail-branch" + ): + run_condition = RunCondition(type="branch_identify", branch_identify="success-branch") + elif edge_config.get("sourceHandle") != "source": + run_condition = RunCondition( + type="branch_identify", branch_identify=edge_config.get("sourceHandle") + ) graph_edge = GraphEdge( source_node_id=source_node_id, target_node_id=target_node_id, run_condition=run_condition @@ -100,13 +116,6 @@ def init(cls, graph_config: Mapping[str, Any], root_node_id: Optional[str] = Non edge_mapping[source_node_id].append(graph_edge) reverse_edge_mapping[target_node_id].append(graph_edge) - # node configs - node_configs = graph_config.get("nodes") - if not node_configs: - raise ValueError("Graph must have at least one node") - - node_configs = cast(list, node_configs) - # fetch nodes that have no predecessor node root_node_configs = [] all_node_id_config_mapping: dict[str, dict] = {} diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index baeec9bf0160d7..7683dcc9dcd3c0 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -15,6 +15,7 @@ class Status(Enum): SUCCESS = "success" FAILED = "failed" PAUSED = "paused" + EXCEPTION = "exception" id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" @@ -51,7 +52,11 @@ def set_finished(self, run_result: NodeRunResult) -> None: :param run_result: run result """ - if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: + if self.status in { + RouteNodeState.Status.SUCCESS, + RouteNodeState.Status.FAILED, + RouteNodeState.Status.EXCEPTION, + }: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: @@ -59,6 +64,9 @@ def set_finished(self, run_result: NodeRunResult) -> None: elif run_result.status == WorkflowNodeExecutionStatus.FAILED: self.status = RouteNodeState.Status.FAILED self.failed_reason = run_result.error + elif run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + self.status = RouteNodeState.Status.EXCEPTION + self.failed_reason = run_result.error else: raise Exception(f"Invalid route status {run_result.status}") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 7cffd7bc8e1659..e03d4a7194a11e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -5,21 +5,23 @@ from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from copy import copy, deepcopy -from typing import Any, Optional +from typing import Any, Optional, cast from flask import Flask, current_app from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( BaseIterationEvent, GraphEngineEvent, GraphRunFailedEvent, + GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, + NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -36,7 +38,9 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor +from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from extensions.ext_database import db @@ -128,6 +132,7 @@ def __init__( def run(self) -> Generator[GraphEngineEvent, None, None]: # trigger graph run start event yield GraphRunStartedEvent() + handle_exceptions = [] try: if self.init_params.workflow_type == WorkflowType.CHAT: @@ -140,13 +145,17 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ) # run graph - generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) - + generator = stream_processor.process( + self._run(start_node_id=self.graph.root_node_id, handle_exceptions=handle_exceptions) + ) for item in generator: try: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or "Unknown error.") + yield GraphRunFailedEvent( + error=item.route_node_state.failed_reason or "Unknown error.", + exceptions_count=len(handle_exceptions), + ) return elif isinstance(item, NodeRunSucceededEvent): if item.node_type == NodeType.END: @@ -172,19 +181,24 @@ def run(self) -> Generator[GraphEngineEvent, None, None]: ].strip() except Exception as e: logger.exception("Graph run failed") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) return - - # trigger graph run success event - yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) + # count exceptions to determine partial success + if len(handle_exceptions) > 0: + yield GraphRunPartialSucceededEvent( + exceptions_count=len(handle_exceptions), outputs=self.graph_runtime_state.outputs + ) + else: + # trigger graph run success event + yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs) self._release_thread() except GraphRunFailedError as e: - yield GraphRunFailedEvent(error=e.error) + yield GraphRunFailedEvent(error=e.error, exceptions_count=len(handle_exceptions)) self._release_thread() return except Exception as e: logger.exception("Unknown Error when graph running") - yield GraphRunFailedEvent(error=str(e)) + yield GraphRunFailedEvent(error=str(e), exceptions_count=len(handle_exceptions)) self._release_thread() raise e @@ -198,6 +212,7 @@ def _run( in_parallel_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: parallel_start_node_id = None if in_parallel_id: @@ -242,7 +257,7 @@ def _run( previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) try: # run node generator = self._run_node( @@ -252,6 +267,7 @@ def _run( parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: @@ -301,7 +317,12 @@ def _run( if len(edge_mappings) == 1: edge = edge_mappings[0] - + if ( + previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and edge.run_condition is None + ): + break if edge.run_condition: result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -334,7 +355,7 @@ def _run( if len(sub_edge_mappings) == 0: continue - edge = sub_edge_mappings[0] + edge = cast(GraphEdge, sub_edge_mappings[0]) result = ConditionManager.get_condition_handler( init_params=self.init_params, @@ -355,6 +376,7 @@ def _run( edge_mappings=sub_edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in parallel_generator: @@ -369,11 +391,18 @@ def _run( break next_node_id = final_node_id + elif ( + node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node_instance.should_continue_on_error + and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION + ): + break else: parallel_generator = self._run_parallel_branches( edge_mappings=edge_mappings, in_parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in parallel_generator: @@ -395,6 +424,7 @@ def _run_parallel_branches( edge_mappings: list[GraphEdge], in_parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent | str, None, None]: # if nodes has no run conditions, parallel run all nodes parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) @@ -438,6 +468,7 @@ def _run_parallel_branches( "parallel_start_node_id": edge.target_node_id, "parent_parallel_id": in_parallel_id, "parent_parallel_start_node_id": parallel_start_node_id, + "handle_exceptions": handle_exceptions, }, ) @@ -481,6 +512,7 @@ def _run_parallel_node( parallel_start_node_id: str, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> None: """ Run parallel nodes @@ -502,6 +534,7 @@ def _run_parallel_node( in_parallel_id=parallel_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, + handle_exceptions=handle_exceptions, ) for item in generator: @@ -548,6 +581,7 @@ def _run_node( parallel_start_node_id: Optional[str] = None, parent_parallel_id: Optional[str] = None, parent_parallel_start_node_id: Optional[str] = None, + handle_exceptions: list[str] = [], ) -> Generator[GraphEngineEvent, None, None]: """ Run node @@ -587,19 +621,55 @@ def _run_node( route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - yield NodeRunFailedEvent( - error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, - route_node_state=route_node_state, - parallel_id=parallel_id, - parallel_start_node_id=parallel_start_node_id, - parent_parallel_id=parent_parallel_id, - parent_parallel_start_node_id=parent_parallel_start_node_id, - ) + if node_instance.should_continue_on_error: + # if run failed, handle error + run_result = self._handle_continue_on_error( + node_instance, + item.run_result, + self.graph_runtime_state.variable_pool, + handle_exceptions=handle_exceptions, + ) + route_node_state.node_run_result = run_result + route_node_state.status = RouteNodeState.Status.EXCEPTION + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_instance.node_id, + variable_key_list=[variable_key], + variable_value=variable_value, + ) + yield NodeRunExceptionEvent( + error=run_result.error or "System Error", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + else: + yield NodeRunFailedEvent( + error=route_node_state.failed_reason or "Unknown error.", + id=node_instance.id, + node_id=node_instance.node_id, + node_type=node_instance.node_type, + node_data=node_instance.node_data, + route_node_state=route_node_state, + parallel_id=parallel_id, + parallel_start_node_id=parallel_start_node_id, + parent_parallel_id=parent_parallel_id, + parent_parallel_start_node_id=parent_parallel_start_node_id, + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if node_instance.should_continue_on_error and self.graph.edge_mapping.get( + node_instance.node_id + ): + run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): # plus state total_tokens self.graph_runtime_state.total_tokens += int( @@ -735,6 +805,56 @@ def create_copy(self): new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) return new_instance + def _handle_continue_on_error( + self, + node_instance: BaseNode[BaseNodeData], + error_result: NodeRunResult, + variable_pool: VariablePool, + handle_exceptions: list[str] = [], + ) -> NodeRunResult: + """ + handle continue on error when self._should_continue_on_error is True + + + :param error_result (NodeRunResult): error run result + :param variable_pool (VariablePool): variable pool + :return: excption run result + """ + # add error message and error type to variable pool + variable_pool.add([node_instance.node_id, "error_message"], error_result.error) + variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + # add error message to handle_exceptions + handle_exceptions.append(error_result.error) + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": error_result.error, + "inputs": error_result.inputs, + "metadata": { + NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + }, + } + + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + return NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + ) + elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node_instance.node_id): + node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED + return NodeRunResult( + **node_error_args, + outputs={ + "error_message": error_result.error, + "error_type": error_result.error_type, + }, + ) + return error_result + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 8c78016f09a334..1b948bf59203b7 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -6,7 +6,7 @@ TextGenerateRouteChunk, VarGenerateRouteChunk, ) -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -148,13 +148,18 @@ def _recursive_fetch_answer_dependencies( for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in { - NodeType.ANSWER, - NodeType.IF_ELSE, - NodeType.QUESTION_CLASSIFIER, - NodeType.ITERATION, - NodeType.VARIABLE_ASSIGNER, - }: + source_node_data = node_id_config_mapping[source_node_id].get("data", {}) + if ( + source_node_type + in { + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, + NodeType.VARIABLE_ASSIGNER, + } + or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH + ): answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index 8a768088da660e..d94f0590584842 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -6,6 +6,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, + NodeRunExceptionEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, @@ -50,7 +51,7 @@ def process(self, generator: Generator[GraphEngineEvent, None, None]) -> Generat for _ in stream_out_answer_node_ids: yield event - elif isinstance(event, NodeRunSucceededEvent): + elif isinstance(event, NodeRunSucceededEvent | NodeRunExceptionEvent): yield event if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids: # update self.route_position after all stream event finished diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index fb50fbd6e863fa..9271867afffa6e 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -1,14 +1,124 @@ +import json from abc import ABC -from typing import Optional +from enum import StrEnum +from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, model_validator + +from core.workflow.nodes.base.exc import DefaultValueTypeError +from core.workflow.nodes.enums import ErrorStrategy + + +class DefaultValueType(StrEnum): + STRING = "string" + NUMBER = "number" + OBJECT = "object" + ARRAY_NUMBER = "array[number]" + ARRAY_STRING = "array[string]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILES = "array[file]" + + +NumberType = Union[int, float] + + +class DefaultValue(BaseModel): + value: Any + type: DefaultValueType + key: str + + @staticmethod + def _parse_json(value: str) -> Any: + """Unified JSON parsing handler""" + try: + return json.loads(value) + except json.JSONDecodeError: + raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") + + @staticmethod + def _validate_array(value: Any, element_type: DefaultValueType) -> bool: + """Unified array type validation""" + return isinstance(value, list) and all(isinstance(x, element_type) for x in value) + + @staticmethod + def _convert_number(value: str) -> float: + """Unified number conversion handler""" + try: + return float(value) + except ValueError: + raise DefaultValueTypeError(f"Cannot convert to number: {value}") + + @model_validator(mode="after") + def validate_value_type(self) -> "DefaultValue": + if self.type is None: + raise DefaultValueTypeError("type field is required") + + # Type validation configuration + type_validators = { + DefaultValueType.STRING: { + "type": str, + "converter": lambda x: x, + }, + DefaultValueType.NUMBER: { + "type": NumberType, + "converter": self._convert_number, + }, + DefaultValueType.OBJECT: { + "type": dict, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_NUMBER: { + "type": list, + "element_type": NumberType, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_STRING: { + "type": list, + "element_type": str, + "converter": self._parse_json, + }, + DefaultValueType.ARRAY_OBJECT: { + "type": list, + "element_type": dict, + "converter": self._parse_json, + }, + } + + validator = type_validators.get(self.type) + if not validator: + if self.type == DefaultValueType.ARRAY_FILES: + # Handle files type + return self + raise DefaultValueTypeError(f"Unsupported type: {self.type}") + + # Handle string input cases + if isinstance(self.value, str) and self.type != DefaultValueType.STRING: + self.value = validator["converter"](self.value) + + # Validate base type + if not isinstance(self.value, validator["type"]): + raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") + + # Validate array element types + if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): + raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") + + return self class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + error_strategy: Optional[ErrorStrategy] = None + default_value: Optional[list[DefaultValue]] = None version: str = "1" + @property + def default_value_dict(self): + if self.default_value: + return {item.key: item.value for item in self.default_value} + return {} + class BaseIterationNodeData(BaseNodeData): start_node_id: Optional[str] = None diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py new file mode 100644 index 00000000000000..ec134e031cf9d3 --- /dev/null +++ b/api/core/workflow/nodes/base/exc.py @@ -0,0 +1,10 @@ +class BaseNodeError(Exception): + """Base class for node errors.""" + + pass + + +class DefaultValueTypeError(BaseNodeError): + """Raised when the default value type is invalid.""" + + pass diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index d0fbed31cd1e20..e1e28af60b4c3f 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from models.workflow import WorkflowNodeExecutionStatus @@ -72,10 +72,7 @@ def run(self) -> Generator[Union[NodeEvent, "InNodeEvent"], None, None]: result = self._run() except Exception as e: logger.exception(f"Node {self.node_id} failed to run") - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - ) + result = NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=str(e), error_type="SystemError") if isinstance(result, NodeRunResult): yield RunCompletedEvent(run_result=result) @@ -137,3 +134,12 @@ def node_type(self) -> NodeType: :return: """ return self._node_type + + @property + def should_continue_on_error(self) -> bool: + """judge if should continue on error + + Returns: + bool: if should continue on error + """ + return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index ce283e38ec9b12..19b9078a5ce4a4 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -61,7 +61,9 @@ def _run(self) -> NodeRunResult: # Transform result result = self._transform_result(result, self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ + ) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 44be403ee6ab91..6d8ca6f7018cb8 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -22,3 +22,16 @@ class NodeType(StrEnum): VARIABLE_ASSIGNER = "assigner" DOCUMENT_EXTRACTOR = "document-extractor" LIST_OPERATOR = "list-operator" + + +class ErrorStrategy(StrEnum): + FAIL_BRANCH = "fail-branch" + DEFAULT_VALUE = "default-value" + + +class FailBranchSourceHandle(StrEnum): + FAILED = "fail-branch" + SUCCESS = "success-branch" + + +CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 22ad2a39f62fa4..0ac095d26cc248 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -21,6 +21,7 @@ from .exc import ( AuthorizationConfigError, FileFetchError, + HttpRequestNodeError, InvalidHttpMethodError, ResponseSizeError, ) @@ -208,8 +209,10 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: "follow_redirects": True, } # request_args = {k: v for k, v in request_args.items() if v is not None} - - response = getattr(ssrf_proxy, self.method)(**request_args) + try: + response = getattr(ssrf_proxy, self.method)(**request_args) + except ssrf_proxy.MaxRetriesExceededError as e: + raise HttpRequestNodeError(str(e)) return response def invoke(self) -> Response: diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 2a92a16ede84e0..d040cc9f559ded 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -65,6 +65,21 @@ def _run(self) -> NodeRunResult: response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) + if not response.response.is_success and self.should_continue_on_error: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + outputs={ + "status_code": response.status_code, + "body": response.text if not files else "", + "headers": response.headers, + "files": files, + }, + process_data={ + "request": http_executor.to_log(), + }, + error=f"Request failed with status code {response.status_code}", + error_type="HTTPResponseCodeError", + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={ @@ -83,6 +98,7 @@ def _run(self) -> NodeRunResult: status=WorkflowNodeExecutionStatus.FAILED, error=str(e), process_data=process_data, + error_type=type(e).__name__, ) @staticmethod diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b788191adcbbd7..67e62cb8750430 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -193,6 +193,7 @@ def _run(self) -> NodeRunResult | Generator[NodeEvent | InNodeEvent, None, None] error=str(e), inputs=node_inputs, process_data=process_data, + error_type=type(e).__name__, ) ) return diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index e855ab2d2b0659..7594036b50abb4 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -139,7 +139,7 @@ def _run(self): "usage": jsonable_encoder(usage), "finish_reason": finish_reason, } - outputs = {"class_name": category_name} + outputs = {"class_name": category_name, "class_id": category_id} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 951e5330a324e8..9b901c026e8bcc 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -56,6 +56,7 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to get tool runtime: {str(e)}", + error_type=type(e).__name__, ) # get parameters @@ -89,6 +90,7 @@ def _run(self) -> NodeRunResult: NodeRunMetadataKey.TOOL_INFO: tool_info, }, error=f"Failed to invoke tool: {str(e)}", + error_type=type(e).__name__, ) # convert tool messages diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 1413adf7196879..8390c665561841 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -14,6 +14,7 @@ "total_steps": fields.Integer, "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } workflow_run_for_list_fields = { @@ -27,6 +28,7 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } advanced_chat_workflow_run_for_list_fields = { @@ -42,6 +44,7 @@ "created_by_account": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } advanced_chat_workflow_run_pagination_fields = { @@ -73,6 +76,7 @@ "created_by_end_user": fields.Nested(simple_end_user_fields, attribute="created_by_end_user", allow_null=True), "created_at": TimestampField, "finished_at": TimestampField, + "exceptions_count": fields.Integer, } workflow_run_node_execution_fields = { diff --git a/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py new file mode 100644 index 00000000000000..8c576339bae8cf --- /dev/null +++ b/api/migrations/versions/2024_11_28_0553-cf8f4fc45278_add_exceptions_count_field_to_.py @@ -0,0 +1,33 @@ +"""add exceptions_count field to WorkflowRun model + +Revision ID: cf8f4fc45278 +Revises: 01d6889832f7 +Create Date: 2024-11-28 05:53:21.576178 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'cf8f4fc45278' +down_revision = '01d6889832f7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('exceptions_count', sa.Integer(), server_default=sa.text('0'), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_column('exceptions_count') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index c0e70889a88875..1b0af85f085e91 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -325,6 +325,7 @@ class WorkflowRunStatus(StrEnum): SUCCEEDED = "succeeded" FAILED = "failed" STOPPED = "stopped" + PARTIAL_SUCCESSED = "partial-succeeded" @classmethod def value_of(cls, value: str) -> "WorkflowRunStatus": @@ -395,7 +396,7 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[str] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) @@ -405,6 +406,7 @@ class WorkflowRun(db.Model): created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) + exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @property def created_by_account(self): @@ -464,6 +466,7 @@ def to_dict(self): "created_by": self.created_by, "created_at": self.created_at, "finished_at": self.finished_at, + "exceptions_count": self.exceptions_count, } @classmethod @@ -489,6 +492,7 @@ def from_dict(cls, data: dict) -> "WorkflowRun": created_by=data.get("created_by"), created_at=data.get("created_at"), finished_at=data.get("finished_at"), + exceptions_count=data.get("exceptions_count"), ) @@ -522,6 +526,7 @@ class WorkflowNodeExecutionStatus(Enum): RUNNING = "running" SUCCEEDED = "succeeded" FAILED = "failed" + EXCEPTION = "exception" @classmethod def value_of(cls, value: str) -> "WorkflowNodeExecutionStatus": diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 37d7d0937cd492..84768d5af053e4 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,7 +2,7 @@ import time from collections.abc import Sequence from datetime import UTC, datetime -from typing import Optional +from typing import Optional, cast from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -11,6 +11,9 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes import NodeType +from core.workflow.nodes.base.entities import BaseNodeData +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry @@ -225,7 +228,7 @@ def run_draft_workflow_node( user_inputs=user_inputs, user_id=account.id, ) - + node_instance = cast(BaseNode[BaseNodeData], node_instance) node_run_result: NodeRunResult | None = None for event in generator: if isinstance(event, RunCompletedEvent): @@ -237,8 +240,35 @@ def run_draft_workflow_node( if not node_run_result: raise ValueError("Node run failed with no run result") - - run_succeeded = True if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED else False + # single step debug mode error handling return + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + node_error_args = { + "status": WorkflowNodeExecutionStatus.EXCEPTION, + "error": node_run_result.error, + "inputs": node_run_result.inputs, + "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + } + if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + **node_instance.node_data.default_value_dict, + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + else: + node_run_result = NodeRunResult( + **node_error_args, + outputs={ + "error_message": node_run_result.error, + "error_type": node_run_result.error_type, + }, + ) + run_succeeded = node_run_result.status in ( + WorkflowNodeExecutionStatus.SUCCEEDED, + WorkflowNodeExecutionStatus.EXCEPTION, + ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: node_instance = e.node_instance @@ -260,7 +290,6 @@ def run_draft_workflow_node( workflow_node_execution.created_by = account.id workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - if run_succeeded and node_run_result: # create workflow node execution inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None @@ -277,7 +306,11 @@ def run_draft_workflow_node( workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: + workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value + workflow_node_execution.error = node_run_result.error else: # create workflow node execution workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py new file mode 100644 index 00000000000000..ba209e4020afad --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -0,0 +1,502 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import ( + GraphRunPartialSucceededEvent, + GraphRunSucceededEvent, + NodeRunExceptionEvent, + NodeRunStreamChunkEvent, +) +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.graph_engine import GraphEngine +from models.enums import UserFrom +from models.workflow import WorkflowType + + +class ContinueOnErrorTestHelper: + @staticmethod + def get_code_node(code: str, error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a code node configuration""" + node = { + "id": "node", + "data": { + "outputs": {"result": {"type": "number"}}, + "error_strategy": error_strategy, + "title": "code", + "variables": [], + "code_language": "python3", + "code": "\n".join([line[4:] for line in code.split("\n")]), + "type": "code", + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_http_node( + error_strategy: str = "fail-branch", default_value: dict | None = None, authorization_success: bool = False + ): + """Helper method to create a http node configuration""" + authorization = ( + { + "type": "api-key", + "config": { + "type": "basic", + "api_key": "ak-xxx", + "header": "api-key", + }, + } + if authorization_success + else { + "type": "api-key", + # missing config field + } + ) + node = { + "id": "node", + "data": { + "title": "http", + "desc": "", + "method": "get", + "url": "http://example.com", + "authorization": authorization, + "headers": "X-Header:123", + "params": "A:b", + "body": None, + "type": "http-request", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_error_status_code_http_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a http node configuration""" + node = { + "id": "node", + "data": { + "type": "http-request", + "title": "HTTP Request", + "desc": "", + "variables": [], + "method": "get", + "url": "https://api.github.com/issues", + "authorization": {"type": "no-auth", "config": None}, + "headers": "", + "params": "", + "body": {"type": "none", "data": []}, + "timeout": {"max_connect_timeout": 0, "max_read_timeout": 0, "max_write_timeout": 0}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_tool_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a tool node configuration""" + node = { + "id": "node", + "data": { + "title": "a", + "desc": "a", + "provider_id": "maths", + "provider_type": "builtin", + "provider_name": "maths", + "tool_name": "eval_expression", + "tool_label": "eval_expression", + "tool_configurations": {}, + "tool_parameters": { + "expression": { + "type": "variable", + "value": ["1", "123", "args1"], + } + }, + "type": "tool", + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def get_llm_node(error_strategy: str = "fail-branch", default_value: dict | None = None): + """Helper method to create a llm node configuration""" + node = { + "id": "node", + "data": { + "title": "123", + "type": "llm", + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "chat", "completion_params": {}}, + "prompt_template": [ + {"role": "system", "text": "you are a helpful assistant.\ntoday's weather is {{#abc.output#}}."}, + {"role": "user", "text": "{{#sys.query#}}"}, + ], + "memory": None, + "context": {"enabled": False}, + "vision": {"enabled": False}, + "error_strategy": error_strategy, + }, + } + if default_value: + node["data"]["default_value"] = default_value + return node + + @staticmethod + def create_test_graph_engine(graph_config: dict, user_inputs: dict | None = None): + """Helper method to create a graph engine instance for testing""" + graph = Graph.init(graph_config=graph_config) + variable_pool = { + "system_variables": { + SystemVariableKey.QUERY: "clear", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "aaa", + }, + "user_inputs": user_inputs or {"uid": "takato"}, + } + + return GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + graph_config=graph_config, + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200, + ) + + +DEFAULT_VALUE_EDGE = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-source-answer-target", + "source": "node", + "target": "answer", + "sourceHandle": "source", + }, +] + +FAIL_BRANCH_EDGES = [ + { + "id": "start-source-node-target", + "source": "start", + "target": "node", + "sourceHandle": "source", + }, + { + "id": "node-true-success-target", + "source": "node", + "target": "success", + "sourceHandle": "source", + }, + { + "id": "node-false-error-target", + "source": "node", + "target": "error", + "sourceHandle": "fail-branch", + }, +] + + +def test_code_default_value_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_code_node( + error_code, "default-value", [{"key": "result", "type": "number", "value": 132123}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "132123"} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_code_fail_branch_continue_on_error(): + error_code = """ + def main() -> dict: + return { + "result": 1 / 0, + } + """ + + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "node node run successfully"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "node node run failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_code_node(error_code), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "node node run failed"} for e in events + ) + + +def test_http_node_default_value_continue_on_error(): + """Test HTTP node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.response#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_http_node( + "default-value", [{"key": "response", "type": "string", "value": "http node got error response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http node got error response"} + for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "HTTP request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "HTTP request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_default_value_continue_on_error(): + """Test tool node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.result#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_tool_node( + "default-value", [{"key": "result", "type": "string", "value": "default tool result"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default tool result"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_tool_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "tool execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "tool execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_tool_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "tool execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_default_value_continue_on_error(): + """Test LLM node with default value error strategy""" + graph_config = { + "edges": DEFAULT_VALUE_EDGE, + "nodes": [ + {"data": {"title": "start", "type": "start", "variables": []}, "id": "start"}, + {"data": {"title": "answer", "type": "answer", "answer": "{{#node.answer#}}"}, "id": "answer"}, + ContinueOnErrorTestHelper.get_llm_node( + "default-value", [{"key": "answer", "type": "string", "value": "default LLM response"}] + ), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "default LLM response"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_llm_node_fail_branch_continue_on_error(): + """Test LLM node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "LLM request successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "LLM request failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_llm_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "LLM request failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_status_code_error_http_node_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any( + isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {"answer": "http execute failed"} for e in events + ) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 1 + + +def test_variable_pool_error_type_variable(): + graph_config = { + "edges": FAIL_BRANCH_EDGES, + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "http execute successful"}, + "id": "success", + }, + { + "data": {"title": "error", "type": "answer", "answer": "http execute failed"}, + "id": "error", + }, + ContinueOnErrorTestHelper.get_error_status_code_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + list(graph_engine.run()) + error_message = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_message"]) + error_type = graph_engine.graph_runtime_state.variable_pool.get(["node", "error_type"]) + assert error_message != None + assert error_type.value == "HTTPResponseCodeError" + + +def test_no_node_in_fail_branch_continue_on_error(): + """Test HTTP node with fail-branch error strategy""" + graph_config = { + "edges": FAIL_BRANCH_EDGES[:-1], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, + "id": "success", + }, + ContinueOnErrorTestHelper.get_http_node(), + ], + } + + graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config) + events = list(graph_engine.run()) + + assert any(isinstance(e, NodeRunExceptionEvent) for e in events) + assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events) + assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0