Skip to content

Commit

Permalink
feat: hugging face supports embeddings.
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Sep 19, 2023
1 parent 9d3ba98 commit b8db580
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 53 deletions.
39 changes: 28 additions & 11 deletions api/core/model_providers/providers/huggingface_hub_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType

Expand Down Expand Up @@ -91,19 +92,15 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')

if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'sentence-similarity'):
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization, sentence-similarity.')
'text-generation, summarization, feature-extraction.')

try:
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)

llm("ping")
if credentials['task_type'] == 'feature-extraction':
cls.check_embedding_valid(credentials, model_name)
else:
cls.check_llm_valid(credentials)
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
Expand All @@ -115,13 +112,33 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')

VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "sentence-similarity")
VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")

@classmethod
def check_llm_valid(cls, credentials: dict):
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)

llm("ping")

@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):
embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)

embedding_model.embed_query("ping")

@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import json

from pydantic import BaseModel, Extra, root_validator

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.embeddings import HuggingFaceEmbeddings
from huggingface_hub import InferenceClient

HOSTED_INFERENCE_API = 'hosted_inference_api'
INFERENCE_ENDPOINTS = 'inference_endpoints'


class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
Expand All @@ -14,6 +18,7 @@ class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
huggingfacehub_endpoint_url: Optional[str] = None

class Config:
extra = Extra.forbid
Expand All @@ -24,14 +29,40 @@ def validate_environment(cls, values: Dict) -> Dict:
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)

values['client'] = HuggingFaceEmbeddings(
model_name=values['model']
)
values['client'] = InferenceClient(values['huggingfacehub_api_token'])

return values

def embeddings(self, inputs: Union[str, List[str]]) -> str:
model = ''

if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
model = self.model
else:
model = self.huggingfacehub_endpoint_url

output = self.client.post(
json={
"inputs": inputs,
"options": {
"wait_for_model": False
}
}, model=model)

return json.loads(output.decode())

def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self.client.embed_documents(texts)
output = self.embeddings(texts)

if isinstance(output, list):
return output

return [list(map(float, e)) for e in output]

def embed_query(self, text: str) -> List[float]:
return self.client.embed_query(text)
output = self.embeddings(text)

if isinstance(output, list):
return output

return list(map(float, output))
3 changes: 1 addition & 2 deletions api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,4 @@ stripe~=5.5.0
pandas==1.5.3
xinference==0.4.2
safetensors==0.3.2
zhipuai==1.0.7
sentence-transformers==2.2.2
zhipuai==1.0.7
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel

DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'

def get_mock_provider():
return Provider(
Expand All @@ -28,7 +28,7 @@ def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key,
'task_type': 'sentence-similarity'
'task_type': 'feature-extraction'
}

if huggingfacehub_api_type == 'inference_endpoints':
Expand Down Expand Up @@ -56,23 +56,46 @@ def decrypt_side_effect(tenant_id, encrypted_api_key):


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 768
assert len(rst[0]) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 768
assert len(rst) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384

0 comments on commit b8db580

Please sign in to comment.