diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index c7767d13074594..57f054f7f40d77 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -2,14 +2,18 @@ from typing import Tuple, List, Any, Union, Sequence, Optional, cast from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent +from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage +from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool +from pydantic import root_validator +from core.model_providers.models.entity.message import to_prompt_messages from core.model_providers.models.llm.base import BaseLLM +from core.third_party.langchain.llms.fake import FakeLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -24,6 +28,10 @@ class Config: arbitrary_types_allowed = True + @root_validator + def validate_llm(cls, values: dict) -> dict: + return values + def should_use_agent(self, query: str): """ return should use agent @@ -65,7 +73,7 @@ def plan( return AgentFinish(return_values={"output": observation}, log=observation) try: - agent_decision = super().plan(intermediate_steps, callbacks, **kwargs) + agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs) if isinstance(agent_decision, AgentAction): tool_inputs = agent_decision.tool_input if isinstance(tool_inputs, dict) and 'query' in tool_inputs: @@ -76,6 +84,44 @@ def plan( new_exception = self.model_instance.handle_exceptions(e) raise new_exception + def real_plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, along with observations + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + agent_scratchpad = _format_intermediate_steps(intermediate_steps) + selected_inputs = { + k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" + } + full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) + prompt = self.prompt.format_prompt(**full_inputs) + messages = prompt.to_messages() + prompt_messages = to_prompt_messages(messages) + result = self.model_instance.run( + messages=prompt_messages, + functions=self.functions, + ) + + ai_message = AIMessage( + content=result.content, + additional_kwargs={ + 'function_call': result.function_call + } + ) + + agent_decision = _parse_ai_message(ai_message) + return agent_decision + async def aplan( self, intermediate_steps: List[Tuple[AgentAction, str]], @@ -87,7 +133,7 @@ async def aplan( @classmethod def from_llm_and_tools( cls, - llm: BaseLanguageModel, + model_instance: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, @@ -96,11 +142,15 @@ def from_llm_and_tools( ), **kwargs: Any, ) -> BaseSingleActionAgent: - return super().from_llm_and_tools( - llm=llm, - tools=tools, - callback_manager=callback_manager, + prompt = cls.create_prompt( extra_prompt_messages=extra_prompt_messages, system_message=system_message, + ) + return cls( + model_instance=model_instance, + llm=FakeLLM(response=''), + prompt=prompt, + tools=tools, + callback_manager=callback_manager, **kwargs, ) diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index addf0831ce02c7..8931bdc216a6a0 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -5,21 +5,40 @@ _format_intermediate_steps from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks +from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken +from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage -from langchain.schema.language_model import BaseLanguageModel +from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \ + get_buffer_string from langchain.tools import BaseTool +from pydantic import root_validator -from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError -from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin +from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin +from core.chain.llm_chain import LLMChain +from core.model_providers.models.entity.message import to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM +from core.third_party.langchain.llms.fake import FakeLLM -class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin): +class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): + moving_summary_buffer: str = "" + moving_summary_index: int = 0 + summary_model_instance: BaseLLM = None + model_instance: BaseLLM + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @root_validator + def validate_llm(cls, values: dict) -> dict: + return values @classmethod def from_llm_and_tools( cls, - llm: BaseLanguageModel, + model_instance: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, @@ -28,12 +47,16 @@ def from_llm_and_tools( ), **kwargs: Any, ) -> BaseSingleActionAgent: - return super().from_llm_and_tools( - llm=llm, + prompt = cls.create_prompt( + extra_prompt_messages=extra_prompt_messages, + system_message=system_message, + ) + return cls( + model_instance=model_instance, + llm=FakeLLM(response=''), + prompt=prompt, tools=tools, callback_manager=callback_manager, - extra_prompt_messages=extra_prompt_messages, - system_message=cls.get_system_message(), **kwargs, ) @@ -44,23 +67,26 @@ def should_use_agent(self, query: str): :param query: :return: """ - original_max_tokens = self.llm.max_tokens - self.llm.max_tokens = 40 + original_max_tokens = self.model_instance.model_kwargs.max_tokens + self.model_instance.model_kwargs.max_tokens = 40 prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) messages = prompt.to_messages() try: - predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=None + prompt_messages = to_prompt_messages(messages) + result = self.model_instance.run( + messages=prompt_messages, + functions=self.functions, + callbacks=None ) except Exception as e: new_exception = self.model_instance.handle_exceptions(e) raise new_exception - function_call = predicted_message.additional_kwargs.get("function_call", {}) + function_call = result.function_call - self.llm.max_tokens = original_max_tokens + self.model_instance.model_kwargs.max_tokens = original_max_tokens return True if function_call else False @@ -93,10 +119,19 @@ def plan( except ExceededLLMTokensLimitError as e: return AgentFinish(return_values={"output": str(e)}, log=str(e)) - predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks + prompt_messages = to_prompt_messages(messages) + result = self.model_instance.run( + messages=prompt_messages, + functions=self.functions, + ) + + ai_message = AIMessage( + content=result.content, + additional_kwargs={ + 'function_call': result.function_call + } ) - agent_decision = _parse_ai_message(predicted_message) + agent_decision = _parse_ai_message(ai_message) if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': tool_inputs = agent_decision.tool_input @@ -122,3 +157,142 @@ def return_stopped_response( return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) except ValueError: return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") + + def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: + # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 + rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) + rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens + if rest_tokens >= 0: + return messages + + system_message = None + human_message = None + should_summary_messages = [] + for message in messages: + if isinstance(message, SystemMessage): + system_message = message + elif isinstance(message, HumanMessage): + human_message = message + else: + should_summary_messages.append(message) + + if len(should_summary_messages) > 2: + ai_message = should_summary_messages[-2] + function_message = should_summary_messages[-1] + should_summary_messages = should_summary_messages[self.moving_summary_index:-2] + self.moving_summary_index = len(should_summary_messages) + else: + error_msg = "Exceeded LLM tokens limit, stopped." + raise ExceededLLMTokensLimitError(error_msg) + + new_messages = [system_message, human_message] + + if self.moving_summary_index == 0: + should_summary_messages.insert(0, human_message) + + self.moving_summary_buffer = self.predict_new_summary( + messages=should_summary_messages, + existing_summary=self.moving_summary_buffer + ) + + new_messages.append(AIMessage(content=self.moving_summary_buffer)) + new_messages.append(ai_message) + new_messages.append(function_message) + + return new_messages + + def predict_new_summary( + self, messages: List[BaseMessage], existing_summary: str + ) -> str: + new_lines = get_buffer_string( + messages, + human_prefix="Human", + ai_prefix="AI", + ) + + chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) + return chain.predict(summary=existing_summary, new_lines=new_lines) + + def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: + """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. + + Official documentation: https://github.com/openai/openai-cookbook/blob/ + main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" + if model_instance.model_provider.provider_name == 'azure_openai': + model = model_instance.base_model_name + model = model.replace("gpt-35", "gpt-3.5") + else: + model = model_instance.base_model_name + + tiktoken_ = _import_tiktoken() + try: + encoding = tiktoken_.encoding_for_model(model) + except KeyError: + model = "cl100k_base" + encoding = tiktoken_.get_encoding(model) + + if model.startswith("gpt-3.5-turbo"): + # every message follows {role/name}\n{content}\n + tokens_per_message = 4 + # if there's a name, the role is omitted + tokens_per_name = -1 + elif model.startswith("gpt-4"): + tokens_per_message = 3 + tokens_per_name = 1 + else: + raise NotImplementedError( + f"get_num_tokens_from_messages() is not presently implemented " + f"for model {model}." + "See https://github.com/openai/openai-python/blob/main/chatml.md for " + "information on how messages are converted to tokens." + ) + num_tokens = 0 + for m in messages: + message = _convert_message_to_dict(m) + num_tokens += tokens_per_message + for key, value in message.items(): + if key == "function_call": + for f_key, f_value in value.items(): + num_tokens += len(encoding.encode(f_key)) + num_tokens += len(encoding.encode(f_value)) + else: + num_tokens += len(encoding.encode(value)) + + if key == "name": + num_tokens += tokens_per_name + # every reply is primed with assistant + num_tokens += 3 + + if kwargs.get('functions'): + for function in kwargs.get('functions'): + num_tokens += len(encoding.encode('name')) + num_tokens += len(encoding.encode(function.get("name"))) + num_tokens += len(encoding.encode('description')) + num_tokens += len(encoding.encode(function.get("description"))) + parameters = function.get("parameters") + num_tokens += len(encoding.encode('parameters')) + if 'title' in parameters: + num_tokens += len(encoding.encode('title')) + num_tokens += len(encoding.encode(parameters.get("title"))) + num_tokens += len(encoding.encode('type')) + num_tokens += len(encoding.encode(parameters.get("type"))) + if 'properties' in parameters: + num_tokens += len(encoding.encode('properties')) + for key, value in parameters.get('properties').items(): + num_tokens += len(encoding.encode(key)) + for field_key, field_value in value.items(): + num_tokens += len(encoding.encode(field_key)) + if field_key == 'enum': + for enum_field in field_value: + num_tokens += 3 + num_tokens += len(encoding.encode(enum_field)) + else: + num_tokens += len(encoding.encode(field_key)) + num_tokens += len(encoding.encode(str(field_value))) + if 'required' in parameters: + num_tokens += len(encoding.encode('required')) + for required_field in parameters['required']: + num_tokens += 3 + num_tokens += len(encoding.encode(required_field)) + + return num_tokens diff --git a/api/core/agent/agent/openai_function_call_summarize_mixin.py b/api/core/agent/agent/openai_function_call_summarize_mixin.py deleted file mode 100644 index f436346e24535b..00000000000000 --- a/api/core/agent/agent/openai_function_call_summarize_mixin.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import cast, List - -from langchain.chat_models import ChatOpenAI -from langchain.chat_models.openai import _convert_message_to_dict -from langchain.memory.summary import SummarizerMixin -from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage -from langchain.schema.language_model import BaseLanguageModel -from pydantic import BaseModel - -from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin -from core.model_providers.models.llm.base import BaseLLM - - -class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): - moving_summary_buffer: str = "" - moving_summary_index: int = 0 - summary_llm: BaseLanguageModel = None - model_instance: BaseLLM - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: - # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 - rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) - rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens - if rest_tokens >= 0: - return messages - - system_message = None - human_message = None - should_summary_messages = [] - for message in messages: - if isinstance(message, SystemMessage): - system_message = message - elif isinstance(message, HumanMessage): - human_message = message - else: - should_summary_messages.append(message) - - if len(should_summary_messages) > 2: - ai_message = should_summary_messages[-2] - function_message = should_summary_messages[-1] - should_summary_messages = should_summary_messages[self.moving_summary_index:-2] - self.moving_summary_index = len(should_summary_messages) - else: - error_msg = "Exceeded LLM tokens limit, stopped." - raise ExceededLLMTokensLimitError(error_msg) - - new_messages = [system_message, human_message] - - if self.moving_summary_index == 0: - should_summary_messages.insert(0, human_message) - - summary_handler = SummarizerMixin(llm=self.summary_llm) - self.moving_summary_buffer = summary_handler.predict_new_summary( - messages=should_summary_messages, - existing_summary=self.moving_summary_buffer - ) - - new_messages.append(AIMessage(content=self.moving_summary_buffer)) - new_messages.append(ai_message) - new_messages.append(function_message) - - return new_messages - - def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - llm = cast(ChatOpenAI, model_instance.client) - model, encoding = llm._get_encoding_model() - if model.startswith("gpt-3.5-turbo"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-4"): - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - for m in messages: - message = _convert_message_to_dict(m) - num_tokens += tokens_per_message - for key, value in message.items(): - if key == "function_call": - for f_key, f_value in value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(value)) - - if key == "name": - num_tokens += tokens_per_name - # every reply is primed with assistant - num_tokens += 3 - - if kwargs.get('functions'): - for function in kwargs.get('functions'): - num_tokens += len(encoding.encode('name')) - num_tokens += len(encoding.encode(function.get("name"))) - num_tokens += len(encoding.encode('description')) - num_tokens += len(encoding.encode(function.get("description"))) - parameters = function.get("parameters") - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): - num_tokens += len(encoding.encode(key)) - for field_key, field_value in value.items(): - num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': - for enum_field in field_value: - num_tokens += 3 - num_tokens += len(encoding.encode(enum_field)) - else: - num_tokens += len(encoding.encode(field_key)) - num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: - num_tokens += 3 - num_tokens += len(encoding.encode(required_field)) - - return num_tokens diff --git a/api/core/agent/agent/openai_multi_function_call.py b/api/core/agent/agent/openai_multi_function_call.py deleted file mode 100644 index cbbd24dbded5b2..00000000000000 --- a/api/core/agent/agent/openai_multi_function_call.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import List, Tuple, Any, Union, Sequence, Optional - -from langchain.agents import BaseMultiActionAgent -from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \ - _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage -from langchain.schema.language_model import BaseLanguageModel -from langchain.tools import BaseTool - -from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError -from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin - - -class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin): - - @classmethod - def from_llm_and_tools( - cls, - llm: BaseLanguageModel, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, - system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." - ), - **kwargs: Any, - ) -> BaseMultiActionAgent: - return super().from_llm_and_tools( - llm=llm, - tools=tools, - callback_manager=callback_manager, - extra_prompt_messages=extra_prompt_messages, - system_message=cls.get_system_message(), - **kwargs, - ) - - def should_use_agent(self, query: str): - """ - return should use agent - - :param query: - :return: - """ - original_max_tokens = self.llm.max_tokens - self.llm.max_tokens = 15 - - prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) - messages = prompt.to_messages() - - try: - predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=None - ) - except Exception as e: - new_exception = self.model_instance.handle_exceptions(e) - raise new_exception - - function_call = predicted_message.additional_kwargs.get("function_call", {}) - - self.llm.max_tokens = original_max_tokens - - return True if function_call else False - - def plan( - self, - intermediate_steps: List[Tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) - selected_inputs = { - k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" - } - full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) - prompt = self.prompt.format_prompt(**full_inputs) - messages = prompt.to_messages() - - # summarize messages if rest_tokens < 0 - try: - messages = self.summarize_messages_if_needed(messages, functions=self.functions) - except ExceededLLMTokensLimitError as e: - return AgentFinish(return_values={"output": str(e)}, log=str(e)) - - predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks - ) - agent_decision = _parse_ai_message(predicted_message) - return agent_decision - - @classmethod - def get_system_message(cls): - # get current time - return SystemMessage(content="You are a helpful AI assistant.\n" - "The current date or current time you know is wrong.\n" - "Respond directly if appropriate.") diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 8d682b59d5bc08..6b2889f45a8ed8 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -4,7 +4,6 @@ from langchain import BasePromptTemplate from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate @@ -12,6 +11,7 @@ from langchain.tools import BaseTool from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX +from core.chain.llm_chain import LLMChain from core.model_providers.models.llm.base import BaseLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -49,7 +49,6 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): - model_instance: BaseLLM dataset_tools: Sequence[BaseTool] class Config: @@ -98,7 +97,7 @@ def plan( try: full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) except Exception as e: - new_exception = self.model_instance.handle_exceptions(e) + new_exception = self.llm_chain.model_instance.handle_exceptions(e) raise new_exception try: @@ -145,7 +144,7 @@ def create_prompt( @classmethod def from_llm_and_tools( cls, - llm: BaseLanguageModel, + model_instance: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, @@ -157,17 +156,28 @@ def from_llm_and_tools( memory_prompts: Optional[List[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: - return super().from_llm_and_tools( - llm=llm, - tools=tools, - callback_manager=callback_manager, - output_parser=output_parser, + """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) + prompt = cls.create_prompt( + tools, prefix=prefix, suffix=suffix, human_message_template=human_message_template, format_instructions=format_instructions, input_variables=input_variables, memory_prompts=memory_prompts, + ) + llm_chain = LLMChain( + model_instance=model_instance, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + _output_parser = output_parser + return cls( + llm_chain=llm_chain, + allowed_tools=tool_names, + output_parser=_output_parser, dataset_tools=tools, **kwargs, ) diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 77635273ea2573..2c36bc38d37300 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -4,16 +4,17 @@ from langchain import BasePromptTemplate from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks -from langchain.memory.summary import SummarizerMixin +from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate -from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException +from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \ + get_buffer_string from langchain.tools import BaseTool from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError +from core.chain.llm_chain import LLMChain from core.model_providers.models.llm.base import BaseLLM FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). @@ -52,8 +53,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): moving_summary_buffer: str = "" moving_summary_index: int = 0 - summary_llm: BaseLanguageModel = None - model_instance: BaseLLM + summary_model_instance: BaseLLM = None class Config: """Configuration for this pydantic object.""" @@ -95,14 +95,14 @@ def plan( if prompts: messages = prompts[0].to_messages() - rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) + rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages) if rest_tokens < 0: full_inputs = self.summarize_messages(intermediate_steps, **kwargs) try: full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) except Exception as e: - new_exception = self.model_instance.handle_exceptions(e) + new_exception = self.llm_chain.model_instance.handle_exceptions(e) raise new_exception try: @@ -118,7 +118,7 @@ def plan( "I don't know how to respond to that."}, "") def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): - if len(intermediate_steps) >= 2 and self.summary_llm: + if len(intermediate_steps) >= 2 and self.summary_model_instance: should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_messages = [AIMessage(content=observation) for _, observation in should_summary_intermediate_steps] @@ -130,11 +130,10 @@ def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], error_msg = "Exceeded LLM tokens limit, stopped." raise ExceededLLMTokensLimitError(error_msg) - summary_handler = SummarizerMixin(llm=self.summary_llm) if self.moving_summary_buffer and 'chat_history' in kwargs: kwargs["chat_history"].pop() - self.moving_summary_buffer = summary_handler.predict_new_summary( + self.moving_summary_buffer = self.predict_new_summary( messages=should_summary_messages, existing_summary=self.moving_summary_buffer ) @@ -144,6 +143,18 @@ def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], return self.get_full_inputs([intermediate_steps[-1]], **kwargs) + def predict_new_summary( + self, messages: List[BaseMessage], existing_summary: str + ) -> str: + new_lines = get_buffer_string( + messages, + human_prefix="Human", + ai_prefix="AI", + ) + + chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT) + return chain.predict(summary=existing_summary, new_lines=new_lines) + @classmethod def create_prompt( cls, @@ -176,7 +187,7 @@ def create_prompt( @classmethod def from_llm_and_tools( cls, - llm: BaseLanguageModel, + model_instance: BaseLLM, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, output_parser: Optional[AgentOutputParser] = None, @@ -188,16 +199,27 @@ def from_llm_and_tools( memory_prompts: Optional[List[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: - return super().from_llm_and_tools( - llm=llm, - tools=tools, - callback_manager=callback_manager, - output_parser=output_parser, + """Construct an agent from an LLM and tools.""" + cls._validate_tools(tools) + prompt = cls.create_prompt( + tools, prefix=prefix, suffix=suffix, human_message_template=human_message_template, format_instructions=format_instructions, input_variables=input_variables, memory_prompts=memory_prompts, + ) + llm_chain = LLMChain( + model_instance=model_instance, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + _output_parser = output_parser + return cls( + llm_chain=llm_chain, + allowed_tools=tool_names, + output_parser=_output_parser, **kwargs, ) diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 903203d87b63a6..05c4b632ffb2ed 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -10,7 +10,6 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent -from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent @@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum): REACT_ROUTER = 'react_router' REACT = 'react' FUNCTION_CALL = 'function_call' - MULTI_FUNCTION_CALL = 'multi_function_call' class AgentConfiguration(BaseModel): @@ -64,30 +62,18 @@ def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: if self.configuration.strategy == PlanningStrategy.REACT: agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, - llm=self.configuration.model_instance.client, tools=self.configuration.tools, output_parser=StructuredChatOutputParser(), - summary_llm=self.configuration.summary_model_instance.client + summary_model_instance=self.configuration.summary_model_instance if self.configuration.summary_model_instance else None, verbose=True ) elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, - llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_model_instance.client - if self.configuration.summary_model_instance else None, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: - agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( - model_instance=self.configuration.model_instance, - llm=self.configuration.model_instance.client, - tools=self.configuration.tools, - extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_model_instance.client + summary_model_instance=self.configuration.summary_model_instance if self.configuration.summary_model_instance else None, verbose=True ) @@ -95,7 +81,6 @@ def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] agent = MultiDatasetRouterAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, - llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, verbose=True @@ -104,7 +89,6 @@ def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( model_instance=self.configuration.model_instance, - llm=self.configuration.model_instance.client, tools=self.configuration.tools, output_parser=StructuredChatOutputParser(), verbose=True diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py new file mode 100644 index 00000000000000..2a5b4b61a948e2 --- /dev/null +++ b/api/core/chain/llm_chain.py @@ -0,0 +1,36 @@ +from typing import List, Dict, Any, Optional + +from langchain import LLMChain as LCLLMChain +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.schema import LLMResult, Generation +from langchain.schema.language_model import BaseLanguageModel + +from core.model_providers.models.entity.message import to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM +from core.third_party.langchain.llms.fake import FakeLLM + + +class LLMChain(LCLLMChain): + model_instance: BaseLLM + """The language model instance to use.""" + llm: BaseLanguageModel = FakeLLM(response="") + + def generate( + self, + input_list: List[Dict[str, Any]], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> LLMResult: + """Generate LLM result from inputs.""" + prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) + messages = prompts[0].to_messages() + prompt_messages = to_prompt_messages(messages) + result = self.model_instance.run( + messages=prompt_messages, + stop=stop + ) + + generations = [ + [Generation(text=result.content)] + ] + + return LLMResult(generations=generations) diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py index 921bdcf1933aac..c37e88fac9ee4e 100644 --- a/api/core/model_providers/models/entity/message.py +++ b/api/core/model_providers/models/entity/message.py @@ -1,6 +1,6 @@ import enum -from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage +from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage from pydantic import BaseModel @@ -9,6 +9,7 @@ class LLMRunResult(BaseModel): prompt_tokens: int completion_tokens: int source: list = None + function_call: dict = None class MessageType(enum.Enum): @@ -20,6 +21,7 @@ class MessageType(enum.Enum): class PromptMessage(BaseModel): type: MessageType = MessageType.HUMAN content: str = '' + function_call: dict = None def to_lc_messages(messages: list[PromptMessage]): @@ -28,7 +30,10 @@ def to_lc_messages(messages: list[PromptMessage]): if message.type == MessageType.HUMAN: lc_messages.append(HumanMessage(content=message.content)) elif message.type == MessageType.ASSISTANT: - lc_messages.append(AIMessage(content=message.content)) + additional_kwargs = {} + if message.function_call: + additional_kwargs['function_call'] = message.function_call + lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs)) elif message.type == MessageType.SYSTEM: lc_messages.append(SystemMessage(content=message.content)) @@ -41,9 +46,19 @@ def to_prompt_messages(messages: list[BaseMessage]): if isinstance(message, HumanMessage): prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) elif isinstance(message, AIMessage): - prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) + message_kwargs = { + 'content': message.content, + 'type': MessageType.ASSISTANT + } + + if 'function_call' in message.additional_kwargs: + message_kwargs['function_call'] = message.additional_kwargs['function_call'] + + prompt_messages.append(PromptMessage(**message_kwargs)) elif isinstance(message, SystemMessage): prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) + elif isinstance(message, FunctionMessage): + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) return prompt_messages diff --git a/api/core/model_providers/models/llm/azure_openai_model.py b/api/core/model_providers/models/llm/azure_openai_model.py index d97330ae3b27cc..1ef5ac31d471f6 100644 --- a/api/core/model_providers/models/llm/azure_openai_model.py +++ b/api/core/model_providers/models/llm/azure_openai_model.py @@ -81,7 +81,20 @@ def _run(self, messages: List[PromptMessage], :return: """ prompts = self._get_prompt_from_messages(messages) - return self._client.generate([prompts], stop, callbacks) + generate_kwargs = { + 'stop': stop, + 'callbacks': callbacks + } + + if isinstance(prompts, str): + generate_kwargs['prompts'] = [prompts] + else: + generate_kwargs['messages'] = [prompts] + + if 'functions' in kwargs: + generate_kwargs['functions'] = kwargs['functions'] + + return self._client.generate(**generate_kwargs) @property def base_model_name(self) -> str: diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index c7cec88ff02c59..7224bf714191db 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -13,7 +13,8 @@ from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.helper import moderation from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages +from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \ + to_lc_messages from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.providers.base import BaseModelProvider from core.prompt.prompt_builder import PromptBuilder @@ -157,8 +158,11 @@ def run(self, messages: List[PromptMessage], except Exception as ex: raise self.handle_exceptions(ex) + function_call = None if isinstance(result.generations[0][0], ChatGeneration): completion_content = result.generations[0][0].message.content + if 'function_call' in result.generations[0][0].message.additional_kwargs: + function_call = result.generations[0][0].message.additional_kwargs.get('function_call') else: completion_content = result.generations[0][0].text @@ -191,7 +195,8 @@ def run(self, messages: List[PromptMessage], return LLMRunResult( content=completion_content, prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens + completion_tokens=completion_tokens, + function_call=function_call ) @abstractmethod @@ -442,16 +447,7 @@ def _get_prompt_from_messages(self, messages: List[PromptMessage], if len(messages) == 0: return [] - chat_messages = [] - for message in messages: - if message.type == MessageType.HUMAN: - chat_messages.append(HumanMessage(content=message.content)) - elif message.type == MessageType.ASSISTANT: - chat_messages.append(AIMessage(content=message.content)) - elif message.type == MessageType.SYSTEM: - chat_messages.append(SystemMessage(content=message.content)) - - return chat_messages + return to_lc_messages(messages) def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: """ diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py index d65efc96418dcc..9c2cd6428e943f 100644 --- a/api/core/model_providers/models/llm/openai_model.py +++ b/api/core/model_providers/models/llm/openai_model.py @@ -106,7 +106,21 @@ def _run(self, messages: List[PromptMessage], raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") prompts = self._get_prompt_from_messages(messages) - return self._client.generate([prompts], stop, callbacks) + + generate_kwargs = { + 'stop': stop, + 'callbacks': callbacks + } + + if isinstance(prompts, str): + generate_kwargs['prompts'] = [prompts] + else: + generate_kwargs['messages'] = [prompts] + + if 'functions' in kwargs: + generate_kwargs['functions'] = kwargs['functions'] + + return self._client.generate(**generate_kwargs) def get_num_tokens(self, messages: List[PromptMessage]) -> int: """ diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 5110650ec32a17..36d04ada39a4fe 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -1,7 +1,6 @@ import math from typing import Optional -from flask import current_app from langchain import WikipediaAPIWrapper from langchain.callbacks.manager import Callbacks from langchain.memory.chat_memory import BaseChatMemory @@ -27,7 +26,6 @@ from extensions.ext_database import db from models.dataset import Dataset, DatasetProcessRule from models.model import AppModelConfig -from models.provider import ProviderType class OrchestratorRuleParser: @@ -77,7 +75,7 @@ def to_agent_executor(self, conversation_message_task: ConversationMessageTask, # only OpenAI chat model (include Azure) support function call, use ReACT instead if agent_model_instance.model_mode != ModelMode.CHAT \ or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']: - if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: + if planning_strategy == PlanningStrategy.FUNCTION_CALL: planning_strategy = PlanningStrategy.REACT elif planning_strategy == PlanningStrategy.ROUTER: planning_strategy = PlanningStrategy.REACT_ROUTER @@ -207,7 +205,10 @@ def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, tool = self.to_current_datetime_tool() if tool: - tool.callbacks.extend(callbacks) + if tool.callbacks is not None: + tool.callbacks.extend(callbacks) + else: + tool.callbacks = callbacks tools.append(tool) return tools @@ -269,10 +270,9 @@ def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool summary_model_instance = None tool = WebReaderTool( - llm=summary_model_instance.client if summary_model_instance else None, + model_instance=summary_model_instance if summary_model_instance else None, max_chunk_length=4000, - continue_reading=True, - callbacks=[DifyStdOutCallbackHandler()] + continue_reading=True ) return tool @@ -290,16 +290,13 @@ def to_google_search_tool(self) -> Optional[BaseTool]: "is not up to date. " "Input should be a search query.", func=OptimizedSerpAPIWrapper(**func_kwargs).run, - args_schema=OptimizedSerpAPIInput, - callbacks=[DifyStdOutCallbackHandler()] + args_schema=OptimizedSerpAPIInput ) return tool def to_current_datetime_tool(self) -> Optional[BaseTool]: - tool = DatetimeTool( - callbacks=[DifyStdOutCallbackHandler()] - ) + tool = DatetimeTool() return tool @@ -310,8 +307,7 @@ class WikipediaInput(BaseModel): return WikipediaQueryRun( name="wikipedia", api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000), - args_schema=WikipediaInput, - callbacks=[DifyStdOutCallbackHandler()] + args_schema=WikipediaInput ) @classmethod diff --git a/api/core/tool/web_reader_tool.py b/api/core/tool/web_reader_tool.py index c2e68bc7a8331b..64f3ed85155870 100644 --- a/api/core/tool/web_reader_tool.py +++ b/api/core/tool/web_reader_tool.py @@ -11,8 +11,8 @@ import requests from bs4 import BeautifulSoup, NavigableString, Comment, CData -from langchain.base_language import BaseLanguageModel -from langchain.chains.summarize import load_summarize_chain +from langchain.chains import RefineDocumentsChain +from langchain.chains.summarize import refine_prompts from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.tools.base import BaseTool @@ -20,8 +20,10 @@ from pydantic import BaseModel, Field from regex import regex +from core.chain.llm_chain import LLMChain from core.data_loader import file_extractor from core.data_loader.file_extractor import FileExtractor +from core.model_providers.models.llm.base import BaseLLM FULL_TEMPLATE = """ TITLE: {title} @@ -65,7 +67,7 @@ class WebReaderTool(BaseTool): summary_chunk_overlap: int = 0 summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] continue_reading: bool = True - llm: BaseLanguageModel = None + model_instance: BaseLLM = None def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: try: @@ -78,7 +80,7 @@ def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: except Exception as e: return f'Read this website failed, caused by: {str(e)}.' - if summary and self.llm: + if summary and self.model_instance: character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( chunk_size=self.summary_chunk_tokens, chunk_overlap=self.summary_chunk_overlap, @@ -95,10 +97,9 @@ def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: if len(docs) > 5: docs = docs[:5] - chain = load_summarize_chain(self.llm, chain_type="refine", callbacks=self.callbacks) + chain = self.get_summary_chain() try: page_contents = chain.run(docs) - # todo use cache except Exception as e: return f'Read this website failed, caused by: {str(e)}.' else: @@ -114,6 +115,23 @@ def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: async def _arun(self, url: str) -> str: raise NotImplementedError + def get_summary_chain(self) -> RefineDocumentsChain: + initial_chain = LLMChain( + model_instance=self.model_instance, + prompt=refine_prompts.PROMPT + ) + refine_chain = LLMChain( + model_instance=self.model_instance, + prompt=refine_prompts.REFINE_PROMPT + ) + return RefineDocumentsChain( + initial_llm_chain=initial_chain, + refine_llm_chain=refine_chain, + document_variable_name="text", + initial_response_name="existing_answer", + callbacks=self.callbacks + ) + def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`."""