From 9af8a59da0d5147ed024e4dd3f94c4eb6d7c9304 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 28 Dec 2023 19:20:38 +0800 Subject: [PATCH 1/2] fix: missing system prompt, missing arguments --- .../model_providers/zhipuai/llm/llm.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 60ed972a614fd5..4540413897b2eb 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -108,6 +108,13 @@ def _generate(self, model: str, credentials_kwargs: dict, api_key=credentials_kwargs['api_key'] ) + if len(prompt_messages) == 0: + raise ValueError('At least one message is required') + + if prompt_messages[0].role.value == 'system': + if not prompt_messages[0].content: + prompt_messages = prompt_messages[1:] + params = { 'model': model, 'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], @@ -116,12 +123,14 @@ def _generate(self, model: str, credentials_kwargs: dict, if stream: response = client.sse_invoke(incremental=True, **params).events() - return self._handle_generate_stream_response(model, response, prompt_messages) + return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages) response = client.invoke(**params) - return self._handle_generate_response(model, response, prompt_messages) + return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages) - def _handle_generate_response(self, model: str, response: Dict[str, Any], + def _handle_generate_response(self, model: str, + credentials: dict, + response: Dict[str, Any], prompt_messages: list[PromptMessage]) -> LLMResult: """ Handle llm response @@ -144,18 +153,21 @@ def _handle_generate_response(self, model: str, response: Dict[str, Any], token_usage['completion_tokens'] = token_usage['total_tokens'] # transform usage - usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens']) + usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) # transform response result = LLMResult( model=model, + prompt_messages=prompt_messages, message=AssistantPromptMessage(content=text), usage=usage, ) return result - def _handle_generate_stream_response(self, model: str, responses: list[Generator], + def _handle_generate_stream_response(self, model: str, + credentials: dict, + responses: list[Generator], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -168,6 +180,7 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator for index, event in enumerate(responses): if event.event == "add": yield LLMResultChunk( + prompt_messages=prompt_messages, model=model, delta=LLMResultChunkDelta( index=index, @@ -187,10 +200,11 @@ def _handle_generate_stream_response(self, model: str, responses: list[Generator if 'completion_tokens' not in token_usage: token_usage['completion_tokens'] = token_usage['total_tokens'] - usage = self._calc_response_usage(model, token_usage['prompt_tokens'], token_usage['completion_tokens']) + usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) yield LLMResultChunk( model=model, + prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=index, message=AssistantPromptMessage(content=event.data), From f76a00f80cbf1b76a17c028b2ea7b14e92ae8ef3 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 28 Dec 2023 19:33:53 +0800 Subject: [PATCH 2/2] fix: baichuan max usage --- .../model_providers/baichuan/llm/baichuan2-53b.yaml | 2 +- .../model_providers/baichuan/llm/baichuan2-turbo.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml index f26b1937460381..57a433a058851a 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-53b.yaml @@ -24,7 +24,7 @@ parameter_rules: - name: max_tokens use_template: max_tokens required: true - default: 4000 + default: 1000 min: 1 max: 4000 - name: presence_penalty diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml index b26458a7f568d9..004776df67314d 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan2-turbo.yaml @@ -24,9 +24,9 @@ parameter_rules: - name: max_tokens use_template: max_tokens required: true - default: 192000 + default: 8000 min: 1 - max: 8000 + max: 192000 - name: presence_penalty use_template: presence_penalty - name: frequency_penalty