Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: error in agent mode while using models not supporting tools #5057

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions api/core/prompt/agent_history_prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform

Expand All @@ -17,6 +19,7 @@ class AgentHistoryPromptTransform(PromptTransform):
"""
History Prompt Transform for Agent App
"""

def __init__(self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
Expand All @@ -39,15 +42,28 @@ def get_prompt(self) -> list[PromptMessage]:
if not self.memory:
return prompt_messages

max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
max_token_limit = self._calculate_rest_token(
self.prompt_messages, self.model_config)

model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

# if model does not suppot tool call, filter tool prompt out while calculating tokens count
tool_feature_set = {
ModelFeature.TOOL_CALL,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
support_tool = tool_feature_set & set(
self.model_config.model_schema.features or [])
# history messages for calculating tokens count
histories = self.history_messages
if not support_tool:
histories = [
msg for msg in histories if not isinstance(msg, ToolPromptMessage)]

curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
self.history_messages
histories
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from models.model import Conversation
Expand All @@ -25,13 +26,17 @@ def test_get_prompt():
SystemPromptMessage(content='System Prompt 1'),
UserPromptMessage(content='User Prompt 1'),
AssistantPromptMessage(content='Assistant Thought 1'),
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
ToolPromptMessage(content='Tool 1-1',
name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2',
name='Tool 1-2', tool_call_id='2'),
SystemPromptMessage(content='System Prompt 2'),
UserPromptMessage(content='User Prompt 2'),
AssistantPromptMessage(content='Assistant Thought 2'),
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
ToolPromptMessage(content='Tool 2-1',
name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2',
name='Tool 2-2', tool_call_id='4'),
UserPromptMessage(content='User Prompt 3'),
AssistantPromptMessage(content='Assistant Thought 3'),
]
Expand All @@ -40,15 +45,20 @@ def test_get_prompt():
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
large_language_model_mock.get_num_tokens = MagicMock(
side_effect=side_effect_get_num_tokens)

provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle)
provider_model_bundle_mock.model_type_instance = large_language_model_mock

model_entity_mock = MagicMock(spec=AIModelEntity)
model_entity_mock.features = [ModelFeature.TOOL_CALL]

model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = 'openai'
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock
model_config_mock.model_schema = model_entity_mock

memory = TokenBufferMemory(
conversation=Conversation(),
Expand Down