diff --git a/api/core/workflow/entities/variable_entities.py b/api/core/workflow/entities/variable_entities.py index 19d9af2a6171a4..169904ba2924c5 100644 --- a/api/core/workflow/entities/variable_entities.py +++ b/api/core/workflow/entities/variable_entities.py @@ -1,5 +1,10 @@ +from typing import TYPE_CHECKING + from pydantic import BaseModel +if TYPE_CHECKING: + from core.workflow.entities.variable_pool import ValueType + class VariableSelector(BaseModel): """ @@ -7,3 +12,14 @@ class VariableSelector(BaseModel): """ variable: str value_selector: list[str] + + +class GlobalVariable(BaseModel): + """ + Global Variable. + """ + name: str + value: str + value_type: "ValueType" + is_secret: bool + exportable: bool diff --git a/api/core/workflow/utils/variable_template_parser.py b/api/core/workflow/utils/variable_template_parser.py index 9c1dc6086aa99b..018203d4034616 100644 --- a/api/core/workflow/utils/variable_template_parser.py +++ b/api/core/workflow/utils/variable_template_parser.py @@ -7,11 +7,35 @@ class VariableTemplateParser: """ + A class for parsing and manipulating template variables in a string. + Rules: 1. Template variables must be enclosed in `{{}}`. 2. The template variable Key can only be: #node_id.var1.var2#. 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. + + Example usage: + + template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." + parser = VariableTemplateParser(template) + + # Extract template variable keys + variable_keys = parser.extract() + print(variable_keys) + # Output: ['#node_id.query.name#', '#node_id.query.age#'] + + # Extract variable selectors + variable_selectors = parser.extract_variable_selectors() + print(variable_selectors) + # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), + # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] + + # Format the template string + inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} + formatted_string = parser.format(inputs) + print(formatted_string) + # Output: "Hello, John! Your age is 25." """ def __init__(self, template: str): @@ -19,6 +43,12 @@ def __init__(self, template: str): self.variable_keys = self.extract() def extract(self) -> list: + """ + Extracts all the template variable keys from the template string. + + Returns: + A list of template variable keys. + """ # Regular expression to match the template rules matches = re.findall(REGEX, self.template) @@ -27,6 +57,12 @@ def extract(self) -> list: return list(set(first_group_matches)) def extract_variable_selectors(self) -> list[VariableSelector]: + """ + Extracts the variable selectors from the template variable keys. + + Returns: + A list of VariableSelector objects representing the variable selectors. + """ variable_selectors = [] for variable_key in self.variable_keys: remove_hash = variable_key.replace('#', '') @@ -42,9 +78,20 @@ def extract_variable_selectors(self) -> list[VariableSelector]: return variable_selectors def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + """ + Formats the template string by replacing the template variables with their corresponding values. + + Args: + inputs: A dictionary containing the values for the template variables. + remove_template_variables: A boolean indicating whether to remove the template variables from the output. + + Returns: + The formatted string with template variables replaced by their values. + """ def replacer(match): key = match.group(1) value = inputs.get(key, match.group(0)) # return original matched string if key not found + print(key, value) # convert the value to string if isinstance(value, list | dict | bool | int | float): value = str(value) @@ -59,4 +106,13 @@ def replacer(match): @classmethod def remove_template_variables(cls, text: str): + """ + Removes the template variables from the given text. + + Args: + text: The text from which to remove the template variables. + + Returns: + The text with template variables removed. + """ return re.sub(REGEX, r'{\1}', text) diff --git a/api/migrations/versions/5f6291109057_add_to_workflows.py b/api/migrations/versions/5f6291109057_add_to_workflows.py new file mode 100644 index 00000000000000..997284d12e23ab --- /dev/null +++ b/api/migrations/versions/5f6291109057_add_to_workflows.py @@ -0,0 +1,36 @@ +"""add to workflows + +Revision ID: 5f6291109057 +Revises: 4e99a8df00ff +Create Date: 2024-06-14 06:30:41.326661 + +""" +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = '5f6291109057' +down_revision = '4e99a8df00ff' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.add_column(sa.Column('global_variables', sa.Text(), nullable=True)) + op.execute("UPDATE workflows SET global_variables = '{}' WHERE global_variables IS NULL") + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.alter_column('global_variables', nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflows', schema=None) as batch_op: + batch_op.drop_column('global_variables') + + # ### end Alembic commands ### diff --git a/api/models/workflow.py b/api/models/workflow.py index d9bc7848787925..a1a2b629e2f15f 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,12 +1,16 @@ import json from enum import Enum -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union +from core.workflow.entities.variable_entities import GlobalVariable from extensions.ext_database import db from libs import helper from models import StringUUID from models.account import Account +if TYPE_CHECKING: + from models.model import AppMode, Message + class CreatedByRole(Enum): """ @@ -108,6 +112,7 @@ class Workflow(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) features = db.Column(db.Text) + global_variables = db.Column(db.Text, nullable=False, default='{}') created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_by = db.Column(StringUUID) @@ -135,7 +140,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list: return [] graph_dict = self.graph_dict - if 'nodes' not in graph_dict: + if not graph_dict or 'nodes' not in graph_dict: return [] start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None) @@ -177,6 +182,55 @@ def tool_published(self) -> bool: WorkflowToolProvider.app_id == self.app_id ).first() is not None + @property + def global_variables_dict(self) -> dict[str, GlobalVariable]: + """ + Converts the global_variables attribute from JSON to a dictionary of GlobalVariable objects. + + Returns: + A dictionary containing the converted global variables, where the keys are the variable names + and the values are instances of the GlobalVariable class. + """ + dict_ = json.loads(self.global_variables) + return { + k: GlobalVariable.model_validate(v) for k, v in dict_.items() + } + + def get_global_variable(self, name: str) -> Optional[GlobalVariable]: + """ + Get a global variable by name. + + Args: + name: The name of the global variable. + + Returns: + The global variable with the given name, or None if it does not exist. + """ + return self.global_variables_dict.get(name) + + def update_global_variable(self, value: GlobalVariable): + """ + Create or update a global variable. + + Args: + value: The new value of the global variable. + """ + name = value.name + global_variables = self.global_variables_dict + global_variables[name] = value + self.global_variables = json.dumps(global_variables) + + def remove_global_variable(self, name: str): + """ + Remove a global variable by name. + + Args: + name: The name of the global variable. + """ + global_variables = self.global_variables_dict + global_variables.pop(name, None) + self.global_variables = json.dumps(global_variables) + class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum