Skip to content

Commit

Permalink
feat: fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Oct 18, 2023
1 parent 905c59c commit ae74798
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
4 changes: 3 additions & 1 deletion api/core/model_providers/providers/wenxin_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from json import JSONDecodeError
from typing import Type

from langchain.schema import HumanMessage

from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
Expand Down Expand Up @@ -117,7 +119,7 @@ def is_provider_credentials_valid_or_raise(cls, credentials: dict):
**credential_kwargs
)

llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

Expand Down
5 changes: 2 additions & 3 deletions api/tests/integration_tests/models/llm/test_wenxin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)

model = get_mock_model('ernie-bot')
messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
messages = [PromptMessage(type=MessageType.USER, content='Human: 1 + 1=? \nAssistant: Integer answer is:')]
rst = model.run(
messages,
stop=['\nHuman:'],
messages
)
assert len(rst.content) > 0
5 changes: 4 additions & 1 deletion api/tests/unit_tests/model_providers/test_wenxin_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from unittest.mock import patch
import json

from langchain.schema import AIMessage, ChatGeneration, ChatResult

from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.wenxin_provider import WenxinProvider
from models.provider import ProviderType, Provider
Expand All @@ -24,7 +26,8 @@ def decrypt_side_effect(tenant_id, encrypted_key):


def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc")
mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))

MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)

Expand Down

0 comments on commit ae74798

Please sign in to comment.