Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Jan 1, 2024
2 parents 0e2c50c + f21461c commit 37e2f89
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 82 deletions.
14 changes: 14 additions & 0 deletions api/core/model_runtime/model_providers/huggingface_hub/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from huggingface_hub.utils import HfHubHTTPError

from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError


class _CommonHuggingfaceHub:

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError
]
}
110 changes: 61 additions & 49 deletions api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import HfHubHTTPError

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMMode
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage, \
UserPromptMessage, SystemPromptMessage
from core.model_runtime.entities.model_entities import ParameterRule, DefaultParameterName, AIModelEntity, ModelType, \
FetchFrom
from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub


class HuggingfaceHubLargeLanguageModel(LargeLanguageModel):
class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
Expand All @@ -34,40 +34,14 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes
**model_parameters)

if stream:
return self._handle_generate_stream_response(model, prompt_messages, response)
return self._handle_generate_stream_response(model, credentials, prompt_messages, response)

return self._handle_generate_response(model, prompt_messages, response)

@staticmethod
def _get_llm_usage():
usage = LLMUsage(
prompt_tokens=0,
prompt_unit_price=0,
prompt_price_unit=0,
prompt_price=0,
completion_tokens=0,
completion_unit_price=0,
completion_price_unit=0,
completion_price=0,
total_tokens=0,
total_price=0,
currency='USD',
latency=0,
)
return usage
return self._handle_generate_response(model, credentials, prompt_messages, response)

def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
return 0
prompt = self._convert_messages_to_prompt(prompt_messages)
return self._get_num_tokens_by_gpt2(prompt)

def validate_credentials(self, model: str, credentials: dict) -> None:
try:
Expand Down Expand Up @@ -106,14 +80,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError
]
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
entity = AIModelEntity(
model=model,
Expand Down Expand Up @@ -163,35 +129,54 @@ def _get_customizable_model_parameter_rules() -> list[ParameterRule]:

return [temperature_rule, top_k_rule, top_p_rule]

def _handle_generate_stream_response(self, model: str,
def _handle_generate_stream_response(self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
response: Generator) -> Generator:
for chunk in response:
# skip special tokens
if chunk.token.special:
continue

assistant_prompt_message = AssistantPromptMessage(
content=chunk.token.text
)

prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])

usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk.token.id,
message=AssistantPromptMessage(content=chunk.token.text),
usage=self._get_llm_usage(),
message=assistant_prompt_message,
usage=usage,
),
)

def _handle_generate_response(self, model: str, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
if isinstance(response, str):
content = response
else:
content = response.generated_text

usage = self._get_llm_usage()
assistant_prompt_message = AssistantPromptMessage(
content=content
)

prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])

usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)

result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=content),
message=assistant_prompt_message,
usage=usage,
)
return result
Expand All @@ -216,3 +201,30 @@ def _get_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str):
raise CredentialsValidateFailedError(f"{str(e)}")

return model_info.pipeline_tag

def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
messages = messages.copy() # don't mutate the original list

text = "".join(
self._convert_one_message_to_text(message)
for message in messages
)

return text.rstrip()

@staticmethod
def _convert_one_message_to_text(message: PromptMessage) -> str:
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
content = message.content

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:
raise ValueError(f"Got unknown type {message}")

return message_text
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
import time
from typing import Optional

import numpy as np
from huggingface_hub import InferenceClient, HfApi
from huggingface_hub.utils import HfHubHTTPError

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.invoke import InvokeError, InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub


class HuggingfaceHubTextEmbeddingModel(TextEmbeddingModel):
class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):

def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
Expand All @@ -36,15 +36,8 @@ def _invoke(self, model: str, credentials: dict, texts: list[str],

embeddings = json.loads(output.decode())

usage = EmbeddingUsage(
tokens=0,
total_tokens=0,
unit_price=0.0,
price_unit=0.0,
total_price=0.0,
currency='USD',
latency=0.0
)
tokens = self.get_num_tokens(model, credentials, texts)
usage = self._calc_response_usage(model, credentials, tokens)

return TextEmbeddingResult(
embeddings=self._mean_pooling(embeddings),
Expand All @@ -53,15 +46,10 @@ def _invoke(self, model: str, credentials: dict, texts: list[str],
)

def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
return 0
num_tokens = 0
for text in texts:
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def validate_credentials(self, model: str, credentials: dict) -> None:
try:
Expand Down Expand Up @@ -111,14 +99,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option
)
return entity

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError
]
}

# https://huggingface.co/docs/api-inference/detailed_parameters#feature-extraction-task
# Returned values are a list of floats, or a list[list[floats]]
# (depending on if you sent a string or a list of string,
Expand Down Expand Up @@ -153,3 +133,24 @@ def _check_hosted_model_task_type(huggingfacehub_api_token: str, model_name: str
f"must be one of {valid_tasks}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{str(e)}")

def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)

# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)

return usage
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,4 @@ def test_get_num_tokens():
]
)

assert num_tokens == 0
assert num_tokens == 7
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_hosted_inference_api_invoke_model():

assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 0
assert result.usage.total_tokens == 2


def test_inference_endpoints_validate_credentials():
Expand Down Expand Up @@ -117,4 +117,4 @@ def test_get_num_tokens():
]
)

assert num_tokens == 0
assert num_tokens == 2

0 comments on commit 37e2f89

Please sign in to comment.