Skip to content

Commit

Permalink
Merge branch 'fix/workflow-app-run' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
zxhlyh committed May 31, 2024
2 parents 6b1af38 + bc26667 commit 4a301a2
Show file tree
Hide file tree
Showing 28 changed files with 1,740 additions and 1,148 deletions.
5 changes: 3 additions & 2 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
SMTP_SERVER=smtp.gmail.com
SMTP_PORT=587
SMTP_PORT=465
SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=false
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false

# Sentry configuration
SENTRY_DSN=
Expand Down
1 change: 1 addition & 0 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def __init__(self):
self.SMTP_USERNAME = get_env('SMTP_USERNAME')
self.SMTP_PASSWORD = get_env('SMTP_PASSWORD')
self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS')
self.SMTP_OPPORTUNISTIC_TLS = get_bool_env('SMTP_OPPORTUNISTIC_TLS')

# ------------------------
# Workspace Configurations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,76 @@
from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.tool.builtin_tool import BuiltinTool

# All commented out parameters default to null
DRAW_TEXT_OPTIONS = {
# Prompts
"prompt": "",
"negative_prompt": "",
# "styles": [],
# Seeds
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
'sampler_index': 'DPM++ SDE Karras',
"seed_resize_from_w": -1,

# Samplers
# "sampler_name": "DPM++ 2M",
# "scheduler": "",
# "sampler_index": "Automatic",

# Latent Space Options
"batch_size": 1,
"n_iter": 1,
"steps": 10,
"cfg_scale": 7,
"width": 1024,
"height": 1024,
"restore_faces": False,
"width": 512,
"height": 512,
# "restore_faces": True,
# "tiling": True,
"do_not_save_samples": False,
"do_not_save_grid": False,
"eta": 0,
"denoising_strength": 0,
"s_min_uncond": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 0,
# "eta": 0,
# "denoising_strength": 0.75,
# "s_min_uncond": 0,
# "s_churn": 0,
# "s_tmax": 0,
# "s_tmin": 0,
# "s_noise": 0,
"override_settings": {},
"override_settings_restore_afterwards": True,
# Refinement Options
"refiner_checkpoint": "",
"refiner_switch_at": 0,
"disable_extra_networks": False,
"comments": {},
# "firstpass_image": "",
# "comments": "",
# High-Resolution Options
"enable_hr": False,
"firstphase_width": 0,
"firstphase_height": 0,
"hr_scale": 2,
# "hr_upscaler": "",
"hr_second_pass_steps": 0,
"hr_resize_x": 0,
"hr_resize_y": 0,
# "hr_checkpoint_name": "",
# "hr_sampler_name": "",
# "hr_scheduler": "",
"hr_prompt": "",
"hr_negative_prompt": "",
# Task Options
# "force_task_id": "",

# Script Options
# "script_name": "",
"script_args": [],
# Output Options
"send_images": True,
"save_images": False,
"alwayson_scripts": {}
"alwayson_scripts": {},
# "infotext": "",

}


Expand Down Expand Up @@ -88,60 +115,15 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
except Exception as e:
raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model')


# prompt
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')

# get negative prompt
negative_prompt = tool_parameters.get('negative_prompt', '')

# get size
width = tool_parameters.get('width', 1024)
height = tool_parameters.get('height', 1024)

# get steps
steps = tool_parameters.get('steps', 1)

# get lora
lora = tool_parameters.get('lora', '')

# get image id
# get image id and image variable
image_id = tool_parameters.get('image_id', '')
if image_id.strip():
image_variable = self.get_default_image_variable()
if image_variable:
image_binary = self.get_variable_file(image_variable.name)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')

# convert image to RGB
image = Image.open(io.BytesIO(image_binary))
image = image.convert("RGB")
buffer = io.BytesIO()
image.save(buffer, format="PNG")
image_binary = buffer.getvalue()
image.close()
image_variable = self.get_default_image_variable()
# Return text2img if there's no image ID or no image variable
if not image_id or not image_variable:
return self.text2img(base_url=base_url,tool_parameters=tool_parameters)

return self.img2img(base_url=base_url,
lora=lora,
image_binary=image_binary,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
model=model)

return self.text2img(base_url=base_url,
lora=lora,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps,
model=model)
# Proceed with image-to-image generation
return self.img2img(base_url=base_url,tool_parameters=tool_parameters)

def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Expand Down Expand Up @@ -197,35 +179,67 @@ def get_sd_models(self) -> list[str]:
except Exception as e:
return []

def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int, model: str) \
def img2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
"""
draw_options = {

# Fetch the binary data of the image
image_variable = self.get_default_image_variable()
image_binary = self.get_variable_file(image_variable.name)
if not image_binary:
return self.create_text_message('Image not found, please request user to generate image firstly.')

# Convert image to RGB and save as PNG
try:
with Image.open(io.BytesIO(image_binary)) as image:
with io.BytesIO() as buffer:
image.convert("RGB").save(buffer, format="PNG")
image_binary = buffer.getvalue()
except Exception as e:
return self.create_text_message(f"Failed to process the image: {str(e)}")

# copy draw options
draw_options = deepcopy(DRAW_TEXT_OPTIONS)
# set image options
model = tool_parameters.get('model', '')
draw_options_image = {
"init_images": [b64encode(image_binary).decode('utf-8')],
"prompt": "",
"negative_prompt": negative_prompt,
"denoising_strength": 0.9,
"width": width,
"height": height,
"cfg_scale": 7,
"sampler_name": "Euler a",
"restore_faces": False,
"steps": steps,
"script_args": ["outpainting mk2"],
"override_settings": {"sd_model_checkpoint": model}
"script_args": [],
"override_settings": {"sd_model_checkpoint": model},
"resize_mode":0,
"image_cfg_scale": 0,
# "mask": None,
"mask_blur_x": 4,
"mask_blur_y": 4,
"mask_blur": 0,
"mask_round": True,
"inpainting_fill": 0,
"inpaint_full_res": True,
"inpaint_full_res_padding": 0,
"inpainting_mask_invert": 0,
"initial_noise_multiplier": 0,
# "latent_mask": None,
"include_init_images": True,
}
# update key and values
draw_options.update(draw_options_image)
draw_options.update(tool_parameters)

# get prompt lora model
prompt = tool_parameters.get('prompt', '')
lora = tool_parameters.get('lora', '')
model = tool_parameters.get('model', '')
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt

try:
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img')
response = post(url, data=json.dumps(draw_options), timeout=120)
if response.status_code != 200:
return self.create_text_message('Failed to generate image')
Expand All @@ -239,24 +253,24 @@ def img2img(self, base_url: str, lora: str, image_binary: bytes,
except Exception as e:
return self.create_text_message('Failed to generate image')

def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \
def text2img(self, base_url: str, tool_parameters: dict[str, Any]) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
"""
# copy draw options
draw_options = deepcopy(DRAW_TEXT_OPTIONS)

draw_options.update(tool_parameters)
# get prompt lora model
prompt = tool_parameters.get('prompt', '')
lora = tool_parameters.get('lora', '')
model = tool_parameters.get('model', '')
if lora:
draw_options['prompt'] = f'{lora},{prompt}'
else:
draw_options['prompt'] = prompt

draw_options['width'] = width
draw_options['height'] = height
draw_options['steps'] = steps
draw_options['negative_prompt'] = negative_prompt
draw_options['override_settings']['sd_model_checkpoint'] = model


try:
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
Expand Down
4 changes: 2 additions & 2 deletions api/core/workflow/nodes/variable_aggregator/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData


class AdvancedSetting(BaseModel):
class AdvancedSettings(BaseModel):
"""
Advanced setting.
"""
Expand All @@ -30,4 +30,4 @@ class VariableAssignerNodeData(BaseNodeData):
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]
advanced_setting: Optional[AdvancedSetting]
advanced_settings: Optional[AdvancedSettings]
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
outputs = {}
inputs = {}

if not node_data.advanced_setting or node_data.advanced_setting.group_enabled:
if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled:
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)

Expand All @@ -32,12 +32,14 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
}
break
else:
for group in node_data.advanced_setting.groups:
for group in node_data.advanced_settings.groups:
for variable in group.variables:
value = variable_pool.get_variable_value(variable)

if value is not None:
outputs[f'{group.group_name}_output'] = value
outputs[group.group_name] = {
'output': value
}
inputs['.'.join(variable[1:])] = value
break

Expand Down
5 changes: 4 additions & 1 deletion api/extensions/ext_mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ def init_app(self, app: Flask):
from libs.smtp import SMTPClient
if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'):
raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type')
if not app.config.get('SMTP_USE_TLS') and app.config.get('SMTP_OPPORTUNISTIC_TLS'):
raise ValueError('SMTP_OPPORTUNISTIC_TLS is not supported without enabling SMTP_USE_TLS')
self._client = SMTPClient(
server=app.config.get('SMTP_SERVER'),
port=app.config.get('SMTP_PORT'),
username=app.config.get('SMTP_USERNAME'),
password=app.config.get('SMTP_PASSWORD'),
_from=app.config.get('MAIL_DEFAULT_SEND_FROM'),
use_tls=app.config.get('SMTP_USE_TLS')
use_tls=app.config.get('SMTP_USE_TLS'),
opportunistic_tls=app.config.get('SMTP_OPPORTUNISTIC_TLS')
)
else:
raise ValueError('Unsupported mail type {}'.format(app.config.get('MAIL_TYPE')))
Expand Down
17 changes: 12 additions & 5 deletions api/libs/smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,27 @@


class SMTPClient:
def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False):
def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False):
self.server = server
self.port = port
self._from = _from
self.username = username
self.password = password
self._use_tls = use_tls
self.use_tls = use_tls
self.opportunistic_tls = opportunistic_tls

def send(self, mail: dict):
smtp = None
try:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
if self._use_tls:
smtp.starttls()
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
smtp.starttls()
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
else:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)

if self.username and self.password:
smtp.login(self.username, self.password)

Expand Down
3 changes: 2 additions & 1 deletion api/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def latest_process_rule(self):

@property
def app_count(self):
return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id).scalar()
return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id,
App.id == AppDatasetJoin.app_id).scalar()

@property
def document_count(self):
Expand Down
Loading

0 comments on commit 4a301a2

Please sign in to comment.