Skip to content

Commit

Permalink
enhance: use override_settings for concurrent stable diffusion (#2818)
Browse files Browse the repository at this point in the history
  • Loading branch information
QunBB authored Mar 14, 2024
1 parent 4fe585a commit 1e5455e
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,17 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps)
steps=steps,
model=model)

return self.text2img(base_url=base_url,
lora=lora,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
steps=steps)
steps=steps,
model=model)

def validate_models(self) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Expand Down Expand Up @@ -197,7 +199,7 @@ def get_sd_models(self) -> list[str]:

def img2img(self, base_url: str, lora: str, image_binary: bytes,
prompt: str, negative_prompt: str,
width: int, height: int, steps: int) \
width: int, height: int, steps: int, model: str) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
Expand All @@ -213,7 +215,8 @@ def img2img(self, base_url: str, lora: str, image_binary: bytes,
"sampler_name": "Euler a",
"restore_faces": False,
"steps": steps,
"script_args": ["outpainting mk2"]
"script_args": ["outpainting mk2"],
"override_settings": {"sd_model_checkpoint": model}
}

if lora:
Expand All @@ -236,7 +239,7 @@ def img2img(self, base_url: str, lora: str, image_binary: bytes,
except Exception as e:
return self.create_text_message('Failed to generate image')

def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \
def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \
-> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
generate image
Expand All @@ -253,6 +256,7 @@ def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str,
draw_options['height'] = height
draw_options['steps'] = steps
draw_options['negative_prompt'] = negative_prompt
draw_options['override_settings']['sd_model_checkpoint'] = model

try:
url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img')
Expand Down

0 comments on commit 1e5455e

Please sign in to comment.