diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 60ed972a614fd5..4540413897b2eb 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -108,6 +108,13 @@ def _generate(self, model: str, credentials_kwargs: dict, api_key=credentials_kwargs['api_key'] ) + if len(prompt_messages) == 0: + raise ValueError('At least one message is required') + + if prompt_messages[0].role.value == 'system': + if not prompt_messages[0].content: + prompt_messages = prompt_messages[1:] + params = { 'model': model, 'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], @@ -116,12 +123,14 @@ def _generate(self, model: str, credentials_kwargs: dict, if stream: response = client.sse_invoke(incremental=True, **params).events() - return self._handle_generate_stream_response(model, response, prompt_messages) + return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages) response = client.invoke(**params) - return self._handle_generate_response(model, response, prompt_messages) + return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages) - def _handle_generate_response(self, model: str, response: Dict[str, Any], + def _handle_generate_response(self, model: str, + credentials: dict, + response: Dict[str, Any], prompt_messages: list[PromptMessage]) -> LLMResult: """ Handle llm response @@ -144,18 +153,21 @@ def _handle_generate_response(self, model: str, response: Dict[str, Any], token_usage['completion_tokens'] = token_usage['total_tokens'] # transform usage - usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens']) + usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) # transform response result = LLMResult( model=model, + prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, responses: list[Generator], + def _handle_generate_stream_response(self, model: str, + credentials: dict, + responses: list[Generator], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -168,6 +180,7 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator for index, event in enumerate(responses): if event.event == "add": yield LLMResultChunk( + prompt_messages=prompt_messages, model=model, delta=LLMResultChunkDelta( index=index, @@ -187,10 +200,11 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator if 'completion_tokens' not in token_usage: token_usage['completion_tokens'] = token_usage['total_tokens'] - usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens']) + usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) yield LLMResultChunk( model=model, + prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=event.data),