Skip to content

Commit

Permalink
feat: Add global variable support to workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
laipz8200 committed Jun 14, 2024
1 parent 9ac6778 commit 8c6fb80
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 2 deletions.
16 changes: 16 additions & 0 deletions api/core/workflow/entities/variable_entities.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel

if TYPE_CHECKING:
from core.workflow.entities.variable_pool import ValueType


class VariableSelector(BaseModel):
"""
Variable Selector.
"""
variable: str
value_selector: list[str]


class GlobalVariable(BaseModel):
"""
Global Variable.
"""
name: str
value: str
value_type: "ValueType"
is_secret: bool
exportable: bool
56 changes: 56 additions & 0 deletions api/core/workflow/utils/variable_template_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,48 @@

class VariableTemplateParser:
"""
A class for parsing and manipulating template variables in a string.
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: #node_id.var1.var2#.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
Example usage:
template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}."
parser = VariableTemplateParser(template)
# Extract template variable keys
variable_keys = parser.extract()
print(variable_keys)
# Output: ['#node_id.query.name#', '#node_id.query.age#']
# Extract variable selectors
variable_selectors = parser.extract_variable_selectors()
print(variable_selectors)
# Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']),
# VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])]
# Format the template string
inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}}
formatted_string = parser.format(inputs)
print(formatted_string)
# Output: "Hello, John! Your age is 25."
"""

def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()

def extract(self) -> list:
"""
Extracts all the template variable keys from the template string.
Returns:
A list of template variable keys.
"""
# Regular expression to match the template rules
matches = re.findall(REGEX, self.template)

Expand All @@ -27,6 +57,12 @@ def extract(self) -> list:
return list(set(first_group_matches))

def extract_variable_selectors(self) -> list[VariableSelector]:
"""
Extracts the variable selectors from the template variable keys.
Returns:
A list of VariableSelector objects representing the variable selectors.
"""
variable_selectors = []
for variable_key in self.variable_keys:
remove_hash = variable_key.replace('#', '')
Expand All @@ -42,9 +78,20 @@ def extract_variable_selectors(self) -> list[VariableSelector]:
return variable_selectors

def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
"""
Formats the template string by replacing the template variables with their corresponding values.
Args:
inputs: A dictionary containing the values for the template variables.
remove_template_variables: A boolean indicating whether to remove the template variables from the output.
Returns:
The formatted string with template variables replaced by their values.
"""
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
print(key, value)
# convert the value to string
if isinstance(value, list | dict | bool | int | float):
value = str(value)
Expand All @@ -59,4 +106,13 @@ def replacer(match):

@classmethod
def remove_template_variables(cls, text: str):
"""
Removes the template variables from the given text.
Args:
text: The text from which to remove the template variables.
Returns:
The text with template variables removed.
"""
return re.sub(REGEX, r'{\1}', text)
36 changes: 36 additions & 0 deletions api/migrations/versions/5f6291109057_add_to_workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""add to workflows
Revision ID: 5f6291109057
Revises: 4e99a8df00ff
Create Date: 2024-06-14 06:30:41.326661
"""
import sqlalchemy as sa
from alembic import op

import models as models

# revision identifiers, used by Alembic.
revision = '5f6291109057'
down_revision = '4e99a8df00ff'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.add_column(sa.Column('global_variables', sa.Text(), nullable=True))
op.execute("UPDATE workflows SET global_variables = '{}' WHERE global_variables IS NULL")
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.alter_column('global_variables', nullable=False)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflows', schema=None) as batch_op:
batch_op.drop_column('global_variables')

# ### end Alembic commands ###
58 changes: 56 additions & 2 deletions api/models/workflow.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import json
from enum import Enum
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union

from core.workflow.entities.variable_entities import GlobalVariable
from extensions.ext_database import db
from libs import helper
from models import StringUUID
from models.account import Account

if TYPE_CHECKING:
from models.model import AppMode, Message


class CreatedByRole(Enum):
"""
Expand Down Expand Up @@ -108,6 +112,7 @@ class Workflow(db.Model):
version = db.Column(db.String(255), nullable=False)
graph = db.Column(db.Text)
features = db.Column(db.Text)
global_variables = db.Column(db.Text, nullable=False, default='{}')
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
updated_by = db.Column(StringUUID)
Expand Down Expand Up @@ -135,7 +140,7 @@ def user_input_form(self, to_old_structure: bool = False) -> list:
return []

graph_dict = self.graph_dict
if 'nodes' not in graph_dict:
if not graph_dict or 'nodes' not in graph_dict:
return []

start_node = next((node for node in graph_dict['nodes'] if node['data']['type'] == 'start'), None)
Expand Down Expand Up @@ -177,6 +182,55 @@ def tool_published(self) -> bool:
WorkflowToolProvider.app_id == self.app_id
).first() is not None

@property
def global_variables_dict(self) -> dict[str, GlobalVariable]:
"""
Converts the global_variables attribute from JSON to a dictionary of GlobalVariable objects.
Returns:
A dictionary containing the converted global variables, where the keys are the variable names
and the values are instances of the GlobalVariable class.
"""
dict_ = json.loads(self.global_variables)
return {
k: GlobalVariable.model_validate(v) for k, v in dict_.items()
}

def get_global_variable(self, name: str) -> Optional[GlobalVariable]:
"""
Get a global variable by name.
Args:
name: The name of the global variable.
Returns:
The global variable with the given name, or None if it does not exist.
"""
return self.global_variables_dict.get(name)

def update_global_variable(self, value: GlobalVariable):
"""
Create or update a global variable.
Args:
value: The new value of the global variable.
"""
name = value.name
global_variables = self.global_variables_dict
global_variables[name] = value
self.global_variables = json.dumps(global_variables)

def remove_global_variable(self, name: str):
"""
Remove a global variable by name.
Args:
name: The name of the global variable.
"""
global_variables = self.global_variables_dict
global_variables.pop(name, None)
self.global_variables = json.dumps(global_variables)

class WorkflowRunTriggeredFrom(Enum):
"""
Workflow Run Triggered From Enum
Expand Down

0 comments on commit 8c6fb80

Please sign in to comment.