diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index af0075ea9154fc..45b72910b090cc 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -7,8 +7,10 @@ from core.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, + ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_transform import PromptTransform @@ -17,6 +19,7 @@ class AgentHistoryPromptTransform(PromptTransform): """ History Prompt Transform for Agent App """ + def __init__(self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage], @@ -39,15 +42,28 @@ def get_prompt(self) -> list[PromptMessage]: if not self.memory: return prompt_messages - max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) + max_token_limit = self._calculate_rest_token( + self.prompt_messages, self.model_config) model_type_instance = self.model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) + # if model does not suppot tool call, filter tool prompt out while calculating tokens count + tool_feature_set = { + ModelFeature.TOOL_CALL, + ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL} + support_tool = tool_feature_set & set( + self.model_config.model_schema.features or []) + # history messages for calculating tokens count + histories = self.history_messages + if not support_tool: + histories = [ + msg for msg in histories if not isinstance(msg, ToolPromptMessage)] + curr_message_tokens = model_type_instance.get_num_tokens( self.memory.model_instance.model, self.memory.model_instance.credentials, - self.history_messages + histories ) if curr_message_tokens <= max_token_limit: return self.history_messages