Skip to content

Commit

Permalink
Revert "feat: hugging face supports embeddings."
Browse files Browse the repository at this point in the history
This reverts commit 757ee4d.
  • Loading branch information
GarfieldDai committed Sep 19, 2023
1 parent 9a00c08 commit 86ddce2
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 174 deletions.
24 changes: 0 additions & 24 deletions api/core/model_providers/models/embedding/huggingface_embedding.py

This file was deleted.

11 changes: 4 additions & 7 deletions api/core/model_providers/providers/huggingface_hub_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType


Expand All @@ -34,8 +33,6 @@ def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = HuggingfaceHubModel
elif model_type == ModelType.EMBEDDINGS:
model_class = HuggingfaceEmbedding
else:
raise NotImplementedError

Expand Down Expand Up @@ -66,7 +63,7 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
:param model_type:
:param credentials:
"""
if model_type not in [ModelType.TEXT_GENERATION, ModelType.EMBEDDINGS]:
if model_type != ModelType.TEXT_GENERATION:
raise NotImplementedError

if 'huggingfacehub_api_type' not in credentials \
Expand All @@ -91,9 +88,9 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')

if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'sentence-similarity'):
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization, sentence-similarity.')
'text-generation, summarization.')

try:
llm = HuggingFaceEndpointLLM(
Expand All @@ -115,7 +112,7 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')

VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "sentence-similarity")
VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
Expand Down

This file was deleted.

This file was deleted.

1 change: 0 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,3 @@ pandas==1.5.3
xinference==0.4.2
safetensors==0.3.2
zhipuai==1.0.7
sentence-transformers==2.2.2

This file was deleted.

0 comments on commit 86ddce2

Please sign in to comment.