From c98d91e44d75cf03f395eee521e5af9a36a45ad8 Mon Sep 17 00:00:00 2001 From: jiangbo721 <365065261@qq.com> Date: Wed, 25 Dec 2024 13:29:43 +0800 Subject: [PATCH] fix: o1 model error, use max_completion_tokens instead of max_tokens. (#12037) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 刘江波 --- .../model_providers/azure_openai/llm/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c5d7a83a4ee69f..03818741f65875 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if "o1" in model: + if model.startswith("o1"): client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,10 @@ def _chat_generate( prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if "o1" in model: + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] if stream: block_as_stream = True stream = False @@ -404,7 +407,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp ] ) - if "o1" in model: + if model.startswith("o1"): system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = []