Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Iteration node support parallel mode #9493

Merged
merged 22 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7d76101
feat: Iteration node support parallel mode
Nov1c444 Oct 18, 2024
00127b6
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 18, 2024
d33907d
fix: fix config unit test error
Nov1c444 Oct 18, 2024
377aa0b
fix: graph error can't be raised
Nov1c444 Oct 22, 2024
6c13caa
merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 22, 2024
2303a10
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 22, 2024
f13b9d9
fix: variable pool get any method
Nov1c444 Oct 22, 2024
5005fb8
fix: frontend display of iteration log panel
Nov1c444 Oct 24, 2024
c58e644
fix: chatflow teminated log show error
Nov1c444 Oct 24, 2024
1e31e97
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 24, 2024
96771f5
fix: workflow log show error
Nov1c444 Oct 25, 2024
52a203f
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 25, 2024
f519f9d
fix: parallel nums show error
Nov1c444 Oct 29, 2024
9221a5d
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 29, 2024
4f6d402
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 29, 2024
ac07082
fix: unify frontend styles
Nov1c444 Oct 29, 2024
f199c50
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Oct 30, 2024
d15c9d3
fix: add missing type annotations
Nov1c444 Oct 31, 2024
b5413f4
feat(workflow): add handling for QueueNodeInIterationFailedEvent
laipz8200 Nov 4, 2024
92ec7ae
chore: change the parallel warning show logic
Nov1c444 Nov 1, 2024
cbf6c77
fix: correct enum value naming
Nov1c444 Nov 5, 2024
a4cb049
Merge branch 'main' into feat/iteration-node-parallel
Nov1c444 Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand Down Expand Up @@ -314,7 +315,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down
3 changes: 2 additions & 1 deletion api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand Down Expand Up @@ -275,7 +276,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down
35 changes: 35 additions & 0 deletions api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand All @@ -30,6 +31,7 @@
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
Expand Down Expand Up @@ -193,6 +195,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
Expand Down Expand Up @@ -246,9 +249,40 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
Expand Down Expand Up @@ -326,6 +360,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
Expand Down
37 changes: 36 additions & 1 deletion api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""

parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration

Expand Down Expand Up @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""


class QueueNodeSucceededEvent(AppQueueEvent):
Expand Down Expand Up @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
error: Optional[str] = None


class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""

event: QueueEvent = QueueEvent.NODE_FAILED

node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime

inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str


class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
Expand All @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str

Expand Down
2 changes: 2 additions & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class Data(BaseModel):
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None

event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
Expand Down Expand Up @@ -432,6 +433,7 @@ class Data(BaseModel):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None

event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
Expand Down
28 changes: 23 additions & 5 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand All @@ -35,6 +36,7 @@
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
Expand Down Expand Up @@ -251,6 +253,12 @@ def _handle_node_execution_start(
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)

session.add(workflow_node_execution)
Expand Down Expand Up @@ -305,7 +313,9 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent

return workflow_node_execution

def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
Expand All @@ -318,16 +328,19 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()

execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)

Expand All @@ -342,6 +355,7 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata

self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)

Expand Down Expand Up @@ -448,6 +462,7 @@ def _workflow_node_start_to_stream_response(
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)

Expand All @@ -464,7 +479,7 @@ def _workflow_node_start_to_stream_response(

def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
Expand Down Expand Up @@ -608,6 +623,7 @@ def _workflow_iteration_next_to_stream_response(
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)

Expand All @@ -633,7 +649,9 @@ def _workflow_iteration_completed_to_stream_response(
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
Expand Down
1 change: 1 addition & 0 deletions api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"


class NodeRunResult(BaseModel):
Expand Down
7 changes: 7 additions & 0 deletions api/core/workflow/graph_engine/entities/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):

class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""


Expand All @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


###########################################
# Parallel Branch Events
###########################################
Expand Down Expand Up @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""


class IterationRunStartedEvent(BaseIterationEvent):
Expand Down
11 changes: 11 additions & 0 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional

from flask import Flask, current_app
Expand Down Expand Up @@ -724,6 +725,16 @@ def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
return time.perf_counter() - start_at > max_execution_time

def create_copy(self):
"""
create a graph engine copy
:return: with a new variable pool instance of graph engine
"""
new_instance = copy(self)
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance


class GraphRunFailedError(Exception):
def __init__(self, error: str):
Expand Down
10 changes: 10 additions & 0 deletions api/core/workflow/nodes/iteration/entities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from enum import Enum
from typing import Any, Optional

from pydantic import Field

from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData


class ErrorHandleMode(str, Enum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"


class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
Expand All @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error


class IterationStartNodeData(BaseNodeData):
Expand Down
Loading
Loading