Skip to content

Commit

Permalink
refactor(api/workflow): Move SystemVariable to workflow/enums.
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Aug 14, 2024
1 parent e19ca0b commit daf3b8b
Show file tree
Hide file tree
Showing 22 changed files with 60 additions and 175 deletions.
2 changes: 1 addition & 1 deletion api/core/app/apps/advanced_chat/app_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from core.file.message_file_parser import MessageFileParser
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from extensions.ext_database import db
from models.account import Account
from models.model import App, Conversation, EndUser, Message
Expand Down
3 changes: 2 additions & 1 deletion api/core/app/apps/advanced_chat/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.answer.entities import TextGenerateRouteChunk, VarGenerateRouteChunk
from events.message_event import message_was_created
Expand Down
2 changes: 1 addition & 1 deletion api/core/app/apps/workflow/app_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
WorkflowAppGenerateEntity,
)
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import UserFrom
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from extensions.ext_database import db
Expand Down
6 changes: 3 additions & 3 deletions api/core/app/apps/workflow/generate_task_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeType
from core.workflow.enums import SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
Expand Down Expand Up @@ -519,7 +520,7 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
"""
nodes = graph.get('nodes')

iteration_ids = [node.get('id') for node in nodes
iteration_ids = [node.get('id') for node in nodes
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
Expand All @@ -530,4 +531,3 @@ def _get_iteration_nested_relations(self, graph: dict) -> dict[str, list[str]]:
node.get('id') for node in nodes if node.get('data', {}).get('iteration_id') == iteration_id
] for iteration_id in iteration_ids
}

6 changes: 0 additions & 6 deletions api/core/app/segments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .segments import (
ArrayAnySegment,
ArraySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
Expand All @@ -13,11 +12,9 @@
from .types import SegmentType
from .variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
Expand All @@ -32,7 +29,6 @@
'FloatVariable',
'ObjectVariable',
'SecretVariable',
'FileVariable',
'StringVariable',
'ArrayAnyVariable',
'Variable',
Expand All @@ -45,11 +41,9 @@
'FloatSegment',
'ObjectSegment',
'ArrayAnySegment',
'FileSegment',
'StringSegment',
'ArrayStringVariable',
'ArrayNumberVariable',
'ArrayObjectVariable',
'ArrayFileVariable',
'ArraySegment',
]
12 changes: 0 additions & 12 deletions api/core/app/segments/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from typing import Any

from configs import dify_config
from core.file.file_obj import FileVar

from .exc import VariableError
from .segments import (
ArrayAnySegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
Expand All @@ -17,11 +15,9 @@
)
from .types import SegmentType
from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
Expand Down Expand Up @@ -49,8 +45,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f'invalid number value {value}')
case SegmentType.FILE:
result = FileVariable.model_validate(mapping)
case SegmentType.OBJECT if isinstance(value, dict):
result = ObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_STRING if isinstance(value, list):
Expand All @@ -59,10 +53,6 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
result = ArrayNumberVariable.model_validate(mapping)
case SegmentType.ARRAY_OBJECT if isinstance(value, list):
result = ArrayObjectVariable.model_validate(mapping)
case SegmentType.ARRAY_FILE if isinstance(value, list):
mapping = dict(mapping)
mapping['value'] = [{'value': v} for v in value]
result = ArrayFileVariable.model_validate(mapping)
case _:
raise VariableError(f'not supported value type {value_type}')
if result.size > dify_config.MAX_VARIABLE_SIZE:
Expand All @@ -83,6 +73,4 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, list):
return ArrayAnySegment(value=value)
if isinstance(value, FileVar):
return FileSegment(value=value)
raise ValueError(f'not supported value {value}')
13 changes: 0 additions & 13 deletions api/core/app/segments/segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from pydantic import BaseModel, ConfigDict, field_validator

from core.file.file_obj import FileVar

from .types import SegmentType


Expand Down Expand Up @@ -78,14 +76,7 @@ class IntegerSegment(Segment):
value: int


class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE
# TODO: embed FileVar in this model.
value: FileVar

@property
def markdown(self) -> str:
return self.value.to_markdown()


class ObjectSegment(Segment):
Expand Down Expand Up @@ -130,7 +121,3 @@ class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]]


class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[FileSegment]
2 changes: 0 additions & 2 deletions api/core/app/segments/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ class SegmentType(str, Enum):
ARRAY_STRING = 'array[string]'
ARRAY_NUMBER = 'array[number]'
ARRAY_OBJECT = 'array[object]'
ARRAY_FILE = 'array[file]'
OBJECT = 'object'
FILE = 'file'

GROUP = 'group'
9 changes: 0 additions & 9 deletions api/core/app/segments/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@

from .segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
Expand Down Expand Up @@ -44,10 +42,6 @@ class IntegerVariable(IntegerSegment, Variable):
pass


class FileVariable(FileSegment, Variable):
pass


class ObjectVariable(ObjectSegment, Variable):
pass

Expand All @@ -68,9 +62,6 @@ class ArrayObjectVariable(ArrayObjectSegment, Variable):
pass


class ArrayFileVariable(ArrayFileSegment, Variable):
pass


class SecretVariable(StringVariable):
value_type: SegmentType = SegmentType.SECRET
Expand Down
4 changes: 2 additions & 2 deletions api/core/app/task_pipeline/workflow_cycle_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable
from models.account import Account
from models.model import EndUser
from models.workflow import Workflow
Expand All @@ -13,4 +13,4 @@ class WorkflowCycleStateManager:
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
_workflow_system_variables: dict[SystemVariable, Any]
_workflow_system_variables: dict[SystemVariable, Any]
29 changes: 4 additions & 25 deletions api/core/workflow/entities/node_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

from pydantic import BaseModel

from models.workflow import WorkflowNodeExecutionStatus
from models import WorkflowNodeExecutionStatus


class NodeType(Enum):
"""
Node Types.
"""

START = 'start'
END = 'end'
ANSWER = 'answer'
Expand Down Expand Up @@ -44,34 +45,11 @@ def value_of(cls, value: str) -> 'NodeType':
raise ValueError(f'invalid node type value {value}')


class SystemVariable(Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'

@classmethod
def value_of(cls, value: str) -> 'SystemVariable':
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')


class NodeRunMetadataKey(Enum):
"""
Node Run Metadata Key.
"""

TOTAL_TOKENS = 'total_tokens'
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
Expand All @@ -84,6 +62,7 @@ class NodeRunResult(BaseModel):
"""
Node Run Result.
"""

status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING

inputs: Optional[Mapping[str, Any]] = None # node inputs
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/entities/variable_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.enums import SystemVariable

VariableValue = Union[str, int, float, dict, list, FileVar]

Expand Down
25 changes: 25 additions & 0 deletions api/core/workflow/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum


class SystemVariable(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'

@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
11 changes: 6 additions & 5 deletions api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
Expand Down Expand Up @@ -201,8 +202,8 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
usage = LLMUsage.empty_usage()

return full_text, usage
def _transform_chat_messages(self,

def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
"""
Expand Down Expand Up @@ -249,13 +250,13 @@ def parse_dict(d: dict) -> str:
# check if it's a context structure
if 'metadata' in d and '_source' in d['metadata'] and 'content' in d:
return d['content']

# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
except Exception:
return str(d)

if isinstance(value, str):
value = value
elif isinstance(value, list):
Expand Down
Loading

0 comments on commit daf3b8b

Please sign in to comment.