diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 3201426dfa99a5..8eae1216d09e89 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -5,7 +5,7 @@ import httpx from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Headers, NotGiven +from ..core._base_type import NOT_GIVEN, Body, Headers, NotGiven from ..core._http_client import make_user_request_input from ..types.image import ImagesResponded @@ -28,7 +28,9 @@ def generations( size: Optional[str] | NotGiven = NOT_GIVEN, style: Optional[str] | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, + extra_body: Body | None = None, disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ImagesResponded: @@ -46,9 +48,12 @@ def generations( "size": size, "style": style, "user": user, + "request_id": request_id, }, options=make_user_request_input( - extra_headers=extra_headers, timeout=timeout + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout ), cast_type=_cast_type, enable_stream=False, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 924d00912327c4..263fe829901c83 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -11,7 +11,7 @@ from tenacity.stop import stop_after_attempt from . import _errors -from ._base_type import NOT_GIVEN, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT +from ._base_type import NOT_GIVEN, AnyMapping, Body, Data, Headers, NotGiven, Query, RequestFiles, ResponseT from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError from ._files import make_httpx_files from ._request_opt import ClientRequestParam, UserRequestInput @@ -358,6 +358,7 @@ def make_user_request_input( max_retries: int | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, extra_headers: Headers = None, + extra_body: Body | None = None, query: Query | None = None, ) -> UserRequestInput: options: UserRequestInput = {} @@ -370,5 +371,7 @@ def make_user_request_input( options['timeout'] = timeout if query is not None: options["params"] = query + if extra_body is not None: + options["extra_json"] = cast(AnyMapping, extra_body) return options diff --git a/api/core/tools/provider/builtin/cogview/__init__.py b/api/core/tools/provider/builtin/cogview/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/tools/provider/builtin/cogview/_assets/icon.png b/api/core/tools/provider/builtin/cogview/_assets/icon.png new file mode 100644 index 00000000000000..f0c1c24a02fc83 Binary files /dev/null and b/api/core/tools/provider/builtin/cogview/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/cogview/cogview.py b/api/core/tools/provider/builtin/cogview/cogview.py new file mode 100644 index 00000000000000..801817ec06ed36 --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/cogview.py @@ -0,0 +1,27 @@ +""" Provide the input parameters type for the cogview provider class """ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.cogview.tools.cogview3 import CogView3Tool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class COGVIEWProvider(BuiltinToolProviderController): + """ cogview provider """ + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + CogView3Tool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "prompt": "一个城市在水晶瓶中欢快生活的场景,水彩画风格,展现出微观与珠宝般的美丽。", + "size": "square", + "n": 1 + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) from e + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/cogview/cogview.yaml b/api/core/tools/provider/builtin/cogview/cogview.yaml new file mode 100644 index 00000000000000..374b0e98d9122c --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/cogview.yaml @@ -0,0 +1,61 @@ +identity: + author: Waffle + name: cogview + label: + en_US: CogView + zh_Hans: CogView 绘画 + pt_BR: CogView + description: + en_US: CogView art + zh_Hans: CogView 绘画 + pt_BR: CogView art + icon: icon.png + tags: + - image + - productivity +credentials_for_provider: + zhipuai_api_key: + type: secret-input + required: true + label: + en_US: ZhipuAI API key + zh_Hans: ZhipuAI API key + pt_BR: ZhipuAI API key + help: + en_US: Please input your ZhipuAI API key + zh_Hans: 请输入你的 ZhipuAI API key + pt_BR: Please input your ZhipuAI API key + placeholder: + en_US: Please input your ZhipuAI API key + zh_Hans: 请输入你的 ZhipuAI API key + pt_BR: Please input your ZhipuAI API key + zhipuai_organizaion_id: + type: text-input + required: false + label: + en_US: ZhipuAI organization ID + zh_Hans: ZhipuAI organization ID + pt_BR: ZhipuAI organization ID + help: + en_US: Please input your ZhipuAI organization ID + zh_Hans: 请输入你的 ZhipuAI organization ID + pt_BR: Please input your ZhipuAI organization ID + placeholder: + en_US: Please input your ZhipuAI organization ID + zh_Hans: 请输入你的 ZhipuAI organization ID + pt_BR: Please input your ZhipuAI organization ID + zhipuai_base_url: + type: text-input + required: false + label: + en_US: ZhipuAI base URL + zh_Hans: ZhipuAI base URL + pt_BR: ZhipuAI base URL + help: + en_US: Please input your ZhipuAI base URL + zh_Hans: 请输入你的 ZhipuAI base URL + pt_BR: Please input your ZhipuAI base URL + placeholder: + en_US: Please input your ZhipuAI base URL + zh_Hans: 请输入你的 ZhipuAI base URL + pt_BR: Please input your ZhipuAI base URL diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py new file mode 100644 index 00000000000000..bb2720196f244f --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -0,0 +1,69 @@ +import random +from typing import Any, Union + +from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class CogView3Tool(BuiltinTool): + """ CogView3 Tool """ + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke CogView3 tool + """ + client = ZhipuAI( + base_url=self.runtime.credentials['zhipuai_base_url'], + api_key=self.runtime.credentials['zhipuai_api_key'], + ) + size_mapping = { + 'square': '1024x1024', + 'vertical': '1024x1792', + 'horizontal': '1792x1024', + } + # prompt + prompt = tool_parameters.get('prompt', '') + if not prompt: + return self.create_text_message('Please input prompt') + # get size + print(tool_parameters.get('prompt', 'square')) + size = size_mapping[tool_parameters.get('size', 'square')] + # get n + n = tool_parameters.get('n', 1) + # get quality + quality = tool_parameters.get('quality', 'standard') + if quality not in ['standard', 'hd']: + return self.create_text_message('Invalid quality') + # get style + style = tool_parameters.get('style', 'vivid') + if style not in ['natural', 'vivid']: + return self.create_text_message('Invalid style') + # set extra body + seed_id = tool_parameters.get('seed_id', self._generate_random_id(8)) + extra_body = {'seed': seed_id} + response = client.images.generations( + prompt=prompt, + model="cogview-3", + size=size, + n=n, + extra_body=extra_body, + style=style, + quality=quality, + response_format='b64_json' + ) + result = [] + for image in response.data: + result.append(self.create_image_message(image=image.url)) + result.append(self.create_text_message( + f'\nGenerate image source to Seed ID: {seed_id}')) + return result + + @staticmethod + def _generate_random_id(length=8): + characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' + random_id = ''.join(random.choices(characters, k=length)) + return random_id diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml new file mode 100644 index 00000000000000..ba0b271a1c716c --- /dev/null +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.yaml @@ -0,0 +1,123 @@ +identity: + name: cogview3 + author: Waffle + label: + en_US: CogView 3 + zh_Hans: CogView 3 绘画 + pt_BR: CogView 3 + description: + en_US: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt + zh_Hans: CogView 3 是一个强大的绘画工具,它可以根据您的提示词绘制出您想要的图像 + pt_BR: CogView 3 is a powerful drawing tool that can draw the image you want based on your prompt +description: + human: + en_US: CogView 3 is a text to image tool + zh_Hans: CogView 3 是一个文本到图像的工具 + pt_BR: CogView 3 is a text to image tool + llm: CogView 3 is a tool used to generate images from text +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + pt_BR: Prompt + human_description: + en_US: Image prompt, you can check the official documentation of CogView 3 + zh_Hans: 图像提示词,您可以查看CogView 3 的官方文档 + pt_BR: Image prompt, you can check the official documentation of CogView 3 + llm_description: Image prompt of CogView 3, you should describe the image you want to generate as a list of words as possible as detailed + form: llm + - name: size + type: select + required: true + human_description: + en_US: selecting the image size + zh_Hans: 选择图像大小 + pt_BR: selecting the image size + label: + en_US: Image size + zh_Hans: 图像大小 + pt_BR: Image size + form: form + options: + - value: square + label: + en_US: Squre(1024x1024) + zh_Hans: 方(1024x1024) + pt_BR: Squre(1024x1024) + - value: vertical + label: + en_US: Vertical(1024x1792) + zh_Hans: 竖屏(1024x1792) + pt_BR: Vertical(1024x1792) + - value: horizontal + label: + en_US: Horizontal(1792x1024) + zh_Hans: 横屏(1792x1024) + pt_BR: Horizontal(1792x1024) + default: square + - name: n + type: number + required: true + human_description: + en_US: selecting the number of images + zh_Hans: 选择图像数量 + pt_BR: selecting the number of images + label: + en_US: Number of images + zh_Hans: 图像数量 + pt_BR: Number of images + form: form + min: 1 + max: 1 + default: 1 + - name: quality + type: select + required: true + human_description: + en_US: selecting the image quality + zh_Hans: 选择图像质量 + pt_BR: selecting the image quality + label: + en_US: Image quality + zh_Hans: 图像质量 + pt_BR: Image quality + form: form + options: + - value: standard + label: + en_US: Standard + zh_Hans: 标准 + pt_BR: Standard + - value: hd + label: + en_US: HD + zh_Hans: 高清 + pt_BR: HD + default: standard + - name: style + type: select + required: true + human_description: + en_US: selecting the image style + zh_Hans: 选择图像风格 + pt_BR: selecting the image style + label: + en_US: Image style + zh_Hans: 图像风格 + pt_BR: Image style + form: form + options: + - value: vivid + label: + en_US: Vivid + zh_Hans: 生动 + pt_BR: Vivid + - value: natural + label: + en_US: Natural + zh_Hans: 自然 + pt_BR: Natural + default: vivid