From 89c51e030d4eb2d81dba7f2977dad2fd0edefb06 Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Tue, 2 Jan 2024 17:27:43 +0800 Subject: [PATCH] fix: spark tongyi zhipu --- .../model_providers/spark/llm/llm.py | 6 ++--- .../model_providers/spark/spark.yaml | 8 +++--- .../model_providers/tongyi/llm/llm.py | 3 +++ .../model_providers/zhipuai/_common.py | 2 +- .../model_providers/zhipuai/zhipuai.yaml | 2 +- .../model_runtime/spark/test_llm.py | 26 +++++++++---------- .../model_runtime/spark/test_provider.py | 6 ++--- .../model_runtime/tongyi/test_llm.py | 10 +++---- .../model_runtime/tongyi/test_provider.py | 2 +- .../model_runtime/zhipuai/test_llm.py | 10 +++---- .../model_runtime/zhipuai/test_provider.py | 2 +- .../zhipuai/test_text_embedding.py | 8 +++--- 12 files changed, 44 insertions(+), 41 deletions(-) diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 6192af45b2a9c0..56f13cdbc33217 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -205,9 +205,9 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "app_id": credentials['spark_app_id'], - "api_secret": credentials['spark_api_secret'], - "api_key": credentials['spark_api_key'], + "app_id": credentials['app_id'], + "api_secret": credentials['api_secret'], + "api_key": credentials['api_key'], } return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/spark/spark.yaml b/api/core/model_runtime/model_providers/spark/spark.yaml index 1b6ee7c36193f3..a4f497b9f7d510 100644 --- a/api/core/model_runtime/model_providers/spark/spark.yaml +++ b/api/core/model_runtime/model_providers/spark/spark.yaml @@ -20,7 +20,7 @@ configurate_methods: - predefined-model provider_credential_schema: credential_form_schemas: - - variable: spark_app_id + - variable: app_id label: en_US: APPID type: text-input @@ -28,15 +28,15 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 APPID en_US: Enter your APPID - - variable: spark_api_secret + - variable: api_secret label: en_US: APISecret - type: text-input + type: secret-input required: true placeholder: zh_Hans: 在此输入您的 APISecret en_US: Enter your APISecret - - variable: spark_api_key + - variable: api_key label: en_US: APIKey type: secret-input diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index abd5471075bb44..cce740b53b7378 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -110,6 +110,8 @@ def _generate(self, model: str, credentials: dict, dashscope.api_key = credentials_kwargs['api_key'] + print(credentials_kwargs, 'credentials_kwargs') + client = EnhanceTongyi( model_name=model, streaming=stream, @@ -219,6 +221,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :param credentials: :return: """ + print(credentials, 'credentials') credentials_kwargs = { "api_key": credentials['dashscope_api_key'], } diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 9f18cd42f61922..19bf82b0c7de9d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -11,7 +11,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['zhipuai_api_key'], + "api_key": credentials['api_key'], } return credentials_kwargs diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai.yaml b/api/core/model_runtime/model_providers/zhipuai/zhipuai.yaml index 33a59bb97be7df..cbee3717a329ed 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai.yaml @@ -21,7 +21,7 @@ configurate_methods: - predefined-model provider_credential_schema: credential_form_schemas: - - variable: zhipuai_api_key + - variable: api_key label: en_US: APIKey type: text-input diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 3b0a8bd631f75b..2e3d775000eb7a 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -17,16 +17,16 @@ def test_validate_credentials(): model.validate_credentials( model='spark-1.5', credentials={ - 'spark_app_id': 'invalid_key' + 'app_id': 'invalid_key' } ) model.validate_credentials( model='spark-1.5', credentials={ - 'spark_app_id': os.environ.get('SPARK_APP_ID'), - 'spark_api_secret': os.environ.get('SPARK_API_SECRET'), - 'spark_api_key': os.environ.get('SPARK_API_KEY') + 'app_id': os.environ.get('SPARK_APP_ID'), + 'api_secret': os.environ.get('SPARK_API_SECRET'), + 'api_key': os.environ.get('SPARK_API_KEY') } ) @@ -37,9 +37,9 @@ def test_invoke_model(): response = model.invoke( model='spark-1.5', credentials={ - 'spark_app_id': os.environ.get('SPARK_APP_ID'), - 'spark_api_secret': os.environ.get('SPARK_API_SECRET'), - 'spark_api_key': os.environ.get('SPARK_API_KEY') + 'app_id': os.environ.get('SPARK_APP_ID'), + 'api_secret': os.environ.get('SPARK_API_SECRET'), + 'api_key': os.environ.get('SPARK_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -65,9 +65,9 @@ def test_invoke_stream_model(): response = model.invoke( model='spark-1.5', credentials={ - 'spark_app_id': os.environ.get('SPARK_APP_ID'), - 'spark_api_secret': os.environ.get('SPARK_API_SECRET'), - 'spark_api_key': os.environ.get('SPARK_API_KEY') + 'app_id': os.environ.get('SPARK_APP_ID'), + 'api_secret': os.environ.get('SPARK_API_SECRET'), + 'api_key': os.environ.get('SPARK_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -97,9 +97,9 @@ def test_get_num_tokens(): num_tokens = model.get_num_tokens( model='spark-1.5', credentials={ - 'spark_app_id': os.environ.get('SPARK_APP_ID'), - 'spark_api_secret': os.environ.get('SPARK_API_SECRET'), - 'spark_api_key': os.environ.get('SPARK_API_KEY') + 'app_id': os.environ.get('SPARK_APP_ID'), + 'api_secret': os.environ.get('SPARK_API_SECRET'), + 'api_key': os.environ.get('SPARK_API_KEY') }, prompt_messages=[ SystemPromptMessage( diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 7960713fb93f47..8e22815a86fc84 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -16,8 +16,8 @@ def test_validate_provider_credentials(): provider.validate_provider_credentials( credentials={ - 'spark_app_id': os.environ.get('SPARK_APP_ID'), - 'spark_api_secret': os.environ.get('SPARK_API_SECRET'), - 'spark_api_key': os.environ.get('SPARK_API_KEY') + 'app_id': os.environ.get('SPARK_APP_ID'), + 'api_secret': os.environ.get('SPARK_API_SECRET'), + 'api_key': os.environ.get('SPARK_API_KEY') } ) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index d7aa0b70c16220..65e57f700131e0 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -17,14 +17,14 @@ def test_validate_credentials(): model.validate_credentials( model='qwen-turbo', credentials={ - 'tongyi_api_key': 'invalid_key' + 'dashscope_api_key': 'invalid_key' } ) model.validate_credentials( model='qwen-turbo', credentials={ - 'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') + 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') } ) @@ -35,7 +35,7 @@ def test_invoke_model(): response = model.invoke( model='qwen-turbo', credentials={ - 'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') + 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -61,7 +61,7 @@ def test_invoke_stream_model(): response = model.invoke( model='qwen-turbo', credentials={ - 'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') + 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -92,7 +92,7 @@ def test_get_num_tokens(): num_tokens = model.get_num_tokens( model='qwen-turbo', credentials={ - 'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') + 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') }, prompt_messages=[ SystemPromptMessage( diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 918e5fd8766e12..6145c1dc37d00b 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -16,6 +16,6 @@ def test_validate_provider_credentials(): provider.validate_provider_credentials( credentials={ - 'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') + 'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY') } ) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 8f1f7c7a5d03d1..adcaa51b35e519 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -17,14 +17,14 @@ def test_validate_credentials(): model.validate_credentials( model='chatglm_turbo', credentials={ - 'zhipuai_api_key': 'invalid_key' + 'api_key': 'invalid_key' } ) model.validate_credentials( model='chatglm_turbo', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') } ) @@ -35,7 +35,7 @@ def test_invoke_model(): response = model.invoke( model='chatglm_turbo', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -61,7 +61,7 @@ def test_invoke_stream_model(): response = model.invoke( model='chatglm_turbo', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') }, prompt_messages=[ UserPromptMessage( @@ -91,7 +91,7 @@ def test_get_num_tokens(): num_tokens = model.get_num_tokens( model='chatglm_turbo', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') }, prompt_messages=[ SystemPromptMessage( diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index f29fab85f558b0..032e15e846af55 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -15,6 +15,6 @@ def test_validate_provider_credentials(): provider.validate_provider_credentials( credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') } ) diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 35dac24ba4a7fb..15a9307a328630 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -14,14 +14,14 @@ def test_validate_credentials(): model.validate_credentials( model='text_embedding', credentials={ - 'zhipuai_api_key': 'invalid_key' + 'api_key': 'invalid_key' } ) model.validate_credentials( model='text_embedding', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') } ) @@ -32,7 +32,7 @@ def test_invoke_model(): result = model.invoke( model='text_embedding', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') }, texts=[ "hello", @@ -52,7 +52,7 @@ def test_get_num_tokens(): num_tokens = model.get_num_tokens( model='text_embedding', credentials={ - 'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY') + 'api_key': os.environ.get('ZHIPUAI_API_KEY') }, texts=[ "hello",