Skip to content

Commit

Permalink
fix(workflow): handle special values for process data consistently
Browse files Browse the repository at this point in the history
- Apply `handle_special_values` to `process_data` in workflow cycle management.
- Improve template processing in `AdvancedPromptTransform` with `VariablePool`.
- Make `system_variables` and `user_inputs` optional in `VariablePool` initialization.
  • Loading branch information
laipz8200 committed Oct 14, 2024
1 parent e442d8e commit f61bf84
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
6 changes: 4 additions & 2 deletions api/core/app/task_pipeline/workflow_cycle_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)

inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
Expand All @@ -278,7 +279,7 @@ def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
Expand Down Expand Up @@ -311,6 +312,7 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)

inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
Expand All @@ -320,7 +322,7 @@ def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) ->
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
Expand Down
10 changes: 6 additions & 4 deletions api/core/prompt/advanced_prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.variable_pool import VariablePool


class AdvancedPromptTransform(PromptTransform):
Expand Down Expand Up @@ -144,10 +145,11 @@ def _get_chat_model_prompt_messages(
raw_prompt = prompt_item.text

if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
prompt_inputs = self._set_context_variable(context=context, parser=parser, prompt_inputs=prompt_inputs)
prompt = parser.format(prompt_inputs)
vp = VariablePool()
for k, v in inputs.items():
vp.add(k[1:-1].split("."), v)
raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text
elif prompt_item.edition_type == "jinja2":
prompt = raw_prompt
prompt_inputs = inputs
Expand Down
6 changes: 4 additions & 2 deletions api/core/workflow/entities/variable_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,16 @@ class VariablePool(BaseModel):
def __init__(
self,
*,
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}

super().__init__(
system_variables=system_variables,
Expand Down

0 comments on commit f61bf84

Please sign in to comment.