Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support weixin ernie-bot-4 and chat mode #1375

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions api/core/model_providers/models/llm/wenxin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin


class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT

def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
Expand All @@ -38,7 +37,13 @@ def _run(self, messages: List[PromptMessage],
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
Expand All @@ -48,7 +53,7 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)

def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
Expand All @@ -58,3 +63,7 @@ def _set_model_kwargs(self, model_kwargs: ModelKwargs):

def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")

@property
def support_streaming(self):
return True
18 changes: 13 additions & 5 deletions api/core/model_providers/providers/wenxin_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from json import JSONDecodeError
from typing import Type

from langchain.schema import HumanMessage

from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
Expand All @@ -23,20 +25,25 @@ def provider_name(self):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
}
]
else:
Expand Down Expand Up @@ -68,11 +75,12 @@ def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> M
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}

if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
Expand Down Expand Up @@ -111,7 +119,7 @@ def is_provider_credentials_valid_or_raise(cls, credentials: dict):
**credential_kwargs
)

llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

Expand Down
6 changes: 6 additions & 0 deletions api/core/model_providers/rules/wenxin.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
Expand Down
198 changes: 135 additions & 63 deletions api/core/third_party/langchain/llms/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
Any,
Dict,
List,
Optional, Iterator,
Optional, Iterator, Tuple,
)

import requests
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator

from langchain.callbacks.manager import (
Expand Down Expand Up @@ -61,6 +64,7 @@ def post(self, request: dict) -> Any:
raise ValueError(f"Wenxin Model name is required")

model_url_map = {
'ernie-bot-4': 'completions_pro',
'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1',
Expand All @@ -70,6 +74,7 @@ def post(self, request: dict) -> Any:

access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del request['model']

headers = {"Content-Type": "application/json"}
response = requests.post(api_url,
Expand All @@ -86,22 +91,21 @@ def post(self, request: dict) -> Any:
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}"
)
return json_response["result"]
return json_response
else:
return response


class Wenxin(LLM):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
class Wenxin(BaseChatModel):
"""Wrapper around Wenxin large language models."""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}

@property
def lc_serializable(self) -> bool:
return True

_client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot"
Expand Down Expand Up @@ -161,64 +165,89 @@ def __init__(self, **data: Any):
secret_key=self.secret_key,
)

def _call(
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict

def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], str]:
dict_messages = []
system = None
for m in messages:
message = self._convert_message_to_dict(m)
if message['role'] == 'system':
if not system:
system = message['content']
else:
system += f"\n{message['content']}"
continue

if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)

return dict_messages, system

def _generate(
self,
prompt: str,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
) -> ChatResult:
if self.streaming:
completion = ""
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}

if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)
completion = self._client.post(request)

if stop is not None:
completion = enforce_stop_tokens(completion, stop)

return completion
response = self._client.post(request)
return self._create_chat_result(response)

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator.

Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python

prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)

for token in self._client.post(request).iter_lines():
Expand All @@ -228,12 +257,18 @@ def _stream(
if token.startswith('data:'):
completion = json.loads(token[5:])

yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
chunk_dict = {
'message': AIMessageChunk(content=completion['result']),
}

if completion['is_end']:
break
token_usage = completion['usage']
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
chunk_dict['generation_info'] = dict({'token_usage': token_usage})

yield ChatGenerationChunk(**chunk_dict)
if run_manager:
run_manager.on_llm_new_token(completion['result'])
else:
try:
json_response = json.loads(token)
Expand All @@ -245,3 +280,40 @@ def _stream(
f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for."
)

def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
generations = [ChatGeneration(
message=AIMessage(content=response['result']),
)]
token_usage = response.get("usage")
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']

llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.

Useful for checking if an input will fit in a model's context window.

Args:
messages: The message inputs to tokenize.

Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}
Loading
Loading