Skip to content

Commit

Permalink
feat: Iteration node support parallel mode (langgenius#9493)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nov1c444 authored and JunXu01 committed Nov 9, 2024
1 parent 5c7eafa commit aa1c27c
Show file tree
Hide file tree
Showing 33 changed files with 1,285 additions and 194 deletions.
3 changes: 2 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
QueueIterationStartEvent,
QueueMessageReplaceEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand Down Expand Up @@ -314,7 +315,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down
3 changes: 2 additions & 1 deletion api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand Down Expand Up @@ -275,7 +276,7 @@ def _process_stream_response(

if response:
yield response
elif isinstance(event, QueueNodeFailedEvent):
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)

response = self._workflow_node_finish_to_stream_response(
Expand Down
35 changes: 35 additions & 0 deletions api/core/app/apps/workflow_app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand All @@ -30,6 +31,7 @@
IterationRunNextEvent,
IterationRunStartedEvent,
IterationRunSucceededEvent,
NodeInIterationFailedEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
Expand Down Expand Up @@ -193,6 +195,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
node_run_index=event.route_node_state.index,
predecessor_node_id=event.predecessor_node_id,
in_iteration_id=event.in_iteration_id,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, NodeRunSucceededEvent):
Expand Down Expand Up @@ -246,9 +249,40 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
error=event.route_node_state.node_run_result.error
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
else "Unknown error",
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
)
)
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_data=event.node_data,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
start_at=event.route_node_state.start_at,
inputs=event.route_node_state.node_run_result.inputs
if event.route_node_state.node_run_result
else {},
process_data=event.route_node_state.node_run_result.process_data
if event.route_node_state.node_run_result
else {},
outputs=event.route_node_state.node_run_result.outputs
if event.route_node_state.node_run_result
else {},
execution_metadata=event.route_node_state.node_run_result.metadata
if event.route_node_state.node_run_result
else {},
in_iteration_id=event.in_iteration_id,
error=event.error,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
self._publish_event(
QueueTextChunkEvent(
Expand Down Expand Up @@ -326,6 +360,7 @@ def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent)
index=event.index,
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
output=event.pre_iteration_output,
parallel_mode_run_id=event.parallel_mode_run_id,
)
)
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
Expand Down
37 changes: 36 additions & 1 deletion api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""

parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""
node_run_index: int
output: Optional[Any] = None # output for the current iteration

Expand Down Expand Up @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""


class QueueNodeSucceededEvent(AppQueueEvent):
Expand Down Expand Up @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
error: Optional[str] = None


class QueueNodeInIterationFailedEvent(AppQueueEvent):
"""
QueueNodeInIterationFailedEvent entity
"""

event: QueueEvent = QueueEvent.NODE_FAILED

node_execution_id: str
node_id: str
node_type: NodeType
node_data: BaseNodeData
parallel_id: Optional[str] = None
"""parallel id if node is in parallel"""
parallel_start_node_id: Optional[str] = None
"""parallel start node id if node is in parallel"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
in_iteration_id: Optional[str] = None
"""iteration id if node is in iteration"""
start_at: datetime

inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str


class QueueNodeFailedEvent(AppQueueEvent):
"""
QueueNodeFailedEvent entity
Expand All @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
inputs: Optional[dict[str, Any]] = None
process_data: Optional[dict[str, Any]] = None
outputs: Optional[dict[str, Any]] = None
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None

error: str

Expand Down
2 changes: 2 additions & 0 deletions api/core/app/entities/task_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ class Data(BaseModel):
parent_parallel_id: Optional[str] = None
parent_parallel_start_node_id: Optional[str] = None
iteration_id: Optional[str] = None
parallel_run_id: Optional[str] = None

event: StreamEvent = StreamEvent.NODE_STARTED
workflow_run_id: str
Expand Down Expand Up @@ -432,6 +433,7 @@ class Data(BaseModel):
extras: dict = {}
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None

event: StreamEvent = StreamEvent.ITERATION_NEXT
workflow_run_id: str
Expand Down
28 changes: 23 additions & 5 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
Expand All @@ -35,6 +36,7 @@
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
Expand Down Expand Up @@ -251,6 +253,12 @@ def _handle_node_execution_start(
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
}
)
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)

session.add(workflow_node_execution)
Expand Down Expand Up @@ -305,7 +313,9 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent

return workflow_node_execution

def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
Expand All @@ -318,16 +328,19 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()

execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
)

Expand All @@ -342,6 +355,7 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata

self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)

Expand Down Expand Up @@ -448,6 +462,7 @@ def _workflow_node_start_to_stream_response(
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
parallel_run_id=event.parallel_mode_run_id,
),
)

Expand All @@ -464,7 +479,7 @@ def _workflow_node_start_to_stream_response(

def _workflow_node_finish_to_stream_response(
self,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
Expand Down Expand Up @@ -608,6 +623,7 @@ def _workflow_iteration_next_to_stream_response(
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
),
)

Expand All @@ -633,7 +649,9 @@ def _workflow_iteration_completed_to_stream_response(
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED,
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
Expand Down
1 change: 1 addition & 0 deletions api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"


class NodeRunResult(BaseModel):
Expand Down
7 changes: 7 additions & 0 deletions api/core/workflow/graph_engine/entities/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):

class NodeRunStartedEvent(BaseNodeEvent):
predecessor_node_id: Optional[str] = None
parallel_mode_run_id: Optional[str] = None
"""predecessor node id"""


Expand All @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


class NodeInIterationFailedEvent(BaseNodeEvent):
error: str = Field(..., description="error")


###########################################
# Parallel Branch Events
###########################################
Expand Down Expand Up @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
"""parent parallel id if node is in parallel"""
parent_parallel_start_node_id: Optional[str] = None
"""parent parallel start node id if node is in parallel"""
parallel_mode_run_id: Optional[str] = None
"""iteratoin run in parallel mode run id"""


class IterationRunStartedEvent(BaseIterationEvent):
Expand Down
11 changes: 11 additions & 0 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
from typing import Any, Optional

from flask import Flask, current_app
Expand Down Expand Up @@ -724,6 +725,16 @@ def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
return time.perf_counter() - start_at > max_execution_time

def create_copy(self):
"""
create a graph engine copy
:return: with a new variable pool instance of graph engine
"""
new_instance = copy(self)
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
return new_instance


class GraphRunFailedError(Exception):
def __init__(self, error: str):
Expand Down
10 changes: 10 additions & 0 deletions api/core/workflow/nodes/iteration/entities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from enum import Enum
from typing import Any, Optional

from pydantic import Field

from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData


class ErrorHandleMode(str, Enum):
TERMINATED = "terminated"
CONTINUE_ON_ERROR = "continue-on-error"
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"


class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
Expand All @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
parent_loop_id: Optional[str] = None # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
is_parallel: bool = False # open the parallel mode or not
parallel_nums: int = 10 # the numbers of parallel
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error


class IterationStartNodeData(BaseNodeData):
Expand Down
Loading

0 comments on commit aa1c27c

Please sign in to comment.