From ee3579aeb849ed029cdc13101aed5fb3aa353384 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=B1=9F=E6=B3=A2?= Date: Tue, 24 Dec 2024 14:26:59 +0800 Subject: [PATCH] fix: o1 model error, use max_completion_tokens instead of max_tokens. --- .../model_providers/azure_openai/llm/llm.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c5d7a83a4ee69f..03818741f65875 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if "o1" in model: + if model.startswith("o1"): client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,10 @@ def _chat_generate( prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if "o1" in model: + if model.startswith("o1"): + if "max_tokens" in model_parameters: + model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] + del model_parameters["max_tokens"] if stream: block_as_stream = True stream = False @@ -404,7 +407,7 @@ def _clear_illegal_prompt_messages(self, model: str, prompt_messages: list[Promp ] ) - if "o1" in model: + if model.startswith("o1"): system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = []