diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 817c75765afaac..931979c7f365e9 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -82,6 +82,30 @@ def get(self, provider): icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) return send_file(io.BytesIO(icon_bytes), mimetype=minetype) +class ToolModelProviderIconApi(Resource): + @setup_required + def get(self, provider): + icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider) + return send_file(io.BytesIO(icon_bytes), mimetype=mimetype) + +class ToolModelProviderListToolsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return ToolManageService.list_model_tool_provider_tools( + user_id, + tenant_id, + args['provider'], + ) class ToolApiProviderAddApi(Resource): @setup_required @@ -283,6 +307,8 @@ def post(self): api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model//icon') +api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools') api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52c2d66f9f0d74..60cb655c9899c8 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -17,7 +17,7 @@ class ModelType(Enum): SPEECH2TEXT = "speech2text" MODERATION = "moderation" TTS = "tts" - # TEXT2IMG = "text2img" + TEXT2IMG = "text2img" @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -36,6 +36,8 @@ def value_of(cls, origin_model_type: str) -> "ModelType": return cls.SPEECH2TEXT elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: return cls.TTS + elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: @@ -59,10 +61,11 @@ def to_origin_model_type(self) -> str: return 'tts' elif self == self.MODERATION: return 'moderation' + elif self == self.TEXT2IMG: + return 'text2img' else: raise ValueError(f'invalid model type {self}') - class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py new file mode 100644 index 00000000000000..972a2ea14ad73b --- /dev/null +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -0,0 +1,48 @@ +from abc import abstractmethod +from typing import IO, Optional + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.ai_model import AIModel + + +class Text2ImageModel(AIModel): + """ + Model class for text2img model. + """ + model_type: ModelType = ModelType.TEXT2IMG + + def invoke(self, model: str, credentials: dict, prompt: str, + model_parameters: dict, user: Optional[str] = None) \ + -> list[IO[bytes]]: + """ + Invoke Text2Image model + + :param model: model name + :param credentials: model credentials + :param prompt: prompt for image generation + :param model_parameters: model parameters + :param user: unique user id + + :return: image bytes + """ + try: + return self._invoke(model, credentials, prompt, model_parameters, user) + except Exception as e: + raise self._transform_invoke_error(e) + + @abstractmethod + def _invoke(self, model: str, credentials: dict, prompt: str, + model_parameters: dict, user: Optional[str] = None) \ + -> list[IO[bytes]]: + """ + Invoke Text2Image model + + :param model: model name + :param credentials: model credentials + :param prompt: prompt for image generation + :param model_parameters: model parameters + :param user: unique user id + + :return: image bytes + """ + raise NotImplementedError diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 13c27b57eef69e..55e31e8c35e5b7 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -8,15 +8,19 @@ class I18nObject(BaseModel): Model class for i18n object. """ zh_Hans: Optional[str] = None + pt_BR: Optional[str] = None en_US: str def __init__(self, **data): super().__init__(**data) if not self.zh_Hans: self.zh_Hans = self.en_US + if not self.pt_BR: + self.pt_BR = self.en_US def to_dict(self) -> dict: return { 'zh_Hans': self.zh_Hans, 'en_US': self.en_US, - } \ No newline at end of file + 'pt_BR': self.pt_BR + } diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index ad27706c3a2982..61b41f9cf4b211 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -304,4 +304,24 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: value=value, ) - self.pool.append(variable) \ No newline at end of file + self.pool.append(variable) + +class ModelToolPropertyKey(Enum): + IMAGE_PARAMETER_NAME = "image_parameter_name" + +class ModelToolConfiguration(BaseModel): + """ + Model tool configuration + """ + type: str = Field(..., description="The type of the model tool") + model: str = Field(..., description="The model") + label: I18nObject = Field(..., description="The label of the model tool") + properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + +class ModelToolProviderConfiguration(BaseModel): + """ + Model tool provider configuration + """ + provider: str = Field(..., description="The provider of the model tool") + models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") + label: I18nObject = Field(..., description="The label of the model tool") \ No newline at end of file diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py index 2641079333f35b..8a5589da274f0c 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/user_entities.py @@ -13,6 +13,7 @@ class ProviderType(Enum): BUILTIN = "builtin" APP = "app" API = "api" + MODEL = "model" id: str author: str diff --git a/api/core/tools/model_tools/anthropic.yaml b/api/core/tools/model_tools/anthropic.yaml new file mode 100644 index 00000000000000..4ccb973df5d457 --- /dev/null +++ b/api/core/tools/model_tools/anthropic.yaml @@ -0,0 +1,20 @@ +provider: anthropic +label: + en_US: Anthropic Model Tools + zh_Hans: Anthropic 模型能力 + pt_BR: Anthropic Model Tools +models: + - type: llm + model: claude-3-sonnet-20240229 + label: + zh_Hans: Claude3 Sonnet 视觉 + en_US: Claude3 Sonnet Vision + properties: + image_parameter_name: image_id + - type: llm + model: claude-3-opus-20240229 + label: + zh_Hans: Claude3 Opus 视觉 + en_US: Claude3 Opus Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/google.yaml b/api/core/tools/model_tools/google.yaml new file mode 100644 index 00000000000000..d81e1b0735984a --- /dev/null +++ b/api/core/tools/model_tools/google.yaml @@ -0,0 +1,13 @@ +provider: google +label: + en_US: Google Model Tools + zh_Hans: Google 模型能力 + pt_BR: Google Model Tools +models: + - type: llm + model: gemini-pro-vision + label: + zh_Hans: Gemini Pro 视觉 + en_US: Gemini Pro Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/openai.yaml b/api/core/tools/model_tools/openai.yaml new file mode 100644 index 00000000000000..45cbb295a98598 --- /dev/null +++ b/api/core/tools/model_tools/openai.yaml @@ -0,0 +1,13 @@ +provider: openai +label: + en_US: OpenAI Model Tools + zh_Hans: OpenAI 模型能力 + pt_BR: OpenAI Model Tools +models: + - type: llm + model: gpt-4-vision-preview + label: + zh_Hans: GPT-4 视觉 + en_US: GPT-4 Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/zhipuai.yaml b/api/core/tools/model_tools/zhipuai.yaml new file mode 100644 index 00000000000000..19a932eb89aa4b --- /dev/null +++ b/api/core/tools/model_tools/zhipuai.yaml @@ -0,0 +1,13 @@ +provider: zhipuai +label: + en_US: ZhipuAI Model Tools + zh_Hans: ZhipuAI 模型能力 + pt_BR: ZhipuAI Model Tools +models: + - type: llm + model: glm-4v + label: + zh_Hans: GLM-4 视觉 + en_US: GLM-4 Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index a418e678f21173..ece9dbe1596851 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,14 +1,18 @@ - google - bing - duckduckgo -- yahoo +- dalle +- azuredalle - wikipedia +- model.openai +- model.google +- model.anthropic +- yahoo - arxiv - pubmed -- dalle -- azuredalle - stablediffusion - webscraper +- model.zhipuai - aippt - youtube - wolframalpha diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index c6cad0187eb50a..fa2c5d27ef823e 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -4,24 +4,24 @@ from core.tools.entities.user_entities import UserToolProvider -position = {} class BuiltinToolProviderSort: - @staticmethod - def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]: - global position - if not position: + _position = {} + + @classmethod + def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: + if not cls._position: tmp_position = {} file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') with open(file_path) as f: for pos, val in enumerate(load(f, Loader=FullLoader)): tmp_position[val] = pos - position = tmp_position + cls._position = tmp_position def sort_compare(provider: UserToolProvider) -> int: - # if provider.type == UserToolProvider.ProviderType.MODEL: - # return position.get(f'model_provider.{provider.name}', 10000) - return position.get(provider.name, 10000) + if provider.type == UserToolProvider.ProviderType.MODEL: + return cls._position.get(f'model.{provider.name}', 10000) + return cls._position.get(provider.name, 10000) sorted_providers = sorted(providers, key=sort_compare) diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py new file mode 100644 index 00000000000000..733b2fd02a3f55 --- /dev/null +++ b/api/core/tools/provider/model_tool_provider.py @@ -0,0 +1,237 @@ +from typing import Any + +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ModelToolPropertyKey, + ToolDescription, + ToolIdentity, + ToolParameter, + ToolProviderCredentials, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolNotFoundError +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.model_tool import ModelTool +from core.tools.tool.tool import Tool +from core.tools.utils.configuration import ModelToolConfigurationManager + + +class ModelToolProviderController(ToolProviderController): + configuration: ProviderConfiguration = None + is_active: bool = False + + def __init__(self, configuration: ProviderConfiguration = None, **kwargs): + """ + init the provider + + :param data: the data of the provider + """ + super().__init__(**kwargs) + self.configuration = configuration + + @staticmethod + def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController': + """ + init the provider from db + + :param configuration: the configuration of the provider + """ + # check if all models are active + if configuration is None: + return None + is_active = True + models = configuration.get_provider_models() + for model in models: + if model.status != ModelStatus.ACTIVE: + is_active = False + break + + # get the provider configuration + model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) + if model_tool_configuration is None: + raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') + + # override the configuration + if model_tool_configuration.label: + if model_tool_configuration.label.en_US: + configuration.provider.label.en_US = model_tool_configuration.label.en_US + if model_tool_configuration.label.zh_Hans: + configuration.provider.label.zh_Hans = model_tool_configuration.label.zh_Hans + + return ModelToolProviderController( + is_active=is_active, + identity=ToolProviderIdentity( + author='Dify', + name=configuration.provider.provider, + description=I18nObject( + zh_Hans=f'{configuration.provider.label.zh_Hans} 模型能力提供商', + en_US=f'{configuration.provider.label.en_US} model capability provider' + ), + label=I18nObject( + zh_Hans=configuration.provider.label.zh_Hans, + en_US=configuration.provider.label.en_US + ), + icon=configuration.provider.icon_small.en_US, + ), + configuration=configuration, + credentials_schema={}, + ) + + @staticmethod + def is_configuration_valid(configuration: ProviderConfiguration) -> bool: + """ + check if the configuration has a model can be used as a tool + """ + models = configuration.get_provider_models() + for model in models: + if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): + return True + return False + + def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' + provider_manager = ProviderManager() + if self.configuration is None: + configurations = provider_manager.get_configurations(tenant_id=tenant_id).values() + self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None) + # get all tools + tools: list[ModelTool] = [] + # get all models + if not self.configuration: + return tools + configuration = self.configuration + + provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) + if provider_configuration is None: + raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') + + for model in configuration.get_provider_models(): + model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model) + if model_configuration is None: + continue + + if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): + provider_instance = configuration.get_provider_instance() + model_type_instance = provider_instance.get_model_instance(model.model_type) + provider_model_bundle = ProviderModelBundle( + configuration=configuration, + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + + try: + model_instance = ModelInstance(provider_model_bundle, model.model) + except ProviderTokenNotInitError: + model_instance = None + + tools.append(ModelTool( + identity=ToolIdentity( + author='Dify', + name=model.model, + label=model_configuration.label, + ), + parameters=[ + ToolParameter( + name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value, + label=I18nObject(zh_Hans='图片ID', en_US='Image ID'), + human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + required=True, + default=Tool.VARIABLE_KEY.IMAGE.value + ) + ], + description=ToolDescription( + human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'), + llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.', + ), + is_team_authorization=model.status == ModelStatus.ACTIVE, + tool_type=ModelTool.ModelToolType.VISION, + model_instance=model_instance, + model=model.model, + )) + + self.tools = tools + return tools + + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + return {} + + def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + return self._get_model_tools(tenant_id=tenant_id) + + def get_tool(self, tool_name: str) -> ModelTool: + """ + get tool by name + + :param tool_name: the name of the tool + :return: the tool + """ + if self.tools is None: + self.get_tools(user_id='', tenant_id=self.configuration.tenant_id) + + for tool in self.tools: + if tool.identity.name == tool_name: + return tool + + raise ValueError(f'tool {tool_name} not found') + + def get_parameters(self, tool_name: str) -> list[ToolParameter]: + """ + returns the parameters of the tool + + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters + """ + tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + if tool is None: + raise ToolNotFoundError(f'tool {tool_name} not found') + return tool.parameters + + @property + def app_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.MODEL + + def validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass + + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass \ No newline at end of file diff --git a/api/core/tools/tool/model_tool.py b/api/core/tools/tool/model_tool.py new file mode 100644 index 00000000000000..84e6610c75c848 --- /dev/null +++ b/api/core/tools/tool/model_tool.py @@ -0,0 +1,156 @@ +from base64 import b64encode +from enum import Enum +from typing import Any, cast + +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessageContent, + PromptMessageContentType, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage +from core.tools.tool.tool import Tool + +VISION_PROMPT = """## Image Recognition Task +### Task Description +I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions. +### Specific Requirements +1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color. +2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout. +3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details. +4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects. +5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features. +### Additional Considerations +- Ensure that the extracted information is as comprehensive and accurate as possible. +- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results. +- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information.""" + +class ModelTool(Tool): + class ModelToolType(Enum): + """ + the type of the model tool + """ + VISION = 'vision' + + model_configuration: dict[str, Any] = None + tool_type: ModelToolType + + def __init__(self, model_instance: ModelInstance = None, model: str = None, + tool_type: ModelToolType = ModelToolType.VISION, + properties: dict[ModelToolPropertyKey, Any] = None, + **kwargs): + """ + init the tool + """ + kwargs['model_configuration'] = { + 'model_instance': model_instance, + 'model': model, + 'properties': properties + } + kwargs['tool_type'] = tool_type + super().__init__(**kwargs) + + """ + Model tool + """ + def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=self.identity.copy() if self.identity else None, + parameters=self.parameters.copy() if self.parameters else None, + description=self.description.copy() if self.description else None, + model_instance=self.model_configuration['model_instance'], + model=self.model_configuration['model'], + tool_type=self.tool_type, + runtime=Tool.Runtime(**meta) + ) + + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None: + """ + validate the credentials for Model tool + """ + pass + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + """ + model_instance = self.model_configuration['model_instance'] + if not model_instance: + return self.create_text_message('the tool is not configured correctly') + + if self.tool_type == ModelTool.ModelToolType.VISION: + return self._invoke_llm_vision(user_id, tool_parameters) + else: + return self.create_text_message('the tool is not configured correctly') + + def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + # get image + image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id') + image_id = tool_parameters.pop(image_parameter_name, '') + if not image_id: + image = self.get_default_image_variable() + if not image: + return self.create_text_message('Please upload an image or input image_id') + else: + image = self.get_variable(image_id) + if not image: + image = self.get_default_image_variable() + if not image: + return self.create_text_message('Please upload an image or input image_id') + + if not image: + return self.create_text_message('Please upload an image or input image_id') + + # get image + image = self.get_variable_file(image.name) + if not image: + return self.create_text_message('Failed to get image') + + # organize prompt messages + prompt_messages = [ + SystemPromptMessage( + content=VISION_PROMPT + ), + UserPromptMessage( + content=[ + PromptMessageContent( + type=PromptMessageContentType.TEXT, + data='Recognize the image and extract the information from the image.' + ), + PromptMessageContent( + type=PromptMessageContentType.IMAGE, + data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}' + ) + ] + ) + ] + + llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance']) + result: LLMResult = llm_instance.invoke( + model=self.model_configuration['model'], + credentials=self.runtime.credentials, + prompt_messages=prompt_messages, + model_parameters=tool_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + ) + + if not result: + return self.create_text_message('Failed to extract information from the image') + + # get result + content = result.message.content + if not content: + return self.create_text_message('Failed to extract information from the image') + + return self.create_text_message(content) \ No newline at end of file diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index fd4748db70d28a..acfea4cd3fc82b 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -7,6 +7,7 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.model_runtime.entities.message_entities import PromptMessage +from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.constant import DEFAULT_PROVIDERS from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials @@ -16,10 +17,11 @@ from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.model_tool_provider import ModelToolProviderController from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.configuration import ToolConfiguration +from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration from core.tools.utils.encoder import serialize_base_model_dict from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -135,7 +137,7 @@ def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: st raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id, + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, agent_callback: DifyAgentCallbackHandler = None) \ -> Union[BuiltinTool, ApiTool]: """ @@ -194,6 +196,19 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten 'tenant_id': tenant_id, 'credentials': decrypted_credentials, }) + elif provider_type == 'model': + if tenant_id is None: + raise ValueError('tenant id is required for model provider') + # get model provider + model_provider = ToolManager.get_model_provider(tenant_id, provider_name) + + # get tool + model_tool = model_provider.get_tool(tool_name) + + return model_tool.fork_tool_runtime(meta={ + 'tenant_id': tenant_id, + 'credentials': model_tool.model_configuration['model_instance'].credentials + }) elif provider_type == 'app': raise NotImplementedError('app provider not implemented') else: @@ -266,6 +281,49 @@ def list_builtin_providers() -> list[BuiltinToolProviderController]: return builtin_providers + @staticmethod + def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: + """ + list all the model providers + + :return: the list of the model providers + """ + tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' + # get configurations + model_configurations = ModelToolConfigurationManager.get_all_configuration() + # get all providers + provider_manager = ProviderManager() + configurations = provider_manager.get_configurations(tenant_id).values() + # get model providers + model_providers: list[ModelToolProviderController] = [] + for configuration in configurations: + # all the model tool should be configurated + if configuration.provider.provider not in model_configurations: + continue + if not ModelToolProviderController.is_configuration_valid(configuration): + continue + model_providers.append(ModelToolProviderController.from_db(configuration)) + + return model_providers + + @staticmethod + def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController: + """ + get the model provider + + :param provider_name: the name of the provider + + :return: the provider + """ + # get configurations + provider_manager = ProviderManager() + configurations = provider_manager.get_configurations(tenant_id) + configuration = configurations.get(provider_name) + if configuration is None: + raise ToolProviderNotFoundError(f'model provider {provider_name} not found') + + return ModelToolProviderController.from_db(configuration) + @staticmethod def get_tool_label(tool_name: str) -> Union[I18nObject, None]: """ @@ -345,6 +403,28 @@ def user_list_providers( result_providers[provider_name].team_credentials = masked_credentials + # get model tool providers + model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) + # append model providers + for provider in model_providers: + result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider( + id=provider.identity.name, + author=provider.identity.author, + name=provider.identity.name, + description=I18nObject( + en_US=provider.identity.description.en_US, + zh_Hans=provider.identity.description.zh_Hans, + ), + icon=provider.identity.icon, + label=I18nObject( + en_US=provider.identity.label.en_US, + zh_Hans=provider.identity.label.zh_Hans, + ), + type=UserToolProvider.ProviderType.MODEL, + team_credentials={}, + is_team_authorization=provider.is_active, + ) + # get db api providers db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ filter(ApiToolProvider.tenant_id == tenant_id).all() diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 56a442a223e984..8f795fd245f0c7 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,10 +1,16 @@ -from typing import Any +import os +from typing import Any, Union from pydantic import BaseModel +from yaml import FullLoader, load from core.helper import encrypter from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType -from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.entities.tool_entities import ( + ModelToolConfiguration, + ModelToolProviderConfiguration, + ToolProviderCredentials, +) from core.tools.provider.tool_provider import ToolProviderController @@ -94,3 +100,65 @@ def delete_tool_credentials_cache(self): cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() + +class ModelToolConfigurationManager: + """ + Model as tool configuration + """ + _configurations: dict[str, ModelToolProviderConfiguration] = {} + _model_configurations: dict[str, ModelToolConfiguration] = {} + _inited = False + + @classmethod + def _init_configuration(cls): + """ + init configuration + """ + + absolute_path = os.path.abspath(os.path.dirname(__file__)) + model_tools_path = os.path.join(absolute_path, '..', 'model_tools') + + # get all .yaml file + files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')] + + for file in files: + provider = file.split('.')[0] + with open(os.path.join(model_tools_path, file), encoding='utf-8') as f: + configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader)) + models = configurations.models or [] + for model in models: + model_key = f'{provider}.{model.model}' + cls._model_configurations[model_key] = model + + cls._configurations[provider] = configurations + cls._inited = True + + @classmethod + def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]: + """ + get configuration by provider + """ + if not cls._inited: + cls._init_configuration() + return cls._configurations.get(provider, None) + + @classmethod + def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]: + """ + get all configurations + """ + if not cls._inited: + cls._init_configuration() + return cls._configurations + + @classmethod + def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]: + """ + get model configuration + """ + key = f'{provider}.{model}' + + if not cls._inited: + cls._init_configuration() + + return cls._model_configurations.get(key, None) \ No newline at end of file diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index b7e733f664aa47..30b8047435373a 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -22,6 +22,7 @@ from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider +from services.model_provider_service import ModelProviderService class ToolManageService: @@ -50,11 +51,13 @@ def repack_provider(provider: dict): :param provider: the provider dict """ url_prefix = (current_app.config.get("CONSOLE_API_URL") - + "/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/") if 'icon' in provider: if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value: - provider['icon'] = url_prefix + provider['name'] + '/icon' + provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon' + elif provider['type'] == UserToolProvider.ProviderType.MODEL.value: + provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon' elif provider['type'] == UserToolProvider.ProviderType.API.value: try: provider['icon'] = json.loads(provider['icon']) @@ -505,6 +508,46 @@ def get_builtin_tool_provider_icon( return icon_bytes, mime_type + @staticmethod + def get_model_tool_provider_icon( + provider: str + ): + """ + get tool provider icon and it's mimetype + """ + + service = ModelProviderService() + icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US') + + if icon_bytes is None: + raise ValueError(f'provider {provider} does not exists') + + return icon_bytes, mime_type + + @staticmethod + def list_model_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ): + """ + list model tool provider tools + """ + provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) + + result = [ + UserTool( + author=tool.identity.author, + name=tool.identity.name, + label=tool.identity.label, + description=tool.description.human, + parameters=tool.parameters or [] + ) for tool in tools + ] + + return json.loads( + serialize_base_model_array(result) + ) + @staticmethod def delete_api_tool_provider( user_id: str, tenant_id: str, provider_name: str diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index cd13636c941fca..95858d95406d0a 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -34,7 +34,7 @@ const AgentTools: FC = () => { const [selectedProviderId, setSelectedProviderId] = useState(undefined) const [isShowSettingTool, setIsShowSettingTool] = useState(false) const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => { - const collection = collectionList.find(collection => collection.id === item.provider_id) + const collection = collectionList.find(collection => collection.id === item.provider_id && collection.type === item.provider_type) const icon = collection?.icon return { ...item, diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index 9892018ab6fb40..378054aae6fcc5 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import type { Collection, Tool } from '@/app/components/tools/types' -import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' +import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools' import I18n from '@/context/i18n' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' @@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon' type Props = { collection: Collection isBuiltIn?: boolean + isModel?: boolean toolName: string setting?: Record readonly?: boolean @@ -29,6 +30,7 @@ type Props = { const SettingBuiltInTool: FC = ({ collection, isBuiltIn = true, + isModel = true, toolName, setting = {}, readonly, @@ -56,7 +58,11 @@ const SettingBuiltInTool: FC = ({ (async () => { setIsLoading(true) try { - const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name) + const list = isBuiltIn + ? await fetchBuiltInToolList(collection.name) + : isModel + ? await fetchModelToolList(collection.name) + : await fetchCustomToolList(collection.name) setTools(list) const currTool = list.find(tool => tool.name === toolName) if (currTool) { diff --git a/web/app/components/tools/index.tsx b/web/app/components/tools/index.tsx index 9e9bbd68020fcb..b467dcbad50b39 100644 --- a/web/app/components/tools/index.tsx +++ b/web/app/components/tools/index.tsx @@ -18,7 +18,7 @@ import NoSearchRes from './info/no-search-res' import NoCustomToolPlaceholder from './no-custom-tool-placeholder' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import TabSlider from '@/app/components/base/tab-slider' -import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' +import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools' import type { AgentTool } from '@/types/app' type Props = { @@ -89,9 +89,11 @@ const Tools: FC = ({ const showCollectionList = (() => { let typeFilteredList: Collection[] = [] if (collectionType === CollectionType.all) - typeFilteredList = collectionList - else - typeFilteredList = collectionList.filter(item => item.type === collectionType) + typeFilteredList = collectionList.filter(item => item.type !== CollectionType.model) + else if (collectionType === CollectionType.builtIn) + typeFilteredList = collectionList.filter(item => item.type === CollectionType.builtIn) + else if (collectionType === CollectionType.custom) + typeFilteredList = collectionList.filter(item => item.type === CollectionType.custom) if (query) return typeFilteredList.filter(item => item.name.includes(query)) @@ -122,6 +124,10 @@ const Tools: FC = ({ const list = await fetchBuiltInToolList(currCollection.name) setCurrentTools(list) } + else if (currCollection.type === CollectionType.model) { + const list = await fetchModelToolList(currCollection.name) + setCurrentTools(list) + } else { const list = await fetchCustomToolList(currCollection.name) setCurrentTools(list) @@ -130,7 +136,7 @@ const Tools: FC = ({ catch (e) { } setIsDetailLoading(false) })() - }, [currCollection?.name]) + }, [currCollection?.name, currCollection?.type]) const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false) const handleCreateToolCollection = () => { @@ -197,7 +203,7 @@ const Tools: FC = ({ (showCollectionList.length > 0 || !query) ? diff --git a/web/app/components/tools/tool-list/header.tsx b/web/app/components/tools/tool-list/header.tsx index 5a243a0a2be7e9..bf564f320f7a52 100644 --- a/web/app/components/tools/tool-list/header.tsx +++ b/web/app/components/tools/tool-list/header.tsx @@ -29,9 +29,8 @@ const Header: FC = ({ const { t } = useTranslation() const isInToolsPage = loc === LOC.tools const isInDebugPage = !isInToolsPage - const needAuth = collection?.allow_delete - // const isBuiltIn = collection.type === CollectionType.builtIn + const needAuth = collection?.allow_delete || collection?.type === CollectionType.model const isAuthed = collection.is_team_authorization return (
@@ -50,10 +49,13 @@ const Header: FC = ({ )}
- {collection.type === CollectionType.builtIn && needAuth && ( + {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && (
onShowAuth()} + onClick={() => { + if (collection.type === CollectionType.builtIn || collection.type === CollectionType.model) + onShowAuth() + }} >
{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}
diff --git a/web/app/components/tools/tool-list/index.tsx b/web/app/components/tools/tool-list/index.tsx index 3bee3292e6e65f..9228a028a54926 100644 --- a/web/app/components/tools/tool-list/index.tsx +++ b/web/app/components/tools/tool-list/index.tsx @@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types' import Loading from '../../base/loading' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import Toast from '../../base/toast' +import { ConfigurateMethodEnum } from '../../header/account-setting/model-provider-page/declarations' import Header from './header' import Item from './item' import AppIcon from '@/app/components/base/app-icon' @@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import type { AgentTool } from '@/types/app' import { MAX_TOOLS_NUM } from '@/config' +import { useModalContext } from '@/context/modal-context' +import { useProviderContext } from '@/context/provider-context' type Props = { collection: Collection | null @@ -42,9 +45,32 @@ const ToolList: FC = ({ const { t } = useTranslation() const isInToolsPage = loc === LOC.tools const isBuiltIn = collection?.type === CollectionType.builtIn + const isModel = collection?.type === CollectionType.model const needAuth = collection?.allow_delete + const { setShowModelModal } = useModalContext() const [showSettingAuth, setShowSettingAuth] = useState(false) + const { modelProviders: providers } = useProviderContext() + const showSettingAuthModal = () => { + if (isModel) { + const provider = providers.find(item => item.provider === collection?.id) + if (provider) { + setShowModelModal({ + payload: { + currentProvider: provider, + currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel, + currentCustomConfigrationModelFixedFields: undefined, + }, + onSaveCallback: () => { + onRefreshData() + }, + }) + } + } + else { + setShowSettingAuth(true) + } + } const [customCollection, setCustomCollection] = useState(null) useEffect(() => { @@ -116,7 +142,7 @@ const ToolList: FC = ({ icon={icon} collection={collection} loc={loc} - onShowAuth={() => setShowSettingAuth(true)} + onShowAuth={() => showSettingAuthModal()} onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)} />
@@ -124,12 +150,12 @@ const ToolList: FC = ({
{t('tools.includeToolNum', { num: list.length, })}
- {needAuth && isBuiltIn && !collection.is_team_authorization && ( + {needAuth && (isBuiltIn || isModel) && !collection.is_team_authorization && ( <>
·
setShowSettingAuth(true)} + onClick={() => showSettingAuthModal()} >
{t('tools.auth.setup')}
@@ -149,7 +175,7 @@ const ToolList: FC = ({ collection={collection} isInToolsPage={isInToolsPage} isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM} - added={!!addedTools?.find(v => v.provider_id === collection.id && v.tool_name === item.name)} + added={!!addedTools?.find(v => v.provider_id === collection.id && v.provider_type === collection.type && v.tool_name === item.name)} onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined} /> ))} diff --git a/web/app/components/tools/tool-list/item.tsx b/web/app/components/tools/tool-list/item.tsx index c53aba61d388a6..e6e07cd5c7b45a 100644 --- a/web/app/components/tools/tool-list/item.tsx +++ b/web/app/components/tools/tool-list/item.tsx @@ -35,6 +35,7 @@ const Item: FC = ({ const language = getLanguage(locale) const isBuiltIn = collection.type === CollectionType.builtIn + const isModel = collection.type === CollectionType.model const canShowDetail = isInToolsPage const [showDetail, setShowDetail] = useState(false) const addBtn = @@ -73,6 +74,7 @@ const Item: FC = ({ setShowDetail(false) }} isBuiltIn={isBuiltIn} + isModel={isModel} /> )} diff --git a/web/app/components/tools/tool-nav-list/index.tsx b/web/app/components/tools/tool-nav-list/index.tsx index 1fab9de7a36332..3a8fd4088bc6d5 100644 --- a/web/app/components/tools/tool-nav-list/index.tsx +++ b/web/app/components/tools/tool-nav-list/index.tsx @@ -6,21 +6,21 @@ import Item from './item' import type { Collection } from '@/app/components/tools/types' type Props = { className?: string - currentName: string + currentIndex: number list: Collection[] onChosen: (index: number) => void } const ToolNavList: FC = ({ className, - currentName, + currentIndex, list, onChosen, }) => { return (
{list.map((item, index) => ( - onChosen(index)}> + onChosen(index)}> ))}
) diff --git a/web/app/components/tools/types.ts b/web/app/components/tools/types.ts index 389276e81ce12c..6de8d8aa76e712 100644 --- a/web/app/components/tools/types.ts +++ b/web/app/components/tools/types.ts @@ -26,6 +26,7 @@ export enum CollectionType { all = 'all', builtIn = 'builtin', custom = 'api', + model = 'model', } export type Emoji = { diff --git a/web/service/tools.ts b/web/service/tools.ts index 008de4a55738b9..ac59e2e50871d0 100644 --- a/web/service/tools.ts +++ b/web/service/tools.ts @@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => { export const fetchCustomToolList = (collectionName: string) => { return get(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`) } + +export const fetchModelToolList = (collectionName: string) => { + return get(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`) +} + export const fetchBuiltInToolCredentialSchema = (collectionName: string) => { return get(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`) }