Skip to content

Commit

Permalink
Merge branch 'fix/run-start' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
zxhlyh committed Oct 19, 2023
2 parents 5378e40 + 211b379 commit 85fbd5c
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 88 deletions.
19 changes: 14 additions & 5 deletions api/core/model_providers/models/llm/wenxin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin


class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT

def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
Expand All @@ -38,7 +37,13 @@ def _run(self, messages: List[PromptMessage],
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
Expand All @@ -48,7 +53,7 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)

def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
Expand All @@ -58,3 +63,7 @@ def _set_model_kwargs(self, model_kwargs: ModelKwargs):

def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")

@property
def support_streaming(self):
return True
1 change: 1 addition & 0 deletions api/core/model_providers/models/llm/zhipuai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ZhipuAIModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
Expand Down
18 changes: 13 additions & 5 deletions api/core/model_providers/providers/wenxin_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from json import JSONDecodeError
from typing import Type

from langchain.schema import HumanMessage

from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
Expand All @@ -23,20 +25,25 @@ def provider_name(self):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
}
]
else:
Expand Down Expand Up @@ -68,11 +75,12 @@ def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> M
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}

if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
Expand Down Expand Up @@ -111,7 +119,7 @@ def is_provider_credentials_valid_or_raise(cls, credentials: dict):
**credential_kwargs
)

llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

Expand Down
6 changes: 6 additions & 0 deletions api/core/model_providers/rules/wenxin.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
Expand Down
Loading

0 comments on commit 85fbd5c

Please sign in to comment.