Skip to content

Commit

Permalink
feat: support ernie-bot-4 and chat mode
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Oct 18, 2023
1 parent c039f4a commit 905c59c
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 72 deletions.
19 changes: 14 additions & 5 deletions api/core/model_providers/models/llm/wenxin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@

from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin


class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT

def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
Expand All @@ -38,7 +37,13 @@ def _run(self, messages: List[PromptMessage],
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)

generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}

if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']

return self._client.generate(**generate_kwargs)

def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
Expand All @@ -48,7 +53,7 @@ def get_num_tokens(self, messages: List[PromptMessage]) -> int:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)

def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
Expand All @@ -58,3 +63,7 @@ def _set_model_kwargs(self, model_kwargs: ModelKwargs):

def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")

@property
def support_streaming(self):
return True
14 changes: 10 additions & 4 deletions api/core/model_providers/providers/wenxin_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,25 @@ def provider_name(self):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
'mode': ModelMode.CHAT.value,
}
]
else:
Expand Down Expand Up @@ -68,11 +73,12 @@ def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> M
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}

if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
Expand Down
6 changes: 6 additions & 0 deletions api/core/model_providers/rules/wenxin.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",
Expand Down
198 changes: 135 additions & 63 deletions api/core/third_party/langchain/llms/wenxin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
Any,
Dict,
List,
Optional, Iterator,
Optional, Iterator, Tuple,
)

import requests
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.schema.output import GenerationChunk
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator

from langchain.callbacks.manager import (
Expand Down Expand Up @@ -61,6 +64,7 @@ def post(self, request: dict) -> Any:
raise ValueError(f"Wenxin Model name is required")

model_url_map = {
'ernie-bot-4': 'completions_pro',
'ernie-bot': 'completions',
'ernie-bot-turbo': 'eb-instant',
'bloomz-7b': 'bloomz_7b1',
Expand All @@ -70,6 +74,7 @@ def post(self, request: dict) -> Any:

access_token = self.get_access_token()
api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
del request['model']

headers = {"Content-Type": "application/json"}
response = requests.post(api_url,
Expand All @@ -86,22 +91,21 @@ def post(self, request: dict) -> Any:
f"Wenxin API {json_response['error_code']}"
f" error: {json_response['error_msg']}"
)
return json_response["result"]
return json_response
else:
return response


class Wenxin(LLM):
"""Wrapper around Wenxin large language models.
To use, you should have the environment variable
``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key,
or pass them as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.wenxin import Wenxin
wenxin = Wenxin(model="<model_name>", api_key="my-api-key",
secret_key="my-group-id")
"""
class Wenxin(BaseChatModel):
"""Wrapper around Wenxin large language models."""

@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}

@property
def lc_serializable(self) -> bool:
return True

_client: _WenxinEndpointClient = PrivateAttr()
model: str = "ernie-bot"
Expand Down Expand Up @@ -161,64 +165,89 @@ def __init__(self, **data: Any):
secret_key=self.secret_key,
)

def _call(
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict

def _create_message_dicts(
self, messages: List[BaseMessage]
) -> Tuple[List[Dict[str, Any]], str]:
dict_messages = []
system = None
for m in messages:
message = self._convert_message_to_dict(m)
if message['role'] == 'system':
if not system:
system = message['content']
else:
system += f"\n{message['content']}"
continue

if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)

return dict_messages, system

def _generate(
self,
prompt: str,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Wenxin's completion endpoint to chat
Args:
prompt: The prompt to pass into the model.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = wenxin("Tell me a joke.")
"""
) -> ChatResult:
if self.streaming:
completion = ""
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}

if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)
completion = self._client.post(request)

if stop is not None:
completion = enforce_stop_tokens(completion, stop)

return completion
response = self._client.post(request)
return self._create_chat_result(response)

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call wenxin completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Wenxin.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = wenxin.stream(prompt)
for token in generator:
yield token
"""
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, system = self._create_message_dicts(messages)
request = self._default_params
request["messages"] = [{"role": "user", "content": prompt}]
request["messages"] = message_dicts
if system:
request["system"] = system
request.update(kwargs)

for token in self._client.post(request).iter_lines():
Expand All @@ -228,12 +257,18 @@ def _stream(
if token.startswith('data:'):
completion = json.loads(token[5:])

yield GenerationChunk(text=completion['result'])
if run_manager:
run_manager.on_llm_new_token(completion['result'])
chunk_dict = {
'message': AIMessageChunk(content=completion['result']),
}

if completion['is_end']:
break
token_usage = completion['usage']
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
chunk_dict['generation_info'] = dict({'token_usage': token_usage})

yield ChatGenerationChunk(**chunk_dict)
if run_manager:
run_manager.on_llm_new_token(completion['result'])
else:
try:
json_response = json.loads(token)
Expand All @@ -245,3 +280,40 @@ def _stream(
f" error: {json_response['error_msg']}, "
f"please confirm if the model you have chosen is already paid for."
)

def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
generations = [ChatGeneration(
message=AIMessage(content=response['result']),
)]
token_usage = response.get("usage")
token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']

llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)

def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model}

0 comments on commit 905c59c

Please sign in to comment.