Skip to content

Commit

Permalink
Feat/model as tool (#2744)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Mar 8, 2024
1 parent 3231a8c commit 40c646c
Show file tree
Hide file tree
Showing 26 changed files with 839 additions and 42 deletions.
26 changes: 26 additions & 0 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ def get(self, provider):
icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=minetype)

class ToolModelProviderIconApi(Resource):
@setup_required
def get(self, provider):
icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider)
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype)

class ToolModelProviderListToolsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user_id = current_user.id
tenant_id = current_user.current_tenant_id

parser = reqparse.RequestParser()
parser.add_argument('provider', type=str, required=True, nullable=False, location='args')

args = parser.parse_args()

return ToolManageService.list_model_tool_provider_tools(
user_id,
tenant_id,
args['provider'],
)

class ToolApiProviderAddApi(Resource):
@setup_required
Expand Down Expand Up @@ -283,6 +307,8 @@ def post(self):
api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin/<provider>/update')
api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin/<provider>/credentials_schema')
api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin/<provider>/icon')
api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model/<provider>/icon')
api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools')
api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add')
api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote')
api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools')
Expand Down
7 changes: 5 additions & 2 deletions api/core/model_runtime/entities/model_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ModelType(Enum):
SPEECH2TEXT = "speech2text"
MODERATION = "moderation"
TTS = "tts"
# TEXT2IMG = "text2img"
TEXT2IMG = "text2img"

@classmethod
def value_of(cls, origin_model_type: str) -> "ModelType":
Expand All @@ -36,6 +36,8 @@ def value_of(cls, origin_model_type: str) -> "ModelType":
return cls.SPEECH2TEXT
elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value:
return cls.TTS
elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value:
return cls.TEXT2IMG
elif origin_model_type == cls.MODERATION.value:
return cls.MODERATION
else:
Expand All @@ -59,10 +61,11 @@ def to_origin_model_type(self) -> str:
return 'tts'
elif self == self.MODERATION:
return 'moderation'
elif self == self.TEXT2IMG:
return 'text2img'
else:
raise ValueError(f'invalid model type {self}')


class FetchFrom(Enum):
"""
Enum class for fetch from.
Expand Down
48 changes: 48 additions & 0 deletions api/core/model_runtime/model_providers/__base/text2img_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import abstractmethod
from typing import IO, Optional

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel


class Text2ImageModel(AIModel):
"""
Model class for text2img model.
"""
model_type: ModelType = ModelType.TEXT2IMG

def invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
try:
return self._invoke(model, credentials, prompt, model_parameters, user)
except Exception as e:
raise self._transform_invoke_error(e)

@abstractmethod
def _invoke(self, model: str, credentials: dict, prompt: str,
model_parameters: dict, user: Optional[str] = None) \
-> list[IO[bytes]]:
"""
Invoke Text2Image model
:param model: model name
:param credentials: model credentials
:param prompt: prompt for image generation
:param model_parameters: model parameters
:param user: unique user id
:return: image bytes
"""
raise NotImplementedError
6 changes: 5 additions & 1 deletion api/core/tools/entities/common_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ class I18nObject(BaseModel):
Model class for i18n object.
"""
zh_Hans: Optional[str] = None
pt_BR: Optional[str] = None
en_US: str

def __init__(self, **data):
super().__init__(**data)
if not self.zh_Hans:
self.zh_Hans = self.en_US
if not self.pt_BR:
self.pt_BR = self.en_US

def to_dict(self) -> dict:
return {
'zh_Hans': self.zh_Hans,
'en_US': self.en_US,
}
'pt_BR': self.pt_BR
}
22 changes: 21 additions & 1 deletion api/core/tools/entities/tool_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,24 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None:
value=value,
)

self.pool.append(variable)
self.pool.append(variable)

class ModelToolPropertyKey(Enum):
IMAGE_PARAMETER_NAME = "image_parameter_name"

class ModelToolConfiguration(BaseModel):
"""
Model tool configuration
"""
type: str = Field(..., description="The type of the model tool")
model: str = Field(..., description="The model")
label: I18nObject = Field(..., description="The label of the model tool")
properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool")

class ModelToolProviderConfiguration(BaseModel):
"""
Model tool provider configuration
"""
provider: str = Field(..., description="The provider of the model tool")
models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool")
label: I18nObject = Field(..., description="The label of the model tool")
1 change: 1 addition & 0 deletions api/core/tools/entities/user_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ProviderType(Enum):
BUILTIN = "builtin"
APP = "app"
API = "api"
MODEL = "model"

id: str
author: str
Expand Down
20 changes: 20 additions & 0 deletions api/core/tools/model_tools/anthropic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
provider: anthropic
label:
en_US: Anthropic Model Tools
zh_Hans: Anthropic 模型能力
pt_BR: Anthropic Model Tools
models:
- type: llm
model: claude-3-sonnet-20240229
label:
zh_Hans: Claude3 Sonnet 视觉
en_US: Claude3 Sonnet Vision
properties:
image_parameter_name: image_id
- type: llm
model: claude-3-opus-20240229
label:
zh_Hans: Claude3 Opus 视觉
en_US: Claude3 Opus Vision
properties:
image_parameter_name: image_id
13 changes: 13 additions & 0 deletions api/core/tools/model_tools/google.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
provider: google
label:
en_US: Google Model Tools
zh_Hans: Google 模型能力
pt_BR: Google Model Tools
models:
- type: llm
model: gemini-pro-vision
label:
zh_Hans: Gemini Pro 视觉
en_US: Gemini Pro Vision
properties:
image_parameter_name: image_id
13 changes: 13 additions & 0 deletions api/core/tools/model_tools/openai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
provider: openai
label:
en_US: OpenAI Model Tools
zh_Hans: OpenAI 模型能力
pt_BR: OpenAI Model Tools
models:
- type: llm
model: gpt-4-vision-preview
label:
zh_Hans: GPT-4 视觉
en_US: GPT-4 Vision
properties:
image_parameter_name: image_id
13 changes: 13 additions & 0 deletions api/core/tools/model_tools/zhipuai.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
provider: zhipuai
label:
en_US: ZhipuAI Model Tools
zh_Hans: ZhipuAI 模型能力
pt_BR: ZhipuAI Model Tools
models:
- type: llm
model: glm-4v
label:
zh_Hans: GLM-4 视觉
en_US: GLM-4 Vision
properties:
image_parameter_name: image_id
10 changes: 7 additions & 3 deletions api/core/tools/provider/_position.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
- google
- bing
- duckduckgo
- yahoo
- dalle
- azuredalle
- wikipedia
- model.openai
- model.google
- model.anthropic
- yahoo
- arxiv
- pubmed
- dalle
- azuredalle
- stablediffusion
- webscraper
- model.zhipuai
- aippt
- youtube
- wolframalpha
Expand Down
18 changes: 9 additions & 9 deletions api/core/tools/provider/builtin/_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@

from core.tools.entities.user_entities import UserToolProvider

position = {}

class BuiltinToolProviderSort:
@staticmethod
def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
global position
if not position:
_position = {}

@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
position = tmp_position
cls._position = tmp_position

def sort_compare(provider: UserToolProvider) -> int:
# if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000)
return position.get(provider.name, 10000)
if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000)
return cls._position.get(provider.name, 10000)

sorted_providers = sorted(providers, key=sort_compare)

Expand Down
Loading

0 comments on commit 40c646c

Please sign in to comment.