diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index d57e9ca351202b..973eb6c741a0ca 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -40,7 +40,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 4096, }, parameter_rules=[ @@ -84,7 +84,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 16385, }, parameter_rules=[ @@ -128,7 +128,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 8192, }, parameter_rules=[ @@ -202,7 +202,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 32768, }, parameter_rules=[ @@ -276,7 +276,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 128000, }, parameter_rules=[ @@ -349,7 +349,7 @@ class AzureBaseModel(BaseModel): ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.CHAT, + 'mode': LLMMode.CHAT.value, 'context_size': 128000, }, parameter_rules=[ @@ -419,7 +419,7 @@ class AzureBaseModel(BaseModel): model_type=ModelType.LLM, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - 'mode': LLMMode.COMPLETION, + 'mode': LLMMode.COMPLETION.value, 'context_size': 4096, }, parameter_rules=[ diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 42ac2b56835809..d12e21a6e897a2 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -32,7 +32,7 @@ def _invoke(self, model: str, credentials: dict, ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: + if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model return self._chat_generate( model=model, @@ -62,7 +62,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr model_mode = self._get_ai_model_entity(credentials['base_model_name'], model).entity.model_properties.get( ModelPropertyKey.MODE) - if model_mode == LLMMode.CHAT: + if model_mode == LLMMode.CHAT.value: # chat model return self._num_tokens_from_messages(credentials, prompt_messages, tools) else: @@ -87,7 +87,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT: + if ai_model_entity.entity.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT.value: # chat model client.chat.completions.create( messages=[{"role": "user", "content": 'ping'}], diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 805a67db19e486..366950ad849784 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -97,7 +97,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ - 'mode': LLMMode.COMPLETION + 'mode': LLMMode.COMPLETION.value }, parameter_rules=self._get_customizable_model_parameter_rules() ) diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 8691833f0fa3c8..556fab977d9bc8 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -91,7 +91,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, model_properties={ - 'mode': model_type + 'mode': model_type.value }, parameter_rules=self._get_customizable_model_parameter_rules(model, credentials) )