Skip to content

Commit

Permalink
refactor(workflow): enhance node genericity and streamline node data …
Browse files Browse the repository at this point in the history
…usage

- Implement generic typing for BaseNode to improve type safety.
- Remove unnecessary type casting, simplifying data handling.
- Add variable selector extraction refinement to enhance code readability.
- Standardize method parameters and variable access patterns across node classes.
  • Loading branch information
laipz8200 committed Oct 9, 2024
1 parent 3e5d0cf commit 8997695
Show file tree
Hide file tree
Showing 20 changed files with 168 additions and 180 deletions.
16 changes: 7 additions & 9 deletions api/core/workflow/nodes/answer/answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from models.workflow import WorkflowNodeExecutionStatus


class AnswerNode(BaseNode):
class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER

Expand All @@ -25,11 +25,8 @@ def _run(self) -> NodeRunResult:
Run node
:return:
"""
node_data = self.node_data
node_data = cast(AnswerNodeData, node_data)

# generate routes
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(node_data)
generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)

answer = ""
files = []
Expand All @@ -52,7 +49,11 @@ def _run(self) -> NodeRunResult:

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: AnswerNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: AnswerNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand All @@ -61,9 +62,6 @@ def _extract_variable_selector_to_variable_mapping(
:param node_data: node data
:return:
"""
node_data = node_data
node_data = cast(AnswerNodeData, node_data)

variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()

Expand Down
23 changes: 16 additions & 7 deletions api/core/workflow/nodes/base_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from typing import Any, Generic, Optional, TypeVar, cast

from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult
Expand All @@ -15,8 +15,10 @@

logger = logging.getLogger(__name__)

GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)

class BaseNode(ABC):

class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType

Expand Down Expand Up @@ -50,7 +52,7 @@ def __init__(
raise ValueError("Node ID is required.")

self.node_id = node_id
self.node_data = self._node_data_cls(**config.get("data", {}))
self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {})))

@abstractmethod
def _run(self) -> NodeRunResult | Generator[RunEvent | InNodeEvent, None, None]:
Expand All @@ -77,7 +79,10 @@ def run(self) -> Generator[RunEvent | InNodeEvent, None, None]:

@classmethod
def extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], config: dict
cls,
*,
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand All @@ -91,12 +96,16 @@ def extract_variable_selector_to_variable_mapping(

node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=node_data
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
)

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: BaseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: GenericNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand Down
25 changes: 11 additions & 14 deletions api/core/workflow/nodes/code/code_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast
from typing import Any, Optional, Union

from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
Expand All @@ -13,7 +13,7 @@
from models.workflow import WorkflowNodeExecutionStatus


class CodeNode(BaseNode):
class CodeNode(BaseNode[CodeNodeData]):
_node_data_cls = CodeNodeData
_node_type = NodeType.CODE

Expand All @@ -34,20 +34,13 @@ def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return code_provider.get_default_config()

def _run(self) -> NodeRunResult:
"""
Run code
:return:
"""
node_data = self.node_data
node_data = cast(CodeNodeData, node_data)

# Get code language
code_language = node_data.code_language
code = node_data.code
code_language = self.node_data.code_language
code = self.node_data.code

# Get variables
variables = {}
for variable_selector in node_data.variables:
for variable_selector in self.node_data.variables:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)

Expand All @@ -61,7 +54,7 @@ def _run(self) -> NodeRunResult:
)

# Transform result
result = self._transform_result(result, node_data.outputs)
result = self._transform_result(result, self.node_data.outputs)
except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))

Expand Down Expand Up @@ -317,7 +310,11 @@ def _transform_result(

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: CodeNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: CodeNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import io
from typing import cast

import docx
import pandas as pd
Expand All @@ -24,7 +23,7 @@
from .models import DocumentExtractorNodeData


class DocumentExtractorNode(BaseNode):
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
Expand All @@ -34,8 +33,7 @@ class DocumentExtractorNode(BaseNode):
_node_type = NodeType.DOCUMENT_EXTRACTOR

def _run(self):
node_data = cast(DocumentExtractorNodeData, self.node_data)
variable_selector = node_data.variable_selector
variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)

if variable is None:
Expand Down
14 changes: 8 additions & 6 deletions api/core/workflow/nodes/end/end_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Any

from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base_node import BaseNode
Expand All @@ -8,7 +8,7 @@
from models.workflow import WorkflowNodeExecutionStatus


class EndNode(BaseNode):
class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData
_node_type = NodeType.END

Expand All @@ -17,9 +17,7 @@ def _run(self) -> NodeRunResult:
Run node
:return:
"""
node_data = self.node_data
node_data = cast(EndNodeData, node_data)
output_variables = node_data.outputs
output_variables = self.node_data.outputs

outputs = {}
for variable_selector in output_variables:
Expand All @@ -35,7 +33,11 @@ def _run(self) -> NodeRunResult:

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: EndNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: EndNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand Down
11 changes: 5 additions & 6 deletions api/core/workflow/nodes/http_request/http_request_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Mapping, Sequence
from mimetypes import guess_extension
from os import path
from typing import Any, cast
from typing import Any

from configs import dify_config
from core.file import File, FileTransferMethod, FileType
Expand All @@ -28,7 +28,7 @@
logger = logging.getLogger(__name__)


class HttpRequestNode(BaseNode):
class HttpRequestNode(BaseNode[HttpRequestNodeData]):
_node_data_cls = HttpRequestNodeData
_node_type = NodeType.HTTP_REQUEST

Expand All @@ -52,13 +52,11 @@ def get_default_config(cls, filters: dict | None = None) -> dict:
}

def _run(self) -> NodeRunResult:
node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data)

process_data = {}
try:
http_executor = HttpExecutor(
node_data=node_data,
timeout=self._get_request_timeout(node_data),
node_data=self.node_data,
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
)
process_data["request"] = http_executor.to_log()
Expand Down Expand Up @@ -99,6 +97,7 @@ def _get_request_timeout(node_data: HttpRequestNodeData) -> HttpRequestNodeTimeo
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: HttpRequestNodeData,
Expand Down
21 changes: 11 additions & 10 deletions api/core/workflow/nodes/if_else/if_else_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal, cast
from typing import Any, Literal

from typing_extensions import deprecated

Expand All @@ -13,7 +13,7 @@
from models.workflow import WorkflowNodeExecutionStatus


class IfElseNode(BaseNode):
class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE

Expand All @@ -22,9 +22,6 @@ def _run(self) -> NodeRunResult:
Run node
:return:
"""
node_data = self.node_data
node_data = cast(IfElseNodeData, node_data)

node_inputs: dict[str, list] = {"conditions": []}

process_datas: dict[str, list] = {"condition_results": []}
Expand All @@ -35,8 +32,8 @@ def _run(self) -> NodeRunResult:
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
if node_data.cases:
for case in node_data.cases:
if self.node_data.cases:
for case in self.node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
Expand All @@ -62,8 +59,8 @@ def _run(self) -> NodeRunResult:
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
conditions=node_data.conditions or [],
operator=node_data.logical_operator or "and",
conditions=self.node_data.conditions or [],
operator=self.node_data.logical_operator or "and",
)

selected_case_id = "true" if final_result else "false"
Expand Down Expand Up @@ -93,7 +90,11 @@ def _run(self) -> NodeRunResult:

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IfElseNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand Down
9 changes: 6 additions & 3 deletions api/core/workflow/nodes/iteration/iteration_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
logger = logging.getLogger(__name__)


class IterationNode(BaseNode):
class IterationNode(BaseNode[IterationNodeData]):
"""
Iteration Node.
"""
Expand All @@ -41,7 +41,6 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
self.node_data = cast(IterationNodeData, self.node_data)
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)

if not iterator_list_segment:
Expand Down Expand Up @@ -248,7 +247,11 @@ def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:

@classmethod
def _extract_variable_selector_to_variable_mapping(
cls, graph_config: Mapping[str, Any], node_id: str, node_data: IterationNodeData
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IterationNodeData,
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
Expand Down
Loading

0 comments on commit 8997695

Please sign in to comment.