From 76f676ba5a8b4de00e78f9736465f147e1aadaf0 Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 17 Oct 2023 19:54:59 +0800 Subject: [PATCH] feat: optimize completion model agent (#1364) --- .../agent/agent/multi_dataset_router_agent.py | 2 +- .../structed_multi_dataset_router_agent.py | 90 ++++++++++++++++--- api/core/agent/agent/structured_chat.py | 86 +++++++++++++++--- 3 files changed, 156 insertions(+), 22 deletions(-) diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 5636d91b51fc2b..16b4a2ab248786 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -76,7 +76,7 @@ def plan( agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) if isinstance(agent_decision, AgentAction): tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: + if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs: tool_inputs['query'] = kwargs['input'] agent_decision.tool_input = tool_inputs else: diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 0ba774a88fc34c..84c0553625b8ff 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,7 +1,7 @@ import re from typing import List, Tuple, Any, Union, Sequence, Optional, cast -from langchain import BasePromptTemplate +from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.callbacks.base import BaseCallbackManager @@ -12,6 +12,7 @@ from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from core.chain.llm_chain import LLMChain +from core.model_providers.models.entity.model_params import ModelMode from core.model_providers.models.llm.base import BaseLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -92,6 +93,10 @@ def plan( rst = tool.run(tool_input={'query': kwargs['input']}) return AgentFinish(return_values={"output": rst}, log=rst) + if intermediate_steps: + _, observation = intermediate_steps[-1] + return AgentFinish(return_values={"output": observation}, log=observation) + full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) try: @@ -107,6 +112,8 @@ def plan( if isinstance(tool_inputs, dict) and 'query' in tool_inputs: tool_inputs['query'] = kwargs['input'] agent_decision.tool_input = tool_inputs + elif isinstance(tool_inputs, str): + agent_decision.tool_input = kwargs['input'] else: agent_decision.return_values['output'] = '' return agent_decision @@ -143,6 +150,61 @@ def create_prompt( ] return ChatPromptTemplate(input_variables=input_variables, messages=messages) + @classmethod + def create_completion_prompt( + cls, + tools: Sequence[BaseTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + ) -> PromptTemplate: + """Create prompt in the style of the zero shot agent. + + Args: + tools: List of tools the agent will have access to, used to format the + prompt. + prefix: String to put before the list of tools. + input_variables: List of input variables the final prompt will expect. + + Returns: + A PromptTemplate with the template assembled from the pieces here. + """ + suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Question: {input} +Thought: {agent_scratchpad} +""" + + tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) + tool_names = ", ".join([tool.name for tool in tools]) + format_instructions = format_instructions.format(tool_names=tool_names) + template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + if input_variables is None: + input_variables = ["input", "agent_scratchpad"] + return PromptTemplate(template=template, input_variables=input_variables) + + def _construct_scratchpad( + self, intermediate_steps: List[Tuple[AgentAction, str]] + ) -> str: + agent_scratchpad = "" + for action, observation in intermediate_steps: + agent_scratchpad += action.log + agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + + if not isinstance(agent_scratchpad, str): + raise ValueError("agent_scratchpad should be of type string.") + if agent_scratchpad: + llm_chain = cast(LLMChain, self.llm_chain) + if llm_chain.model_instance.model_mode == ModelMode.CHAT: + return ( + f"This was your previous work " + f"(but I haven't seen any of it! I only see what " + f"you return as final answer):\n{agent_scratchpad}" + ) + else: + return agent_scratchpad + else: + return agent_scratchpad + @classmethod def from_llm_and_tools( cls, @@ -160,15 +222,23 @@ def from_llm_and_tools( ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) + if model_instance.model_mode == ModelMode.CHAT: + prompt = cls.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + human_message_template=human_message_template, + format_instructions=format_instructions, + input_variables=input_variables, + memory_prompts=memory_prompts, + ) + else: + prompt = cls.create_completion_prompt( + tools, + prefix=prefix, + format_instructions=format_instructions, + input_variables=input_variables + ) llm_chain = LLMChain( model_instance=model_instance, prompt=prompt, diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 2c36bc38d37300..aca6de79b08106 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,7 +1,7 @@ import re -from typing import List, Tuple, Any, Union, Sequence, Optional +from typing import List, Tuple, Any, Union, Sequence, Optional, cast -from langchain import BasePromptTemplate +from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE from langchain.callbacks.base import BaseCallbackManager @@ -15,6 +15,7 @@ from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError from core.chain.llm_chain import LLMChain +from core.model_providers.models.entity.model_params import ModelMode from core.model_providers.models.llm.base import BaseLLM FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -184,6 +185,61 @@ def create_prompt( ] return ChatPromptTemplate(input_variables=input_variables, messages=messages) + @classmethod + def create_completion_prompt( + cls, + tools: Sequence[BaseTool], + prefix: str = PREFIX, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + ) -> PromptTemplate: + """Create prompt in the style of the zero shot agent. + + Args: + tools: List of tools the agent will have access to, used to format the + prompt. + prefix: String to put before the list of tools. + input_variables: List of input variables the final prompt will expect. + + Returns: + A PromptTemplate with the template assembled from the pieces here. + """ + suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. +Question: {input} +Thought: {agent_scratchpad} +""" + + tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) + tool_names = ", ".join([tool.name for tool in tools]) + format_instructions = format_instructions.format(tool_names=tool_names) + template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) + if input_variables is None: + input_variables = ["input", "agent_scratchpad"] + return PromptTemplate(template=template, input_variables=input_variables) + + def _construct_scratchpad( + self, intermediate_steps: List[Tuple[AgentAction, str]] + ) -> str: + agent_scratchpad = "" + for action, observation in intermediate_steps: + agent_scratchpad += action.log + agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" + + if not isinstance(agent_scratchpad, str): + raise ValueError("agent_scratchpad should be of type string.") + if agent_scratchpad: + llm_chain = cast(LLMChain, self.llm_chain) + if llm_chain.model_instance.model_mode == ModelMode.CHAT: + return ( + f"This was your previous work " + f"(but I haven't seen any of it! I only see what " + f"you return as final answer):\n{agent_scratchpad}" + ) + else: + return agent_scratchpad + else: + return agent_scratchpad + @classmethod def from_llm_and_tools( cls, @@ -201,15 +257,23 @@ def from_llm_and_tools( ) -> Agent: """Construct an agent from an LLM and tools.""" cls._validate_tools(tools) - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) + if model_instance.model_mode == ModelMode.CHAT: + prompt = cls.create_prompt( + tools, + prefix=prefix, + suffix=suffix, + human_message_template=human_message_template, + format_instructions=format_instructions, + input_variables=input_variables, + memory_prompts=memory_prompts, + ) + else: + prompt = cls.create_completion_prompt( + tools, + prefix=prefix, + format_instructions=format_instructions, + input_variables=input_variables, + ) llm_chain = LLMChain( model_instance=model_instance, prompt=prompt,