diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 21ce9cb6afa7f9..4b648a4e28755f 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -27,7 +27,9 @@ from libs.login import login_required from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService - +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.tools.tool_manager import ToolManager +from core.entities.application_entities import AgentToolEntity def _get_app(app_id, tenant_id): app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() @@ -236,7 +238,39 @@ class AppApi(Resource): def get(self, app_id): """Get app detail""" app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) + app: App = _get_app(app_id, current_user.current_tenant_id) + + # get original app model config + model_config: AppModelConfig = app.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool['tool_parameters'] = masked_parameter + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) return app diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f67fff4b0627bc..117007d055aa6b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,3 +1,4 @@ +import json from flask import request from flask_login import current_user @@ -7,6 +8,9 @@ from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AgentToolEntity +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required @@ -38,6 +42,82 @@ def post(self, app_id): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} + + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime + + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] + + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) + db.session.add(new_app_model_config) db.session.flush() diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 2a4ae7e1356343..0ee6436d1195ea 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -154,9 +154,9 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P """ convert tool to prompt message tool """ - tool_entity = ToolManager.get_tool_runtime( - provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, - tenant_id=self.application_generate_entity.tenant_id, + tool_entity = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + agent_tool=tool, agent_callback=self.agent_callback ) tool_entity.load_variables(self.variables_pool) @@ -171,33 +171,11 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P } ) - runtime_parameters = {} - - parameters = tool_entity.parameters or [] - user_parameters = tool_entity.get_runtime_parameters() or [] - - # override parameters - for parameter in user_parameters: - # check if parameter in tool parameters - found = False - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - found = True - break - - if found: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - else: - # add new parameter - parameters.append(parameter) - + parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + parameter_type = 'string' enum = [] if parameter.type == ToolParameter.ToolParameterType.STRING: @@ -213,59 +191,16 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P else: raise ValueError(f"parameter type {parameter.type} is not supported") - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config - - elif parameter.form == ToolParameter.ToolParameterForm.LLM: - message_tool.parameters['properties'][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or '', - } - - if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + message_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } - if parameter.required: - message_tool.parameters['required'].append(parameter.name) + if len(enum) > 0: + message_tool.parameters['properties'][parameter.name]['enum'] = enum - tool_entity.runtime.runtime_parameters.update(runtime_parameters) + if parameter.required: + message_tool.parameters['required'].append(parameter.name) return message_tool, tool_entity @@ -305,6 +240,9 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) tool_runtime_parameters = tool.get_runtime_parameters() or [] for parameter in tool_runtime_parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + parameter_type = 'string' enum = [] if parameter.type == ToolParameter.ToolParameterType.STRING: @@ -320,18 +258,17 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) else: raise ValueError(f"parameter type {parameter.type} is not supported") - if parameter.form == ToolParameter.ToolParameterForm.LLM: - prompt_tool.parameters['properties'][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or '', - } - - if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum - - if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + prompt_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } + + if len(enum) > 0: + prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + + if parameter.required: + if parameter.name not in prompt_tool.parameters['required']: + prompt_tool.parameters['required'].append(parameter.name) return prompt_tool diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py new file mode 100644 index 00000000000000..db05eb18750636 --- /dev/null +++ b/api/core/helper/tool_parameter_cache.py @@ -0,0 +1,54 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ToolParameterCacheType(Enum): + PARAMETER = "tool_parameter" + +class ToolParameterCache: + def __init__(self, + tenant_id: str, + provider: str, + tool_name: str, + cache_type: ToolParameterCacheType + ): + self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_tool_parameter = redis_client.get(self.cache_key) + if cached_tool_parameter: + try: + cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = json.loads(cached_tool_parameter) + except JSONDecodeError: + return None + + return cached_tool_parameter + else: + return None + + def set(self, parameters: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) \ No newline at end of file diff --git a/api/core/tools/docs/en_US/tool_scale_out.md b/api/core/tools/docs/en_US/tool_scale_out.md index 589a3c881091ec..e0269e02095878 100644 --- a/api/core/tools/docs/en_US/tool_scale_out.md +++ b/api/core/tools/docs/en_US/tool_scale_out.md @@ -119,7 +119,7 @@ parameters: # Parameter list - The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. - `parameters` Parameter list - `name` Parameter name, unique, no duplication with other parameters - - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box + - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type - `required` Required or not - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts diff --git a/api/core/tools/docs/zh_Hans/tool_scale_out.md b/api/core/tools/docs/zh_Hans/tool_scale_out.md index be146a5aebd156..20bb5e6dbc9348 100644 --- a/api/core/tools/docs/zh_Hans/tool_scale_out.md +++ b/api/core/tools/docs/zh_Hans/tool_scale_out.md @@ -119,7 +119,7 @@ parameters: # 参数列表 - `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 - `parameters` 参数列表 - `name` 参数名称,唯一,不允许和其他参数重名 - - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框 + - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型 - `required` 是否必填 - 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数 - 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 61b41f9cf4b211..f7a61b0b0cca6a 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -100,6 +100,7 @@ class ToolParameterType(Enum): NUMBER = "number" BOOLEAN = "boolean" SELECT = "select" + SECRET_INPUT = "secret-input" class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index e61cf6b1e2df60..615033f5d96f4b 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool): _api_base_url = URL('https://co.aippt.cn/api') _api_token_cache = {} _api_token_cache_lock = Lock() + _style_cache = {} + _style_cache_lock = Lock() _task = {} _task_type_map = { @@ -390,20 +392,31 @@ def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> st ).digest() ).decode('utf-8') - def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + @classmethod + def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: """ Get styles - - :param credentials: the credentials - :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ + + # check cache + with cls._style_cache_lock: + # clear expired styles + now = time() + for key in list(cls._style_cache.keys()): + if cls._style_cache[key]['expire'] < now: + del cls._style_cache[key] + + key = f'{credentials["aippt_access_key"]}#@#{user_id}' + if key in cls._style_cache: + return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + headers = { 'x-channel': '', - 'x-api-key': self.runtime.credentials['aippt_access_key'], - 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id) + 'x-api-key': credentials['aippt_access_key'], + 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) } response = get( - str(self._api_base_url / 'template_component' / 'suit' / 'select'), + str(cls._api_base_url / 'template_component' / 'suit' / 'select'), headers=headers ) @@ -425,7 +438,26 @@ def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: 'name': item.get('title'), } for item in response.get('data', {}).get('suit_style') or []] + with cls._style_cache_lock: + cls._style_cache[key] = { + 'colors': colors, + 'styles': styles, + 'expire': now + 60 * 60 + } + return colors, styles + + def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + """ + Get styles + + :param credentials: the credentials + :return: Tuple[list[dict[id, color]], list[dict[id, style]] + """ + if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): + return [], [] + + return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) def _get_suit(self, style_id: int, colour_id: int) -> int: """ diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml index 52f3cf573102b9..ece1bbc9272ce7 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml @@ -14,7 +14,7 @@ description: llm: A tool for sending messages to a chat group on Wecom(企业微信) . parameters: - name: hook_key - type: string + type: secret-input required: true label: en_US: Wecom Group bot webhook key diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 192793897e7aea..351ae4362ef819 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -266,6 +266,40 @@ def get_runtime_parameters(self) -> list[ToolParameter]: """ return self.parameters + def get_all_runtime_parameters(self) -> list[ToolParameter]: + """ + get all runtime parameters + + :return: all runtime parameters + """ + parameters = self.parameters or [] + parameters = parameters.copy() + user_parameters = self.get_runtime_parameters() or [] + user_parameters = user_parameters.copy() + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + found = False + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + found = True + break + + if found: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + else: + # add new parameter + parameters.append(parameter) + + return parameters + def is_tool_available(self) -> bool: """ check if the tool is available diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index acfea4cd3fc82b..2ac8f27bab7421 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,11 +6,17 @@ from typing import Any, Union from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.constant import DEFAULT_PROVIDERS -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeMessage, + ToolParameter, + ToolProviderCredentials, +) from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiBasedToolProviderController @@ -21,7 +27,12 @@ from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration +from core.tools.tool.tool import Tool +from core.tools.utils.configuration import ( + ModelToolConfigurationManager, + ToolConfigurationManager, + ToolParameterConfigurationManager, +) from core.tools.utils.encoder import serialize_base_model_dict from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -172,7 +183,7 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten # decrypt the credentials credentials = builtin_provider.credentials controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) @@ -189,7 +200,7 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) # decrypt the credentials - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ @@ -214,6 +225,71 @@ def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, ten else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: + """ + get the agent tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + for parameter in parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # get tool parameter from form + tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) + if not tool_parameter_config: + # get default value + tool_parameter_config = parameter.default + if not tool_parameter_config and parameter.required: + raise ValueError(f"tool parameter {parameter.name} not found in tool config") + + if parameter.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter.options)) + if tool_parameter_config not in options: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(tool_parameter_config, int): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, float): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, str): + if '.' in tool_parameter_config: + tool_parameter_config = float(tool_parameter_config) + else: + tool_parameter_config = int(tool_parameter_config) + elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: + tool_parameter_config = bool(tool_parameter_config) + elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + tool_parameter_config = str(tool_parameter_config) + elif parameter.type == ToolParameter.ToolParameterType: + tool_parameter_config = str(tool_parameter_config) + except Exception as e: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") + + # save tool parameter to tool entity memory + runtime_parameters[parameter.name] = tool_parameter_config + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: """ @@ -396,7 +472,7 @@ def user_list_providers( controller = ToolManager.get_builtin_provider(provider_name) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) @@ -463,7 +539,7 @@ def user_list_providers( ) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -523,7 +599,7 @@ def user_get_api_provider(provider: str, tenant_id: str) -> dict: provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 8f795fd245f0c7..927af1f5be5e86 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -5,16 +5,19 @@ from yaml import FullLoader, load from core.helper import encrypter +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.entities.tool_entities import ( ModelToolConfiguration, ModelToolProviderConfiguration, + ToolParameter, ToolProviderCredentials, ) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool -class ToolConfiguration(BaseModel): +class ToolConfigurationManager(BaseModel): tenant_id: str provider_controller: ToolProviderController @@ -101,6 +104,128 @@ def delete_tool_credentials_cache(self): ) cache.delete() +class ToolParameterConfigurationManager(BaseModel): + """ + Tool parameter configuration manager + """ + tenant_id: str + tool_runtime: Tool + provider_name: str + provider_type: str + + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + deep copy parameters + """ + return {key: value for key, value in parameters.items()} + + def _merge_parameters(self) -> list[ToolParameter]: + """ + merge parameters + """ + # get tool parameters + tool_parameters = self.tool_runtime.parameters or [] + # get tool runtime parameters + runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] + # override parameters + current_parameters = tool_parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return current_parameters + + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + mask tool parameters + + return a deep copy of parameters with masked values + """ + parameters = self._deep_copy(parameters) + + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + if len(parameters[parameter.name]) > 6: + parameters[parameter.name] = \ + parameters[parameter.name][:2] + \ + '*' * (len(parameters[parameter.name]) - 4) +\ + parameters[parameter.name][-2:] + else: + parameters[parameter.name] = '*' * len(parameters[parameter.name]) + + return parameters + + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + encrypt tool parameters with tenant id + + return a deep copy of parameters with encrypted values + """ + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) + parameters[parameter.name] = encrypted + + return parameters + + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + decrypt tool parameters with tenant id + + return a deep copy of parameters with decrypted values + """ + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f'{self.provider_type}.{self.provider_name}', + tool_name=self.tool_runtime.identity.name, + cache_type=ToolParameterCacheType.PARAMETER + ) + cached_parameters = cache.get() + if cached_parameters: + return cached_parameters + + # override parameters + current_parameters = self._merge_parameters() + has_secret_input = False + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + try: + has_secret_input = True + parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) + except: + pass + + if has_secret_input: + cache.set(parameters) + + return parameters + + def delete_tool_parameters_cache(self): + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f'{self.provider_type}.{self.provider_name}', + tool_name=self.tool_runtime.identity.name, + cache_type=ToolParameterCacheType.PARAMETER + ) + cache.delete() + class ModelToolConfigurationManager: """ Model as tool configuration diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 30b8047435373a..ff618e5d2bde4e 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -17,7 +17,7 @@ from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ToolConfiguration +from core.tools.utils.configuration import ToolConfigurationManager from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db @@ -77,7 +77,7 @@ def list_builtin_tool_provider_tools( provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # check if user has added the provider builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -279,7 +279,7 @@ def create_api_tool_provider( provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) db_provider.credentials_str = json.dumps(encrypted_credentials) @@ -366,7 +366,7 @@ def update_builtin_tool_provider( provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: raise ValueError(f'provider {provider_name} does not need credentials') - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -450,7 +450,7 @@ def update_api_tool_provider( provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) @@ -490,7 +490,7 @@ def delete_builtin_tool_provider( # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() return { 'result': 'success' } @@ -632,7 +632,7 @@ def test_api_tool_preview( # decrypt credentials if db_provider.id: - tool_configuration = ToolConfiguration( + tool_configuration = ToolConfigurationManager( tenant_id=tenant_id, provider_controller=provider_controller )