-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat/add zhipu CogView 3 tool (#6210)
- Loading branch information
Showing
8 changed files
with
291 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" Provide the input parameters type for the cogview provider class """ | ||
from typing import Any | ||
|
||
from core.tools.errors import ToolProviderCredentialValidationError | ||
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool | ||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController | ||
|
||
|
||
class COGVIEWProvider(BuiltinToolProviderController): | ||
""" cogview provider """ | ||
def _validate_credentials(self, credentials: dict[str, Any]) -> None: | ||
try: | ||
CogView3Tool().fork_tool_runtime( | ||
runtime={ | ||
"credentials": credentials, | ||
} | ||
).invoke( | ||
user_id='', | ||
tool_parameters={ | ||
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", | ||
"size": "square", | ||
"n": 1 | ||
}, | ||
) | ||
except Exception as e: | ||
raise ToolProviderCredentialValidationError(str(e)) from e | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
identity: | ||
author: Waffle | ||
name: cogview | ||
label: | ||
en_US: CogView | ||
zh_Hans: CogView 绘画 | ||
pt_BR: CogView | ||
description: | ||
en_US: CogView art | ||
zh_Hans: CogView 绘画 | ||
pt_BR: CogView art | ||
icon: icon.png | ||
tags: | ||
- image | ||
- productivity | ||
credentials_for_provider: | ||
zhipuai_api_key: | ||
type: secret-input | ||
required: true | ||
label: | ||
en_US: ZhipuAI API key | ||
zh_Hans: ZhipuAI API key | ||
pt_BR: ZhipuAI API key | ||
help: | ||
en_US: Please input your ZhipuAI API key | ||
zh_Hans: 请输入你的 ZhipuAI API key | ||
pt_BR: Please input your ZhipuAI API key | ||
placeholder: | ||
en_US: Please input your ZhipuAI API key | ||
zh_Hans: 请输入你的 ZhipuAI API key | ||
pt_BR: Please input your ZhipuAI API key | ||
zhipuai_organizaion_id: | ||
type: text-input | ||
required: false | ||
label: | ||
en_US: ZhipuAI organization ID | ||
zh_Hans: ZhipuAI organization ID | ||
pt_BR: ZhipuAI organization ID | ||
help: | ||
en_US: Please input your ZhipuAI organization ID | ||
zh_Hans: 请输入你的 ZhipuAI organization ID | ||
pt_BR: Please input your ZhipuAI organization ID | ||
placeholder: | ||
en_US: Please input your ZhipuAI organization ID | ||
zh_Hans: 请输入你的 ZhipuAI organization ID | ||
pt_BR: Please input your ZhipuAI organization ID | ||
zhipuai_base_url: | ||
type: text-input | ||
required: false | ||
label: | ||
en_US: ZhipuAI base URL | ||
zh_Hans: ZhipuAI base URL | ||
pt_BR: ZhipuAI base URL | ||
help: | ||
en_US: Please input your ZhipuAI base URL | ||
zh_Hans: 请输入你的 ZhipuAI base URL | ||
pt_BR: Please input your ZhipuAI base URL | ||
placeholder: | ||
en_US: Please input your ZhipuAI base URL | ||
zh_Hans: 请输入你的 ZhipuAI base URL | ||
pt_BR: Please input your ZhipuAI base URL |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import random | ||
from typing import Any, Union | ||
|
||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI | ||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class CogView3Tool(BuiltinTool): | ||
""" CogView3 Tool """ | ||
|
||
def _invoke(self, | ||
user_id: str, | ||
tool_parameters: dict[str, Any] | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
""" | ||
Invoke CogView3 tool | ||
""" | ||
client = ZhipuAI( | ||
base_url=self.runtime.credentials['zhipuai_base_url'], | ||
api_key=self.runtime.credentials['zhipuai_api_key'], | ||
) | ||
size_mapping = { | ||
'square': '1024x1024', | ||
'vertical': '1024x1792', | ||
'horizontal': '1792x1024', | ||
} | ||
# prompt | ||
prompt = tool_parameters.get('prompt', '') | ||
if not prompt: | ||
return self.create_text_message('Please input prompt') | ||
# get size | ||
print(tool_parameters.get('prompt', 'square')) | ||
size = size_mapping[tool_parameters.get('size', 'square')] | ||
# get n | ||
n = tool_parameters.get('n', 1) | ||
# get quality | ||
quality = tool_parameters.get('quality', 'standard') | ||
if quality not in ['standard', 'hd']: | ||
return self.create_text_message('Invalid quality') | ||
# get style | ||
style = tool_parameters.get('style', 'vivid') | ||
if style not in ['natural', 'vivid']: | ||
return self.create_text_message('Invalid style') | ||
# set extra body | ||
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) | ||
extra_body = {'seed': seed_id} | ||
response = client.images.generations( | ||
prompt=prompt, | ||
model="cogview-3", | ||
size=size, | ||
n=n, | ||
extra_body=extra_body, | ||
style=style, | ||
quality=quality, | ||
response_format='b64_json' | ||
) | ||
result = [] | ||
for image in response.data: | ||
result.append(self.create_image_message(image=image.url)) | ||
result.append(self.create_text_message( | ||
f'\nGenerate image source to Seed ID: {seed_id}')) | ||
return result | ||
|
||
@staticmethod | ||
def _generate_random_id(length=8): | ||
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' | ||
random_id = ''.join(random.choices(characters, k=length)) | ||
return random_id |
123 changes: 123 additions & 0 deletions
123
api/core/tools/provider/builtin/cogview/tools/cogview3.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
identity: | ||
name: cogview3 | ||
author: Waffle | ||
label: | ||
en_US: CogView 3 | ||
zh_Hans: CogView 3 绘画 | ||
pt_BR: CogView 3 | ||
description: | ||
en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt | ||
zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像 | ||
pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt | ||
description: | ||
human: | ||
en_US: CogView 3 is a text to image tool | ||
zh_Hans: CogView 3 是一个文本到图像的工具 | ||
pt_BR: CogView 3 is a text to image tool | ||
llm: CogView 3 is a tool used to generate images from text | ||
parameters: | ||
- name: prompt | ||
type: string | ||
required: true | ||
label: | ||
en_US: Prompt | ||
zh_Hans: 提示词 | ||
pt_BR: Prompt | ||
human_description: | ||
en_US: Image prompt, you can check the official documentation of CogView 3 | ||
zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档 | ||
pt_BR: Image prompt, you can check the official documentation of CogView 3 | ||
llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed | ||
form: llm | ||
- name: size | ||
type: select | ||
required: true | ||
human_description: | ||
en_US: selecting the image size | ||
zh_Hans: 选择图像大小 | ||
pt_BR: selecting the image size | ||
label: | ||
en_US: Image size | ||
zh_Hans: 图像大小 | ||
pt_BR: Image size | ||
form: form | ||
options: | ||
- value: square | ||
label: | ||
en_US: Squre(1024x1024) | ||
zh_Hans: 方(1024x1024) | ||
pt_BR: Squre(1024x1024) | ||
- value: vertical | ||
label: | ||
en_US: Vertical(1024x1792) | ||
zh_Hans: 竖屏(1024x1792) | ||
pt_BR: Vertical(1024x1792) | ||
- value: horizontal | ||
label: | ||
en_US: Horizontal(1792x1024) | ||
zh_Hans: 横屏(1792x1024) | ||
pt_BR: Horizontal(1792x1024) | ||
default: square | ||
- name: n | ||
type: number | ||
required: true | ||
human_description: | ||
en_US: selecting the number of images | ||
zh_Hans: 选择图像数量 | ||
pt_BR: selecting the number of images | ||
label: | ||
en_US: Number of images | ||
zh_Hans: 图像数量 | ||
pt_BR: Number of images | ||
form: form | ||
min: 1 | ||
max: 1 | ||
default: 1 | ||
- name: quality | ||
type: select | ||
required: true | ||
human_description: | ||
en_US: selecting the image quality | ||
zh_Hans: 选择图像质量 | ||
pt_BR: selecting the image quality | ||
label: | ||
en_US: Image quality | ||
zh_Hans: 图像质量 | ||
pt_BR: Image quality | ||
form: form | ||
options: | ||
- value: standard | ||
label: | ||
en_US: Standard | ||
zh_Hans: 标准 | ||
pt_BR: Standard | ||
- value: hd | ||
label: | ||
en_US: HD | ||
zh_Hans: 高清 | ||
pt_BR: HD | ||
default: standard | ||
- name: style | ||
type: select | ||
required: true | ||
human_description: | ||
en_US: selecting the image style | ||
zh_Hans: 选择图像风格 | ||
pt_BR: selecting the image style | ||
label: | ||
en_US: Image style | ||
zh_Hans: 图像风格 | ||
pt_BR: Image style | ||
form: form | ||
options: | ||
- value: vivid | ||
label: | ||
en_US: Vivid | ||
zh_Hans: 生动 | ||
pt_BR: Vivid | ||
- value: natural | ||
label: | ||
en_US: Natural | ||
zh_Hans: 自然 | ||
pt_BR: Natural | ||
default: vivid |