Skip to content

Commit

Permalink
Merge branch 'feat/workflow-backend' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Mar 14, 2024
2 parents fd95393 + e6b8b13 commit e8ccfe3
Show file tree
Hide file tree
Showing 10 changed files with 408 additions and 85 deletions.
277 changes: 266 additions & 11 deletions api/core/app/apps/advanced_chat/generate_task_pipeline.py

Large diffs are not rendered by default.

39 changes: 10 additions & 29 deletions api/core/app/apps/advanced_chat/workflow_event_trigger_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):

def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
self._queue_manager = queue_manager
self._streamable_node_ids = self._fetch_streamable_node_ids(workflow.graph_dict)

def on_workflow_run_started(self) -> None:
"""
Expand Down Expand Up @@ -114,34 +113,16 @@ def on_workflow_node_execute_failed(self, node_id: str,
PublishFrom.APPLICATION_MANAGER
)

def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
if node_id in self._streamable_node_ids:
self._queue_manager.publish(
QueueTextChunkEvent(
text=text
), PublishFrom.APPLICATION_MANAGER
)

def _fetch_streamable_node_ids(self, graph: dict) -> list[str]:
"""
Fetch streamable node ids
When the Workflow type is chat, only the nodes before END Node are LLM or Direct Answer can be streamed output
When the Workflow type is workflow, only the nodes before END Node (only Plain Text mode) are LLM can be streamed output
:param graph: workflow graph
:return:
"""
streamable_node_ids = []
end_node_ids = []
for node_config in graph.get('nodes'):
if node_config.get('data', {}).get('type') == NodeType.END.value:
end_node_ids.append(node_config.get('id'))

for edge_config in graph.get('edges'):
if edge_config.get('target') in end_node_ids:
streamable_node_ids.append(edge_config.get('source'))

return streamable_node_ids
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
6 changes: 2 additions & 4 deletions api/core/app/apps/message_based_app_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueErrorEvent,
QueueMessage,
QueueMessageEndEvent,
QueueStopEvent,
QueueWorkflowFailedEvent,
QueueWorkflowSucceededEvent,
)


Expand Down Expand Up @@ -54,8 +53,7 @@ def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
if isinstance(event, QueueStopEvent
| QueueErrorEvent
| QueueMessageEndEvent
| QueueWorkflowSucceededEvent
| QueueWorkflowFailedEvent):
| QueueAdvancedChatMessageEndEvent):
self.stop_listen()

if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def on_workflow_node_execute_failed(self, node_id: str,
PublishFrom.APPLICATION_MANAGER
)

def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
Expand Down
11 changes: 10 additions & 1 deletion api/core/app/entities/queue_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class QueueEvent(Enum):
AGENT_MESSAGE = "agent_message"
MESSAGE_REPLACE = "message_replace"
MESSAGE_END = "message_end"
ADVANCED_CHAT_MESSAGE_END = "advanced_chat_message_end"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_SUCCEEDED = "workflow_succeeded"
WORKFLOW_FAILED = "workflow_failed"
Expand Down Expand Up @@ -53,6 +54,7 @@ class QueueTextChunkEvent(AppQueueEvent):
"""
event = QueueEvent.TEXT_CHUNK
text: str
metadata: Optional[dict] = None


class QueueAgentMessageEvent(AppQueueEvent):
Expand Down Expand Up @@ -92,7 +94,14 @@ class QueueMessageEndEvent(AppQueueEvent):
QueueMessageEndEvent entity
"""
event = QueueEvent.MESSAGE_END
llm_result: LLMResult
llm_result: Optional[LLMResult] = None


class QueueAdvancedChatMessageEndEvent(AppQueueEvent):
"""
QueueAdvancedChatMessageEndEvent entity
"""
event = QueueEvent.ADVANCED_CHAT_MESSAGE_END


class QueueWorkflowStartedEvent(AppQueueEvent):
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/callbacks/base_workflow_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def on_workflow_node_execute_failed(self, node_id: str,
raise NotImplementedError

@abstractmethod
def on_node_text_chunk(self, node_id: str, text: str) -> None:
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
"""
Publish text chunk
"""
Expand Down
119 changes: 84 additions & 35 deletions api/core/workflow/nodes/answer/answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import ValueType, VariablePool
from core.workflow.nodes.answer.entities import AnswerNodeData
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.nodes.base_node import BaseNode
from models.workflow import WorkflowNodeExecutionStatus

Expand All @@ -22,49 +27,29 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)

variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)

variable_values[variable_selector.variable] = value

variable_keys = list(variable_values.keys())

# format answer template
template_parser = PromptTemplateParser(node_data.answer)
template_variable_keys = template_parser.variable_keys

# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))

template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')

split_template = [
{
"type": "var" if self._is_variable(part, variable_keys) else "text",
"value": part.replace('Ω', '') if self._is_variable(part, variable_keys) else part
}
for part in template.split('Ω') if part
]
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)

answer = []
for part in split_template:
if part["type"] == "var":
value = variable_values.get(part["value"].replace('{{', '').replace('}}', ''))
for part in generate_routes:
if part.type == "var":
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = variable_pool.get_variable_value(
variable_selector=value_selector,
target_value_type=ValueType.STRING
)

answer_part = {
"type": "text",
"text": value
}
# TODO File
else:
part = cast(TextGenerateRouteChunk, part)
answer_part = {
"type": "text",
"text": part["value"]
"text": part.text
}

if len(answer) > 0 and answer[-1]["type"] == "text" and answer_part["type"] == "text":
Expand All @@ -75,6 +60,16 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
if len(answer) == 1 and answer[0]["type"] == "text":
answer = answer[0]["text"]

# re-fetch variable values
variable_values = {}
for variable_selector in node_data.variables:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector,
target_value_type=ValueType.STRING
)

variable_values[variable_selector.variable] = value

return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variable_values,
Expand All @@ -83,7 +78,61 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
}
)

def _is_variable(self, part, variable_keys):
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)

return cls.extract_generate_route_from_node_data(node_data)

@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in node_data.variables
}

variable_keys = list(value_selector_mapping.keys())

# format answer template
template_parser = PromptTemplateParser(node_data.answer)
template_variable_keys = template_parser.variable_keys

# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))

template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')

generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))

return generate_routes

@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys

Expand Down
26 changes: 26 additions & 0 deletions api/core/workflow/nodes/answer/entities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

from pydantic import BaseModel

from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.variable_entities import VariableSelector

Expand All @@ -8,3 +11,26 @@ class AnswerNodeData(BaseNodeData):
"""
variables: list[VariableSelector] = []
answer: str


class GenerateRouteChunk(BaseModel):
"""
Generate Route Chunk.
"""
type: str


class VarGenerateRouteChunk(GenerateRouteChunk):
"""
Var Generate Route Chunk.
"""
type: str = "var"
value_selector: list[str]


class TextGenerateRouteChunk(GenerateRouteChunk):
"""
Text Generate Route Chunk.
"""
type: str = "text"
text: str
9 changes: 7 additions & 2 deletions api/core/workflow/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,22 @@ def run(self, variable_pool: VariablePool) -> NodeRunResult:
self.node_run_result = result
return result

def publish_text_chunk(self, text: str) -> None:
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text
text=text,
metadata={
"node_type": self.node_type,
"value_selector": value_selector
}
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage
text = result.delta.message.content
full_text += text

self.publish_text_chunk(text=text)
self.publish_text_chunk(text=text, value_selector=[self.node_id, 'text'])

if not model:
model = result.model
Expand Down

0 comments on commit e8ccfe3

Please sign in to comment.