Skip to content

Commit

Permalink
Feat: continue on error (#11458)
Browse files Browse the repository at this point in the history
Co-authored-by: Novice Lee <[email protected]>
Co-authored-by: Novice Lee <[email protected]>
  • Loading branch information
3 people authored Dec 11, 2024
1 parent bec5451 commit 79a710c
Show file tree
Hide file tree
Showing 31 changed files with 1,211 additions and 80 deletions.
28 changes: 27 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
Expand All @@ -31,6 +32,7 @@
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
Expand Down Expand Up @@ -317,7 +319,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down Expand Up @@ -384,6 +386,29 @@ def _process_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)

self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run:
Expand All @@ -401,6 +426,7 @@ def _process_stream_response(
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)

yield self._workflow_finish_to_stream_response(
Expand Down
4 changes: 3 additions & 1 deletion api/core/app/apps/workflow/app_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
)
Expand Down Expand Up @@ -34,7 +35,8 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent,
| QueueWorkflowFailedEvent
| QueueWorkflowPartialSuccessEvent,
):
self.stop_listen()

Expand Down
68 changes: 61 additions & 7 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
Expand All @@ -26,6 +27,7 @@
QueueStopEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
Expand Down Expand Up @@ -276,7 +278,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down Expand Up @@ -345,29 +347,81 @@ def _process_stream_response(
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")

workflow_run = self._handle_workflow_run_failed(
workflow_run = self._handle_workflow_run_partial_success(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
outputs=event.outputs,
exceptions_count=event.exceptions_count,
conversation_id=None,
trace_manager=trace_manager,
)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
}
workflow_run = self._handle_workflow_run_failed(**handle_args)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run:
raise Exception("Workflow run not initialized.")

if not graph_runtime_state:
raise Exception("Graph runtime state not initialized.")
handle_args = {
"workflow_run": workflow_run,
"start_at": graph_runtime_state.start_at,
"total_tokens": graph_runtime_state.total_tokens,
"total_steps": graph_runtime_state.node_run_steps,
"status": WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
"error": event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
"conversation_id": None,
"trace_manager": trace_manager,
"exceptions_count": event.exceptions_count,
}
workflow_run = self._handle_workflow_run_partial_success(**handle_args)

# save workflow app log
self._save_workflow_app_log(workflow_run)

yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
Expand Down
40 changes: 39 additions & 1 deletion api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
Expand All @@ -18,20 +19,23 @@
QueueRetrieverResourcesEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
IterationRunFailedEvent,
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
Expand Down Expand Up @@ -176,8 +180,12 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
)
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
self._publish_event(
QueueWorkflowPartialSuccessEvent(outputs=event.outputs, exceptions_count=event.exceptions_count)
)
elif isinstance(event, GraphRunFailedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.error))
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, NodeRunStartedEvent):
self._publish_event(
QueueNodeStartedEvent(
Expand Down Expand Up @@ -253,6 +261,36 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
self._publish_event(
QueueNodeExceptionEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
Expand Down
44 changes: 44 additions & 0 deletions api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ class QueueEvent(StrEnum):
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
WORKFLOW_PARTIAL_SUCCEEDED = "workflow_partial_succeeded"
ITERATION_START = "iteration_start"
ITERATION_NEXT = "iteration_next"
ITERATION_COMPLETED = "iteration_completed"
NODE_STARTED = "node_started"
NODE_SUCCEEDED = "node_succeeded"
NODE_FAILED = "node_failed"
NODE_EXCEPTION = "node_exception"
RETRIEVER_RESOURCES = "retriever_resources"
ANNOTATION_REPLY = "annotation_reply"
AGENT_THOUGHT = "agent_thought"
Expand Down Expand Up @@ -237,6 +239,17 @@ class QueueWorkflowFailedEvent(AppQueueEvent):

event: QueueEvent = QueueEvent.WORKFLOW_FAILED
error: str
exceptions_count: int


class QueueWorkflowPartialSuccessEvent(AppQueueEvent):
"""
QueueWorkflowFailedEvent entity
"""

event: QueueEvent = QueueEvent.WORKFLOW_PARTIAL_SUCCEEDED
exceptions_count: int
outputs: Optional[dict[str, Any]] = None


class QueueNodeStartedEvent(AppQueueEvent):
Expand Down Expand Up @@ -331,6 +344,37 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent):
error: str


class QueueNodeExceptionEvent(AppQueueEvent):
"""
QueueNodeExceptionEvent entity
"""

event: QueueEvent = QueueEvent.NODE_EXCEPTION

node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime

inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str


class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
Expand Down
1 change: 1 addition & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class Data(BaseModel):
created_by: Optional[dict] = None
created_at: int
finished_at: int
exceptions_count: Optional[int] = 0
files: Optional[Sequence[Mapping[str, Any]]] = []

event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
Expand Down
Loading

0 comments on commit 79a710c

Please sign in to comment.