Skip to content

Commit

Permalink
fix: error in agent mode while using models not supporting tools
Browse files Browse the repository at this point in the history
  • Loading branch information
sinomoe committed Jun 11, 2024
1 parent 5986841 commit 07a414f
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions api/core/prompt/agent_history_prompt_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_transform import PromptTransform

Expand All @@ -17,6 +19,7 @@ class AgentHistoryPromptTransform(PromptTransform):
"""
History Prompt Transform for Agent App
"""

def __init__(self,
model_config: ModelConfigWithCredentialsEntity,
prompt_messages: list[PromptMessage],
Expand All @@ -39,15 +42,28 @@ def get_prompt(self) -> list[PromptMessage]:
if not self.memory:
return prompt_messages

max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
max_token_limit = self._calculate_rest_token(
self.prompt_messages, self.model_config)

model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

# if model does not suppot tool call, filter tool prompt out while calculating tokens count
tool_feature_set = {
ModelFeature.TOOL_CALL,
ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
support_tool = tool_feature_set & set(
self.model_config.model_schema.features or [])
# history messages for calculating tokens count
histories = self.history_messages
if not support_tool:
histories = [
msg for msg in histories if not isinstance(msg, ToolPromptMessage)]

curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model,
self.memory.model_instance.credentials,
self.history_messages
histories
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
Expand Down

0 comments on commit 07a414f

Please sign in to comment.