diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 36d04ada39a4fe..f359cf82fd2629 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -37,12 +37,13 @@ def __init__(self, tenant_id: str, app_model_config: AppModelConfig): def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, - return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]: + retriever_from: str = 'dev') -> Optional[AgentExecutor]: if not self.app_model_config.agent_mode_dict: return None agent_mode_config = self.app_model_config.agent_mode_dict model_dict = self.app_model_config.model_dict + return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False) chain = None if agent_mode_config and agent_mode_config.get('enabled'): diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index d90636e71744bd..33fec157eaa36b 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -30,7 +30,7 @@ class DatasetRetrieverTool(BaseTool): dataset_id: str k: int = 3 conversation_message_task: ConversationMessageTask - return_resource: str + return_resource: bool retriever_from: str @classmethod diff --git a/api/models/model.py b/api/models/model.py index 9a0b8c6554cd10..f372f516da1e8f 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -196,7 +196,8 @@ def copy(self): user_input_form=self.user_input_form, dataset_query_variable=self.dataset_query_variable, pre_prompt=self.pre_prompt, - agent_mode=self.agent_mode + agent_mode=self.agent_mode, + retriever_resource=self.retriever_resource ) return new_app_model_config