Skip to content

Commit

Permalink
Feat/tool secret parameter (#2760)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Mar 8, 2024
1 parent bbc0d33 commit ce58f06
Show file tree
Hide file tree
Showing 13 changed files with 492 additions and 119 deletions.
38 changes: 36 additions & 2 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from libs.login import login_required
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService

from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity

def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
Expand Down Expand Up @@ -236,7 +238,39 @@ class AppApi(Resource):
def get(self, app_id):
"""Get app detail"""
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
app: App = _get_app(app_id, current_user.current_tenant_id)

# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}

# override tool parameters
tool['tool_parameters'] = masked_parameter

# override agent mode
model_config.agent_mode = json.dumps(agent_mode)

return app

Expand Down
80 changes: 80 additions & 0 deletions api/controllers/console/app/model_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json

from flask import request
from flask_login import current_user
Expand All @@ -7,6 +8,9 @@
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
Expand Down Expand Up @@ -38,6 +42,82 @@ def post(self, app_id):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)

# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}

key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime

# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)

# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)

manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()

# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue

if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]

# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})

# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)

db.session.add(new_app_model_config)
db.session.flush()

Expand Down
121 changes: 29 additions & 92 deletions api/core/features/assistant_base_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.application_generate_entity.tenant_id,
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
Expand All @@ -171,33 +171,11 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
}
)

runtime_parameters = {}

parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []

# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break

if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)

parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue

parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
Expand All @@ -213,59 +191,16 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
else:
raise ValueError(f"parameter type {parameter.type} is not supported")

if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")

if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")

# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")

# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config

elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}

if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}

if parameter.required:
message_tool.parameters['required'].append(parameter.name)
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum

tool_entity.runtime.runtime_parameters.update(runtime_parameters)
if parameter.required:
message_tool.parameters['required'].append(parameter.name)

return message_tool, tool_entity

Expand Down Expand Up @@ -305,6 +240,9 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool)
tool_runtime_parameters = tool.get_runtime_parameters() or []

for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue

parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
Expand All @@ -320,18 +258,17 @@ def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool)
else:
raise ValueError(f"parameter type {parameter.type} is not supported")

if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}

if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum

if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}

if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum

if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)

return prompt_tool

Expand Down
54 changes: 54 additions & 0 deletions api/core/helper/tool_parameter_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional

from extensions.ext_redis import redis_client


class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"

class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"

def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None

return cached_tool_parameter
else:
return None

def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))

def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)
2 changes: 1 addition & 1 deletion api/core/tools/docs/en_US/tool_scale_out.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ parameters: # Parameter list
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
- `required` Required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
Expand Down
2 changes: 1 addition & 1 deletion api/core/tools/docs/zh_Hans/tool_scale_out.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ parameters: # 参数列表
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string``number``boolean``select` 四种类型,分别对应字符串、数字、布尔值、下拉框
- `type` 参数类型,目前支持`string``number``boolean``select``secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `required` 是否必填
-`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
-`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
Expand Down
1 change: 1 addition & 0 deletions api/core/tools/entities/tool_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class ToolParameterType(Enum):
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"

class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool
Expand Down
Loading

0 comments on commit ce58f06

Please sign in to comment.