Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
zxhlyh committed Jan 2, 2024
2 parents e1305ad + 89c51e0 commit a99b1bd
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 41 deletions.
6 changes: 3 additions & 3 deletions api/core/model_runtime/model_providers/spark/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def _to_credential_kwargs(self, credentials: dict) -> dict:
:return:
"""
credentials_kwargs = {
"app_id": credentials['spark_app_id'],
"api_secret": credentials['spark_api_secret'],
"api_key": credentials['spark_api_key'],
"app_id": credentials['app_id'],
"api_secret": credentials['api_secret'],
"api_key": credentials['api_key'],
}

return credentials_kwargs
Expand Down
8 changes: 4 additions & 4 deletions api/core/model_runtime/model_providers/spark/spark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: spark_app_id
- variable: app_id
label:
en_US: APPID
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 APPID
en_US: Enter your APPID
- variable: spark_api_secret
- variable: api_secret
label:
en_US: APISecret
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 APISecret
en_US: Enter your APISecret
- variable: spark_api_key
- variable: api_key
label:
en_US: APIKey
type: secret-input
Expand Down
3 changes: 3 additions & 0 deletions api/core/model_runtime/model_providers/tongyi/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _generate(self, model: str, credentials: dict,

dashscope.api_key = credentials_kwargs['api_key']

print(credentials_kwargs, 'credentials_kwargs')

client = EnhanceTongyi(
model_name=model,
streaming=stream,
Expand Down Expand Up @@ -219,6 +221,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict:
:param credentials:
:return:
"""
print(credentials, 'credentials')
credentials_kwargs = {
"api_key": credentials['dashscope_api_key'],
}
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/zhipuai/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict:
:return:
"""
credentials_kwargs = {
"api_key": credentials['zhipuai_api_key'],
"api_key": credentials['api_key'],
}

return credentials_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: zhipuai_api_key
- variable: api_key
label:
en_US: APIKey
type: text-input
Expand Down
26 changes: 13 additions & 13 deletions api/tests/integration_tests/model_runtime/spark/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ def test_validate_credentials():
model.validate_credentials(
model='spark-1.5',
credentials={
'spark_app_id': 'invalid_key'
'app_id': 'invalid_key'
}
)

model.validate_credentials(
model='spark-1.5',
credentials={
'spark_app_id': os.environ.get('SPARK_APP_ID'),
'spark_api_secret': os.environ.get('SPARK_API_SECRET'),
'spark_api_key': os.environ.get('SPARK_API_KEY')
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
}
)

Expand All @@ -37,9 +37,9 @@ def test_invoke_model():
response = model.invoke(
model='spark-1.5',
credentials={
'spark_app_id': os.environ.get('SPARK_APP_ID'),
'spark_api_secret': os.environ.get('SPARK_API_SECRET'),
'spark_api_key': os.environ.get('SPARK_API_KEY')
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand All @@ -65,9 +65,9 @@ def test_invoke_stream_model():
response = model.invoke(
model='spark-1.5',
credentials={
'spark_app_id': os.environ.get('SPARK_APP_ID'),
'spark_api_secret': os.environ.get('SPARK_API_SECRET'),
'spark_api_key': os.environ.get('SPARK_API_KEY')
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand Down Expand Up @@ -97,9 +97,9 @@ def test_get_num_tokens():
num_tokens = model.get_num_tokens(
model='spark-1.5',
credentials={
'spark_app_id': os.environ.get('SPARK_APP_ID'),
'spark_api_secret': os.environ.get('SPARK_API_SECRET'),
'spark_api_key': os.environ.get('SPARK_API_KEY')
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_validate_provider_credentials():

provider.validate_provider_credentials(
credentials={
'spark_app_id': os.environ.get('SPARK_APP_ID'),
'spark_api_secret': os.environ.get('SPARK_API_SECRET'),
'spark_api_key': os.environ.get('SPARK_API_KEY')
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
}
)
10 changes: 5 additions & 5 deletions api/tests/integration_tests/model_runtime/tongyi/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_validate_credentials():
model.validate_credentials(
model='qwen-turbo',
credentials={
'tongyi_api_key': 'invalid_key'
'dashscope_api_key': 'invalid_key'
}
)

model.validate_credentials(
model='qwen-turbo',
credentials={
'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
)

Expand All @@ -35,7 +35,7 @@ def test_invoke_model():
response = model.invoke(
model='qwen-turbo',
credentials={
'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand All @@ -61,7 +61,7 @@ def test_invoke_stream_model():
response = model.invoke(
model='qwen-turbo',
credentials={
'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_get_num_tokens():
num_tokens = model.get_num_tokens(
model='qwen-turbo',
credentials={
'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def test_validate_provider_credentials():

provider.validate_provider_credentials(
credentials={
'tongyi_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
)
10 changes: 5 additions & 5 deletions api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def test_validate_credentials():
model.validate_credentials(
model='chatglm_turbo',
credentials={
'zhipuai_api_key': 'invalid_key'
'api_key': 'invalid_key'
}
)

model.validate_credentials(
model='chatglm_turbo',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)

Expand All @@ -35,7 +35,7 @@ def test_invoke_model():
response = model.invoke(
model='chatglm_turbo',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand All @@ -61,7 +61,7 @@ def test_invoke_stream_model():
response = model.invoke(
model='chatglm_turbo',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_get_num_tokens():
num_tokens = model.get_num_tokens(
model='chatglm_turbo',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def test_validate_provider_credentials():

provider.validate_provider_credentials(
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def test_validate_credentials():
model.validate_credentials(
model='text_embedding',
credentials={
'zhipuai_api_key': 'invalid_key'
'api_key': 'invalid_key'
}
)

model.validate_credentials(
model='text_embedding',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)

Expand All @@ -32,7 +32,7 @@ def test_invoke_model():
result = model.invoke(
model='text_embedding',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
Expand All @@ -52,7 +52,7 @@ def test_get_num_tokens():
num_tokens = model.get_num_tokens(
model='text_embedding',
credentials={
'zhipuai_api_key': os.environ.get('ZHIPUAI_API_KEY')
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
Expand Down

0 comments on commit a99b1bd

Please sign in to comment.