From 0c5892bcb6a028c22a10a4db1678d450583f9e12 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 4 Jan 2024 10:39:21 +0800 Subject: [PATCH] fix: zhipuai chatglm turbo prompts must user, assistant in sequence (#1899) --- .../model_providers/zhipuai/llm/llm.py | 34 ++++++++++++++++--- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 4a6cd7d101d0ac..6624a41e83f6bf 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -8,8 +8,9 @@ Union ) -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, AssistantPromptMessage, \ - SystemPromptMessage +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, UserPromptMessage, \ + AssistantPromptMessage, \ + SystemPromptMessage, PromptMessageRole from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \ LLMResultChunkDelta from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -111,16 +112,39 @@ def _generate(self, model: str, credentials_kwargs: dict, if len(prompt_messages) == 0: raise ValueError('At least one message is required') - if prompt_messages[0].role.value == 'system': + if prompt_messages[0].role == PromptMessageRole.SYSTEM: if not prompt_messages[0].content: prompt_messages = prompt_messages[1:] + # resolve zhipuai model not support system message and user message, assistant message must be in sequence + new_prompt_messages = [] + for prompt_message in prompt_messages: + copy_prompt_message = prompt_message.copy() + if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if not isinstance(copy_prompt_message.content, str): + # not support image message + continue + + if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER: + new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content + else: + if copy_prompt_message.role == PromptMessageRole.USER: + new_prompt_messages.append(copy_prompt_message) + else: + new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) + new_prompt_messages.append(new_prompt_message) + else: + if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.ASSISTANT: + new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content + else: + new_prompt_messages.append(copy_prompt_message) + params = { 'model': model, 'prompt': [{ - 'role': prompt_message.role.value if prompt_message.role.value != 'system' else 'user', + 'role': prompt_message.role.value, 'content': prompt_message.content - } for prompt_message in prompt_messages], + } for prompt_message in new_prompt_messages], **model_parameters }