Skip to content

Commit

Permalink
Feat/add zhipu CogView 3 tool (#6210)
Browse files Browse the repository at this point in the history
  • Loading branch information
ox01024 authored Jul 13, 2024
1 parent a7b33b5 commit 07add06
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import httpx

from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Headers, NotGiven
from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven
from ..core._http_client import make_user_request_input
from ..types.image import ImagesResponded

Expand All @@ -28,7 +28,9 @@ def generations(
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
extra_body: Body | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
Expand All @@ -46,9 +48,12 @@ def generations(
"size": size,
"style": style,
"user": user,
"request_id": request_id,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tenacity.stop import stop_after_attempt

from . import _errors
from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import make_httpx_files
from ._request_opt import ClientRequestParam, UserRequestInput
Expand Down Expand Up @@ -358,6 +358,7 @@ def make_user_request_input(
max_retries: int | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
extra_headers: Headers = None,
extra_body: Body | None = None,
query: Query | None = None,
) -> UserRequestInput:
options: UserRequestInput = {}
Expand All @@ -370,5 +371,7 @@ def make_user_request_input(
options['timeout'] = timeout
if query is not None:
options["params"] = query
if extra_body is not None:
options["extra_json"] = cast(AnyMapping, extra_body)

return options
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions api/core/tools/provider/builtin/cogview/cogview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
""" Provide the input parameters type for the cogview provider class """
from typing import Any

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController


class COGVIEWProvider(BuiltinToolProviderController):
""" cogview provider """
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
CogView3Tool().fork_tool_runtime(
runtime={
"credentials": credentials,
}
).invoke(
user_id='',
tool_parameters={
"prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。",
"size": "square",
"n": 1
},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e)) from e

61 changes: 61 additions & 0 deletions api/core/tools/provider/builtin/cogview/cogview.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
identity:
author: Waffle
name: cogview
label:
en_US: CogView
zh_Hans: CogView 绘画
pt_BR: CogView
description:
en_US: CogView art
zh_Hans: CogView 绘画
pt_BR: CogView art
icon: icon.png
tags:
- image
- productivity
credentials_for_provider:
zhipuai_api_key:
type: secret-input
required: true
label:
en_US: ZhipuAI API key
zh_Hans: ZhipuAI API key
pt_BR: ZhipuAI API key
help:
en_US: Please input your ZhipuAI API key
zh_Hans: 请输入你的 ZhipuAI API key
pt_BR: Please input your ZhipuAI API key
placeholder:
en_US: Please input your ZhipuAI API key
zh_Hans: 请输入你的 ZhipuAI API key
pt_BR: Please input your ZhipuAI API key
zhipuai_organizaion_id:
type: text-input
required: false
label:
en_US: ZhipuAI organization ID
zh_Hans: ZhipuAI organization ID
pt_BR: ZhipuAI organization ID
help:
en_US: Please input your ZhipuAI organization ID
zh_Hans: 请输入你的 ZhipuAI organization ID
pt_BR: Please input your ZhipuAI organization ID
placeholder:
en_US: Please input your ZhipuAI organization ID
zh_Hans: 请输入你的 ZhipuAI organization ID
pt_BR: Please input your ZhipuAI organization ID
zhipuai_base_url:
type: text-input
required: false
label:
en_US: ZhipuAI base URL
zh_Hans: ZhipuAI base URL
pt_BR: ZhipuAI base URL
help:
en_US: Please input your ZhipuAI base URL
zh_Hans: 请输入你的 ZhipuAI base URL
pt_BR: Please input your ZhipuAI base URL
placeholder:
en_US: Please input your ZhipuAI base URL
zh_Hans: 请输入你的 ZhipuAI base URL
pt_BR: Please input your ZhipuAI base URL
69 changes: 69 additions & 0 deletions api/core/tools/provider/builtin/cogview/tools/cogview3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import random
from typing import Any, Union

from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class CogView3Tool(BuiltinTool):
""" CogView3 Tool """

def _invoke(self,
user_id: str,
tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke CogView3 tool
"""
client = ZhipuAI(
base_url=self.runtime.credentials['zhipuai_base_url'],
api_key=self.runtime.credentials['zhipuai_api_key'],
)
size_mapping = {
'square': '1024x1024',
'vertical': '1024x1792',
'horizontal': '1792x1024',
}
# prompt
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
print(tool_parameters.get('prompt', 'square'))
size = size_mapping[tool_parameters.get('size', 'square')]
# get n
n = tool_parameters.get('n', 1)
# get quality
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
# get style
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')
# set extra body
seed_id = tool_parameters.get('seed_id', self._generate_random_id(8))
extra_body = {'seed': seed_id}
response = client.images.generations(
prompt=prompt,
model="cogview-3",
size=size,
n=n,
extra_body=extra_body,
style=style,
quality=quality,
response_format='b64_json'
)
result = []
for image in response.data:
result.append(self.create_image_message(image=image.url))
result.append(self.create_text_message(
f'\nGenerate image source to Seed ID: {seed_id}'))
return result

@staticmethod
def _generate_random_id(length=8):
characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
random_id = ''.join(random.choices(characters, k=length))
return random_id
123 changes: 123 additions & 0 deletions api/core/tools/provider/builtin/cogview/tools/cogview3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
identity:
name: cogview3
author: Waffle
label:
en_US: CogView 3
zh_Hans: CogView 3 绘画
pt_BR: CogView 3
description:
en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像
pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt
description:
human:
en_US: CogView 3 is a text to image tool
zh_Hans: CogView 3 是一个文本到图像的工具
pt_BR: CogView 3 is a text to image tool
llm: CogView 3 is a tool used to generate images from text
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
pt_BR: Prompt
human_description:
en_US: Image prompt, you can check the official documentation of CogView 3
zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档
pt_BR: Image prompt, you can check the official documentation of CogView 3
llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed
form: llm
- name: size
type: select
required: true
human_description:
en_US: selecting the image size
zh_Hans: 选择图像大小
pt_BR: selecting the image size
label:
en_US: Image size
zh_Hans: 图像大小
pt_BR: Image size
form: form
options:
- value: square
label:
en_US: Squre(1024x1024)
zh_Hans: 方(1024x1024)
pt_BR: Squre(1024x1024)
- value: vertical
label:
en_US: Vertical(1024x1792)
zh_Hans: 竖屏(1024x1792)
pt_BR: Vertical(1024x1792)
- value: horizontal
label:
en_US: Horizontal(1792x1024)
zh_Hans: 横屏(1792x1024)
pt_BR: Horizontal(1792x1024)
default: square
- name: n
type: number
required: true
human_description:
en_US: selecting the number of images
zh_Hans: 选择图像数量
pt_BR: selecting the number of images
label:
en_US: Number of images
zh_Hans: 图像数量
pt_BR: Number of images
form: form
min: 1
max: 1
default: 1
- name: quality
type: select
required: true
human_description:
en_US: selecting the image quality
zh_Hans: 选择图像质量
pt_BR: selecting the image quality
label:
en_US: Image quality
zh_Hans: 图像质量
pt_BR: Image quality
form: form
options:
- value: standard
label:
en_US: Standard
zh_Hans: 标准
pt_BR: Standard
- value: hd
label:
en_US: HD
zh_Hans: 高清
pt_BR: HD
default: standard
- name: style
type: select
required: true
human_description:
en_US: selecting the image style
zh_Hans: 选择图像风格
pt_BR: selecting the image style
label:
en_US: Image style
zh_Hans: 图像风格
pt_BR: Image style
form: form
options:
- value: vivid
label:
en_US: Vivid
zh_Hans: 生动
pt_BR: Vivid
- value: natural
label:
en_US: Natural
zh_Hans: 自然
pt_BR: Natural
default: vivid

0 comments on commit 07add06

Please sign in to comment.