Skip to content

Commit

Permalink
fix: issue #10596 by making the iteration node outputs right (#11394)
Browse files Browse the repository at this point in the history
Signed-off-by: yihong0618 <[email protected]>
Signed-off-by: -LAN- <[email protected]>
Co-authored-by: -LAN- <[email protected]>
  • Loading branch information
yihong0618 and laipz8200 authored Dec 7, 2024
1 parent 9277156 commit d9d5d35
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 43 deletions.
14 changes: 1 addition & 13 deletions api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum, StrEnum
from typing import Any, Optional

from pydantic import BaseModel, field_validator
from pydantic import BaseModel

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.workflow.entities.node_entities import NodeRunMetadataKey
Expand Down Expand Up @@ -113,18 +113,6 @@ class QueueIterationNextEvent(AppQueueEvent):
output: Optional[Any] = None # output for the current iteration
duration: Optional[float] = None

@field_validator("output", mode="before")
@classmethod
def set_output(cls, v):
"""
Set output
"""
if v is None:
return None
if isinstance(v, int | float | str | bool | dict | list):
return v
raise ValueError("output must be a valid type")


class QueueIterationCompletedEvent(AppQueueEvent):
"""
Expand Down
76 changes: 46 additions & 30 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flask import Flask, current_app

from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
from core.variables import IntegerVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
Expand Down Expand Up @@ -155,18 +155,19 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
iteration_node_data=self.node_data,
index=0,
pre_iteration_output=None,
duration=None,
)
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
if self.node_data.is_parallel:
futures: list[Future] = []
q = Queue()
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
self._run_single_iter_parallel,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
q,
iterator_list_value,
inputs,
Expand All @@ -181,6 +182,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
future.add_done_callback(thread_pool.task_done_callback)
futures.append(future)
succeeded_count = 0
empty_count = 0
while True:
try:
event = q.get(timeout=1)
Expand Down Expand Up @@ -208,33 +210,38 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
else:
for _ in range(len(iterator_list_value)):
yield from self._run_single_iter(
iterator_list_value,
variable_pool,
inputs,
outputs,
start_at,
graph_engine,
iteration_graph,
iter_run_map,
iterator_list_value=iterator_list_value,
variable_pool=variable_pool,
inputs=inputs,
outputs=outputs,
start_at=start_at,
graph_engine=graph_engine,
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]

# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]

yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
iteration_node_type=self.node_type,
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
)

yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
metadata={NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map},
)
)
Expand All @@ -248,7 +255,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=str(e),
Expand Down Expand Up @@ -280,7 +287,7 @@ def _extract_variable_selector_to_variable_mapping(
:param node_data: node data
:return:
"""
variable_mapping = {
variable_mapping: dict[str, Sequence[str]] = {
f"{node_id}.input_selector": node_data.iterator_selector,
}

Expand Down Expand Up @@ -308,7 +315,7 @@ def _extract_variable_selector_to_variable_mapping(
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_config, config=sub_node_config
)
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping)
except NotImplementedError:
sub_node_variable_mapping = {}

Expand All @@ -329,8 +336,12 @@ def _extract_variable_selector_to_variable_mapping(
return variable_mapping

def _handle_event_metadata(
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
) -> NodeRunStartedEvent | BaseNodeEvent:
self,
*,
event: BaseNodeEvent | InNodeEvent,
iter_run_index: int,
parallel_mode_run_id: str | None,
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
"""
add iteration metadata to event.
"""
Expand All @@ -355,6 +366,7 @@ def _handle_event_metadata(

def _run_single_iter(
self,
*,
iterator_list_value: list[str],
variable_pool: VariablePool,
inputs: dict[str, list],
Expand All @@ -373,12 +385,12 @@ def _run_single_iter(
try:
rst = graph_engine.run()
# get current iteration index
current_index = variable_pool.get([self.node_id, "index"]).value
index_variable = variable_pool.get([self.node_id, "index"])
if not isinstance(index_variable, IntegerVariable):
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
current_index = index_variable.value
iteration_run_id = parallel_mode_run_id if parallel_mode_run_id is not None else f"{current_index}"
next_index = int(current_index) + 1

if current_index is None:
raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found")
for event in rst:
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
event.in_iteration_id = self.node_id
Expand All @@ -391,7 +403,9 @@ def _run_single_iter(
continue

if isinstance(event, NodeRunSucceededEvent):
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
yield self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
Expand All @@ -404,7 +418,7 @@ def _run_single_iter(
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
Expand All @@ -417,7 +431,7 @@ def _run_single_iter(
iteration_node_data=self.node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": jsonable_encoder(outputs)},
outputs={"output": outputs},
steps=len(iterator_list_value),
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
error=event.error,
Expand All @@ -429,9 +443,11 @@ def _run_single_iter(
)
)
return
else:
event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
elif isinstance(event, InNodeEvent):
# event = cast(InNodeEvent, event)
metadata_event = self._handle_event_metadata(
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
Expand Down Expand Up @@ -513,7 +529,7 @@ def _run_single_iter(
iteration_node_data=self.node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
pre_iteration_output=current_iteration_output or None,
duration=duration,
)

Expand Down Expand Up @@ -551,7 +567,7 @@ def _run_single_iter_parallel(
index: int,
item: Any,
iter_run_map: dict[str, float],
) -> Generator[NodeEvent | InNodeEvent, None, None]:
):
"""
run single iteration in parallel mode
"""
Expand Down

0 comments on commit d9d5d35

Please sign in to comment.