Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Jan 2, 2024
2 parents 669fc83 + 7871f9c commit bdd69fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from huggingface_hub.utils import HfHubHTTPError
from huggingface_hub.utils import HfHubHTTPError, BadRequestError

from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError

Expand All @@ -9,6 +9,7 @@ class _CommonHuggingfaceHub:
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeBadRequestError: [
HfHubHTTPError
HfHubHTTPError,
BadRequestError
]
}
21 changes: 16 additions & 5 deletions api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from huggingface_hub import InferenceClient
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils import BadRequestError

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
Expand All @@ -25,6 +26,9 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']

if 'baichuan' in model.lower():
stream = False

response = client.text_generation(
prompt=prompt_messages[0].content,
details=True,
Expand Down Expand Up @@ -73,10 +77,14 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
if credentials['huggingfacehub_api_type'] == 'inference_endpoints':
model = credentials['huggingfacehub_endpoint_url']

client.text_generation(
prompt='Who are you?',
stream=False,
model=model)
try:
client.text_generation(
prompt='Who are you?',
stream=True,
model=model)
except BadRequestError as e:
raise CredentialsValidateFailedError('Only available for models running on with the `text-generation-inference`. '
'To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference.')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

Expand Down Expand Up @@ -134,11 +142,14 @@ def _handle_generate_stream_response(self,
credentials: dict,
prompt_messages: list[PromptMessage],
response: Generator) -> Generator:
index = -1
for chunk in response:
# skip special tokens
if chunk.token.special:
continue

index += 1

assistant_prompt_message = AssistantPromptMessage(
content=chunk.token.text
)
Expand All @@ -152,7 +163,7 @@ def _handle_generate_stream_response(self,
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=chunk.token.id,
index=index,
message=assistant_prompt_message,
usage=usage,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,20 @@ def test_inference_endpoints_text_generation_validate_credentials():

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'model': 'openchat/openchat_3.5',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
)

model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'openchat/openchat_3.5',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
Expand All @@ -132,11 +130,10 @@ def test_inference_endpoints_text_generation_invoke_model():
model = HuggingfaceHubLargeLanguageModel()

response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'openchat/openchat_3.5',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
Expand All @@ -163,11 +160,10 @@ def test_inference_endpoints_text_generation_invoke_stream_model():
model = HuggingfaceHubLargeLanguageModel()

response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model='openchat/openchat_3.5',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'openchat/openchat_3.5',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
},
Expand Down Expand Up @@ -200,22 +196,20 @@ def test_inference_endpoints_text2text_generation_validate_credentials():

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'model': 'google/mt5-base',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
)

model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'google/mt5-base',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
Expand All @@ -226,11 +220,10 @@ def test_inference_endpoints_text2text_generation_invoke_model():
model = HuggingfaceHubLargeLanguageModel()

response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'google/mt5-base',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
Expand All @@ -257,11 +250,10 @@ def test_inference_endpoints_text2text_generation_invoke_stream_model():
model = HuggingfaceHubLargeLanguageModel()

response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'google/mt5-base',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
Expand Down Expand Up @@ -293,11 +285,10 @@ def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()

num_tokens = model.get_num_tokens(
model='',
model='google/mt5-base',
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'model': 'google/mt5-base',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
},
Expand Down

0 comments on commit bdd69fd

Please sign in to comment.