From 98d85e6b747d87d5f9cdceb140caf425a1f19a6d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 25 Nov 2024 18:16:55 +0800 Subject: [PATCH] fix: WorkflowNodeExecution.created_at may be earlier than WorkflowRun.created_at (#11070) --- .../task_pipeline/workflow_cycle_manage.py | 49 ++++++++++--------- api/models/workflow.py | 10 ++-- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 9229cbcc0a7c8e..d45726af466538 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -3,6 +3,7 @@ from collections.abc import Mapping, Sequence from datetime import UTC, datetime from typing import Any, Optional, Union, cast +from uuid import uuid4 from sqlalchemy.orm import Session @@ -80,38 +81,38 @@ def _handle_workflow_run_start(self) -> WorkflowRun: inputs[f"sys.{key.value}"] = value - inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN ) + # handle special values + inputs = WorkflowEntry.handle_special_values(inputs) + # init workflow run - workflow_run = WorkflowRun() - workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - if workflow_run_id: - workflow_run.id = workflow_run_id - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING.value - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value - ) - workflow_run.created_by = self._user.id + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = WorkflowRun() + system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] + workflow_run.id = system_id or str(uuid4()) + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER + ) + workflow_run.created_by = self._user.id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - db.session.add(workflow_run) - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() + session.add(workflow_run) + session.commit() return workflow_run diff --git a/api/models/workflow.py b/api/models/workflow.py index 5b0617828d0f7f..fd53f137f906bf 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping, Sequence from datetime import UTC, datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional, Union import sqlalchemy as sa @@ -314,7 +314,7 @@ def conversation_variables(self, value: Sequence[Variable]) -> None: ) -class WorkflowRunStatus(Enum): +class WorkflowRunStatus(StrEnum): """ Workflow Run Status Enum """ @@ -393,13 +393,13 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - outputs: Mapped[str] = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped + outputs: Mapped[str] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime)