Skip to content

Commit

Permalink
fix: Introduce ArrayVariable and update iteration node to handle it (#…
Browse files Browse the repository at this point in the history
…12001)

Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 authored Dec 23, 2024
1 parent 8978a6a commit 9cfd1c6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
2 changes: 2 additions & 0 deletions api/core/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
FloatVariable,
IntegerVariable,
Expand All @@ -43,6 +44,7 @@
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",
"ArrayVariable",
"FileSegment",
"FileVariable",
"FloatSegment",
Expand Down
13 changes: 9 additions & 4 deletions api/core/variables/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
Expand Down Expand Up @@ -52,19 +53,23 @@ class ObjectVariable(ObjectSegment, Variable):
pass


class ArrayAnyVariable(ArrayAnySegment, Variable):
class ArrayVariable(ArraySegment, Variable):
pass


class ArrayStringVariable(ArrayStringSegment, Variable):
class ArrayAnyVariable(ArrayAnySegment, ArrayVariable):
pass


class ArrayNumberVariable(ArrayNumberSegment, Variable):
class ArrayStringVariable(ArrayStringSegment, ArrayVariable):
pass


class ArrayObjectVariable(ArrayObjectSegment, Variable):
class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable):
pass


class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable):
pass


Expand Down
15 changes: 9 additions & 6 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flask import Flask, current_app

from configs import dify_config
from core.variables import IntegerVariable
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.workflow.entities.node_entities import (
NodeRunMetadataKey,
NodeRunResult,
Expand Down Expand Up @@ -75,12 +75,15 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)

if not iterator_list_segment:
raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found")
if not variable:
raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")

if len(iterator_list_segment.value) == 0:
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")

if isinstance(variable, NoneVariable) or len(variable.value) == 0:
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
Expand All @@ -89,7 +92,7 @@ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
)
return

iterator_list_value = iterator_list_segment.to_object()
iterator_list_value = variable.to_object()

if not isinstance(iterator_list_value, list):
raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
Expand Down

0 comments on commit 9cfd1c6

Please sign in to comment.