diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e8463e59d3b5d5..ca4b143027c512 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator -from typing import Optional, Union +from typing import Optional, Union, cast from pydantic import BaseModel, Extra @@ -13,6 +13,7 @@ InvokeFrom, ) from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, QueueAnnotationReplyEvent, QueueErrorEvent, QueueMessageFileEvent, @@ -34,6 +35,8 @@ from core.moderation.output_moderation import ModerationRule, OutputModeration from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk from events.message_event import message_was_created from extensions.ext_database import db from models.account import Account @@ -51,15 +54,26 @@ logger = logging.getLogger(__name__) +class StreamGenerateRoute(BaseModel): + """ + StreamGenerateRoute entity + """ + answer_node_id: str + generate_route: list[GenerateRouteChunk] + current_route_position: int = 0 + + class TaskState(BaseModel): """ TaskState entity """ + class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity """ workflow_node_execution_id: str + node_type: NodeType start_at: float class Config: @@ -77,9 +91,11 @@ class Config: total_tokens: int = 0 total_steps: int = 0 - running_node_execution_infos: dict[str, NodeExecutionInfo] = {} + ran_node_execution_infos: dict[str, NodeExecutionInfo] = {} latest_node_execution_info: Optional[NodeExecutionInfo] = None + current_stream_generate_state: Optional[StreamGenerateRoute] = None + class Config: """Configuration for this pydantic object.""" @@ -122,6 +138,11 @@ def __init__(self, application_generate_entity: AdvancedChatAppGenerateEntity, self._output_moderation_handler = self._init_output_moderation() self._stream = stream + if stream: + self._stream_generate_routes = self._get_stream_generate_routes() + else: + self._stream_generate_routes = None + def process(self) -> Union[dict, Generator]: """ Process generate task pipeline. @@ -290,6 +311,11 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(data) break + self._queue_manager.publish( + QueueAdvancedChatMessageEndEvent(), + PublishFrom.TASK_PIPELINE + ) + workflow_run_response = { 'event': 'workflow_finished', 'task_id': self._application_generate_entity.task_id, @@ -309,7 +335,7 @@ def _process_stream_response(self) -> Generator: } yield self._yield_response(workflow_run_response) - + elif isinstance(event, QueueAdvancedChatMessageEndEvent): # response moderation if self._output_moderation_handler: self._output_moderation_handler.stop_thread() @@ -390,6 +416,11 @@ def _process_stream_response(self) -> Generator: yield self._yield_response(response) elif isinstance(event, QueueTextChunkEvent): + if not self._is_stream_out_support( + event=event + ): + continue + delta_text = event.text if delta_text is None: continue @@ -467,20 +498,28 @@ def _on_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution: latest_node_execution_info = TaskState.NodeExecutionInfo( workflow_node_execution_id=workflow_node_execution.id, + node_type=event.node_type, start_at=time.perf_counter() ) - self._task_state.running_node_execution_infos[event.node_id] = latest_node_execution_info + self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info self._task_state.latest_node_execution_info = latest_node_execution_info self._task_state.total_steps += 1 db.session.close() + # search stream_generate_routes if node id is answer start at node + if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_routes: + self._task_state.current_stream_generate_state = self._stream_generate_routes[event.node_id] + + # stream outputs from start + self._generate_stream_outputs_when_node_start() + return workflow_node_execution def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution: - current_node_execution = self._task_state.running_node_execution_infos[event.node_id] + current_node_execution = self._task_state.ran_node_execution_infos[event.node_id] workflow_node_execution = db.session.query(WorkflowNodeExecution).filter( WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first() if isinstance(event, QueueNodeSucceededEvent): @@ -508,8 +547,8 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven error=event.error ) - # remove running node execution info - del self._task_state.running_node_execution_infos[event.node_id] + # stream outputs when node finished + self._generate_stream_outputs_when_node_finished() db.session.close() @@ -517,7 +556,8 @@ def _on_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEven def _on_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \ -> WorkflowRun: - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first() + workflow_run = (db.session.query(WorkflowRun) + .filter(WorkflowRun.id == self._task_state.workflow_run_id).first()) if isinstance(event, QueueStopEvent): workflow_run = self._workflow_run_failed( workflow_run=workflow_run, @@ -642,7 +682,7 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: QuotaExceededError: { 'code': 'provider_quota_exceeded', 'message': "Your quota for Dify Hosted Model Provider has been exhausted. " - "Please go to Settings -> Model Provider to complete your own provider credentials.", + "Please go to Settings -> Model Provider to complete your own provider credentials.", 'status': 400 }, ModelCurrentlyNotSupportError: {'code': 'model_currently_not_support', 'status': 400}, @@ -660,10 +700,10 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: else: logging.error(e) data = { - 'code': 'internal_server_error', + 'code': 'internal_server_error', 'message': 'Internal Server Error, please contact support.', 'status': 500 - } + } return { 'event': 'error', @@ -730,3 +770,218 @@ def _init_output_moderation(self) -> Optional[OutputModeration]: ), queue_manager=self._queue_manager ) + + def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]: + """ + Get stream generate routes. + :return: + """ + # find all answer nodes + graph = self._workflow.graph_dict + answer_node_configs = [ + node for node in graph['nodes'] + if node.get('data', {}).get('type') == NodeType.ANSWER.value + ] + + # parse stream output node value selectors of answer nodes + stream_generate_routes = {} + for node_config in answer_node_configs: + # get generate route for stream output + answer_node_id = node_config['id'] + generate_route = AnswerNode.extract_generate_route_selectors(node_config) + start_node_id = self._get_answer_start_at_node_id(graph, answer_node_id) + if not start_node_id: + continue + + stream_generate_routes[start_node_id] = StreamGenerateRoute( + answer_node_id=answer_node_id, + generate_route=generate_route + ) + + return stream_generate_routes + + def _get_answer_start_at_node_id(self, graph: dict, target_node_id: str) \ + -> Optional[str]: + """ + Get answer start at node id. + :param graph: graph + :param target_node_id: target node ID + :return: + """ + nodes = graph.get('nodes') + edges = graph.get('edges') + + # fetch all ingoing edges from source node + ingoing_edge = None + for edge in edges: + if edge.get('target') == target_node_id: + ingoing_edge = edge + break + + if not ingoing_edge: + return None + + source_node_id = ingoing_edge.get('source') + source_node = next((node for node in nodes if node.get('id') == source_node_id), None) + if not source_node: + return None + + node_type = source_node.get('data', {}).get('type') + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER + ]: + start_node_id = target_node_id + elif node_type == NodeType.START.value: + start_node_id = source_node_id + else: + start_node_id = self._get_answer_start_at_node_id(graph, source_node_id) + + return start_node_id + + def _generate_stream_outputs_when_node_start(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + for route_chunk in self._task_state.current_stream_generate_state.generate_route: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + else: + break + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _generate_stream_outputs_when_node_finished(self) -> None: + """ + Generate stream outputs. + :return: + """ + if not self._task_state.current_stream_generate_state: + return + + route_chunks = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position:] + + for route_chunk in route_chunks: + if route_chunk.type == 'text': + route_chunk = cast(TextGenerateRouteChunk, route_chunk) + for token in route_chunk.text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + else: + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + route_chunk_node_id = value_selector[0] + + # check chunk node id is before current node id or equal to current node id + if route_chunk_node_id not in self._task_state.ran_node_execution_infos: + break + + latest_node_execution_info = self._task_state.latest_node_execution_info + + # get route chunk node execution info + route_chunk_node_execution_info = self._task_state.ran_node_execution_infos[route_chunk_node_id] + if (route_chunk_node_execution_info.node_type == NodeType.LLM + and latest_node_execution_info.node_type == NodeType.LLM): + # only LLM support chunk stream output + self._task_state.current_stream_generate_state.current_route_position += 1 + continue + + # get route chunk node execution + route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter( + WorkflowNodeExecution.id == route_chunk_node_execution_info.workflow_node_execution_id).first() + + outputs = route_chunk_node_execution.outputs_dict + + # get value from outputs + value = None + for key in value_selector[1:]: + if not value: + value = outputs.get(key) + else: + value = value.get(key) + + if value: + text = None + if isinstance(value, str | int | float): + text = str(value) + elif isinstance(value, object): # TODO FILE + # convert file to markdown + text = f'![]({value.get("url")})' + pass + + if text: + for token in text: + self._queue_manager.publish( + QueueTextChunkEvent( + text=token + ), PublishFrom.TASK_PIPELINE + ) + time.sleep(0.01) + + self._task_state.current_stream_generate_state.current_route_position += 1 + + # all route chunks are generated + if self._task_state.current_stream_generate_state.current_route_position == len( + self._task_state.current_stream_generate_state.generate_route): + self._task_state.current_stream_generate_state = None + + def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool: + """ + Is stream out support + :param event: queue text chunk event + :return: + """ + if not event.metadata: + return True + + if 'node_id' not in event.metadata: + return True + + node_type = event.metadata.get('node_type') + stream_output_value_selector = event.metadata.get('value_selector') + if not stream_output_value_selector: + return False + + if not self._task_state.current_stream_generate_state: + return False + + route_chunk = self._task_state.current_stream_generate_state.generate_route[ + self._task_state.current_stream_generate_state.current_route_position] + + if route_chunk.type != 'var': + return False + + if node_type != NodeType.LLM: + # only LLM support chunk stream output + return False + + route_chunk = cast(VarGenerateRouteChunk, route_chunk) + value_selector = route_chunk.value_selector + + # check chunk node id is before current node id or equal to current node id + if value_selector != stream_output_value_selector: + return False + + return True diff --git a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py index b4a6a9602f6c51..972fda2d49a66c 100644 --- a/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py +++ b/api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py @@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback): def __init__(self, queue_manager: AppQueueManager, workflow: Workflow): self._queue_manager = queue_manager - self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict) def on_workflow_run_started(self) -> None: """ @@ -114,34 +113,16 @@ def on_workflow_node_execute_failed(self, node_id: str, PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ - if node_id in self._streamable_node_ids: - self._queue_manager.publish( - QueueTextChunkEvent( - text=text - ), PublishFrom.APPLICATION_MANAGER - ) - - def _fetch_streamable_node_ids(self, graph: dict) -> list[str]: - """ - Fetch streamable node ids - When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output - When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output - - :param graph: workflow graph - :return: - """ - streamable_node_ids = [] - end_node_ids = [] - for node_config in graph.get('nodes'): - if node_config.get('data', {}).get('type') == NodeType.END.value: - end_node_ids.append(node_config.get('id')) - - for edge_config in graph.get('edges'): - if edge_config.get('target') in end_node_ids: - streamable_node_ids.append(edge_config.get('source')) - - return streamable_node_ids + self._queue_manager.publish( + QueueTextChunkEvent( + text=text, + metadata={ + "node_id": node_id, + **metadata + } + ), PublishFrom.APPLICATION_MANAGER + ) diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 6d0a71f495e328..f4ff44dddac9ef 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -3,12 +3,11 @@ from core.app.entities.queue_entities import ( AppQueueEvent, MessageQueueMessage, + QueueAdvancedChatMessageEndEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, QueueStopEvent, - QueueWorkflowFailedEvent, - QueueWorkflowSucceededEvent, ) @@ -54,8 +53,7 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None: if isinstance(event, QueueStopEvent | QueueErrorEvent | QueueMessageEndEvent - | QueueWorkflowSucceededEvent - | QueueWorkflowFailedEvent): + | QueueAdvancedChatMessageEndEvent): self.stop_listen() if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): diff --git a/api/core/app/apps/workflow/workflow_event_trigger_callback.py b/api/core/app/apps/workflow/workflow_event_trigger_callback.py index 59ef44cd2e4ade..e5a8e8d3747c42 100644 --- a/api/core/app/apps/workflow/workflow_event_trigger_callback.py +++ b/api/core/app/apps/workflow/workflow_event_trigger_callback.py @@ -112,7 +112,7 @@ def on_workflow_node_execute_failed(self, node_id: str, PublishFrom.APPLICATION_MANAGER ) - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 153607e1b4473e..5c31996fd345a6 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -17,6 +17,7 @@ class QueueEvent(Enum): AGENT_MESSAGE = "agent_message" MESSAGE_REPLACE = "message_replace" MESSAGE_END = "message_end" + ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end" WORKFLOW_STARTED = "workflow_started" WORKFLOW_SUCCEEDED = "workflow_succeeded" WORKFLOW_FAILED = "workflow_failed" @@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent): """ event = QueueEvent.TEXT_CHUNK text: str + metadata: Optional[dict] = None class QueueAgentMessageEvent(AppQueueEvent): @@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent): QueueMessageEndEvent entity """ event = QueueEvent.MESSAGE_END - llm_result: LLMResult + llm_result: Optional[LLMResult] = None + + +class QueueAdvancedChatMessageEndEvent(AppQueueEvent): + """ + QueueAdvancedChatMessageEndEvent entity + """ + event = QueueEvent.ADVANCED_CHAT_MESSAGE_END class QueueWorkflowStartedEvent(AppQueueEvent): diff --git a/api/core/workflow/callbacks/base_workflow_callback.py b/api/core/workflow/callbacks/base_workflow_callback.py index 9594fa20372064..1f5472b430c96a 100644 --- a/api/core/workflow/callbacks/base_workflow_callback.py +++ b/api/core/workflow/callbacks/base_workflow_callback.py @@ -64,7 +64,7 @@ def on_workflow_node_execute_failed(self, node_id: str, raise NotImplementedError @abstractmethod - def on_node_text_chunk(self, node_id: str, text: str) -> None: + def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None: """ Publish text chunk """ diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 97ddafad019470..d8ff5cb6f630d1 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -4,7 +4,12 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import ValueType, VariablePool -from core.workflow.nodes.answer.entities import AnswerNodeData +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) from core.workflow.nodes.base_node import BaseNode from models.workflow import WorkflowNodeExecutionStatus @@ -22,49 +27,29 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data = self.node_data node_data = cast(self._node_data_cls, node_data) - variable_values = {} - for variable_selector in node_data.variables: - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector, - target_value_type=ValueType.STRING - ) - - variable_values[variable_selector.variable] = value - - variable_keys = list(variable_values.keys()) - - # format answer template - template_parser = PromptTemplateParser(node_data.answer) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - split_template = [ - { - "type": "var" if self._is_variable(part, variable_keys) else "text", - "value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part - } - for part in template.split('Ω') if part - ] + # generate routes + generate_routes = self.extract_generate_route_from_node_data(node_data) answer = [] - for part in split_template: - if part["type"] == "var": - value = variable_values.get(part["value"].replace('{{', '').replace('}}', '')) + for part in generate_routes: + if part.type == "var": + part = cast(VarGenerateRouteChunk, part) + value_selector = part.value_selector + value = variable_pool.get_variable_value( + variable_selector=value_selector, + target_value_type=ValueType.STRING + ) + answer_part = { "type": "text", "text": value } # TODO File else: + part = cast(TextGenerateRouteChunk, part) answer_part = { "type": "text", - "text": part["value"] + "text": part.text } if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text": @@ -75,6 +60,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: if len(answer) == 1 and answer[0]["type"] == "text": answer = answer[0]["text"] + # re-fetch variable values + variable_values = {} + for variable_selector in node_data.variables: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + variable_values[variable_selector.variable] = value + return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variable_values, @@ -83,7 +78,61 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: } ) - def _is_variable(self, part, variable_keys): + @classmethod + def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = cls._node_data_cls(**config.get("data", {})) + node_data = cast(cls._node_data_cls, node_data) + + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in node_data.variables + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(node_data.answer) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _is_variable(cls, part, variable_keys): cleaned_part = part.replace('{{', '').replace('}}', '') return part.startswith('{{') and cleaned_part in variable_keys diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 7c6fed3e4ea6f4..8aed752ccb55e6 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,6 @@ + +from pydantic import BaseModel + from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.variable_entities import VariableSelector @@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData): """ variables: list[VariableSelector] = [] answer: str + + +class GenerateRouteChunk(BaseModel): + """ + Generate Route Chunk. + """ + type: str + + +class VarGenerateRouteChunk(GenerateRouteChunk): + """ + Var Generate Route Chunk. + """ + type: str = "var" + value_selector: list[str] + + +class TextGenerateRouteChunk(GenerateRouteChunk): + """ + Text Generate Route Chunk. + """ + type: str = "text" + text: str diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 2da19bc409d379..7cc9c6ee3dba81 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -86,17 +86,22 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult: self.node_run_result = result return result - def publish_text_chunk(self, text: str) -> None: + def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: """ Publish text chunk :param text: chunk text + :param value_selector: value selector :return: """ if self.callbacks: for callback in self.callbacks: callback.on_node_text_chunk( node_id=self.node_id, - text=text + text=text, + metadata={ + "node_type": self.node_type, + "value_selector": value_selector + } ) @classmethod diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 9285bbe74e87a1..cb5a33309141dd 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -169,7 +169,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage text = result.delta.message.content full_text += text - self.publish_text_chunk(text=text) + self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text']) if not model: model = result.model