Skip to content

Commit

Permalink
Merge branch 'feat/huggingface-embedding-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Sep 19, 2023
2 parents 01c571b + b8db580 commit abb96e1
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 20 deletions.
40 changes: 29 additions & 11 deletions api/core/model_providers/providers/huggingface_hub_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType


Expand Down Expand Up @@ -88,19 +90,15 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'task_type' not in credentials:
raise CredentialsValidateFailedError('Task Type must be provided.')

if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'):
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
'text-generation, summarization.')
'text-generation, summarization, feature-extraction.')

try:
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)

llm("ping")
if credentials['task_type'] == 'feature-extraction':
cls.check_embedding_valid(credentials, model_name)
else:
cls.check_llm_valid(credentials)
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")
else:
Expand All @@ -112,13 +110,33 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT
if 'inference' in model_info.cardData and not model_info.cardData['inference']:
raise ValueError(f'Inference API has been turned off for this model {model_name}.')

VALID_TASKS = ("text2text-generation", "text-generation", "summarization")
VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction")
if model_info.pipeline_tag not in VALID_TASKS:
raise ValueError(f"Model {model_name} is not a valid task, "
f"must be one of {VALID_TASKS}.")
except Exception as e:
raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}")

@classmethod
def check_llm_valid(cls, credentials: dict):
llm = HuggingFaceEndpointLLM(
endpoint_url=credentials['huggingfacehub_endpoint_url'],
task=credentials['task_type'],
model_kwargs={"temperature": 0.5, "max_new_tokens": 200},
huggingfacehub_api_token=credentials['huggingfacehub_api_token']
)

llm("ping")

@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):
embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)

embedding_model.embed_query("ping")

@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Any, Dict, List, Optional, Union
import json

from pydantic import BaseModel, Extra, root_validator

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from huggingface_hub import InferenceClient

HOSTED_INFERENCE_API = 'hosted_inference_api'
INFERENCE_ENDPOINTS = 'inference_endpoints'


class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
client: Any
model: str

task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
huggingfacehub_endpoint_url: Optional[str] = None

class Config:
extra = Extra.forbid

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values['huggingfacehub_api_token'] = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)

values['client'] = InferenceClient(values['huggingfacehub_api_token'])

return values

def embeddings(self, inputs: Union[str, List[str]]) -> str:
model = ''

if self.huggingfacehub_api_type == HOSTED_INFERENCE_API:
model = self.model
else:
model = self.huggingfacehub_endpoint_url

output = self.client.post(
json={
"inputs": inputs,
"options": {
"wait_for_model": False
}
}, model=model)

return json.loads(output.decode())

def embed_documents(self, texts: List[str]) -> List[List[float]]:
output = self.embeddings(texts)

if isinstance(output, list):
return output

return [list(map(float, e)) for e in output]

def embed_query(self, text: str) -> List[float]:
output = self.embeddings(text)

if isinstance(output, list):
return output

return list(map(float, output))
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import json
import os
from unittest.mock import patch, MagicMock

from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
from models.provider import Provider, ProviderType, ProviderModel

DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'

def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='huggingface_hub',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)


def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
valid_api_key = os.environ['HUGGINGFACE_API_KEY']
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL']
model_provider = HuggingfaceHubProvider(provider=get_mock_provider())

credentials = {
'huggingfacehub_api_type': huggingfacehub_api_type,
'huggingfacehub_api_token': valid_api_key,
'task_type': 'feature-extraction'
}

if huggingfacehub_api_type == 'inference_endpoints':
credentials['huggingfacehub_endpoint_url'] = endpoint_url

mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='huggingface_hub',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps(credentials),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query',
return_value=mock_query)

return HuggingfaceEmbedding(
model_provider=model_provider,
name=model_name
)


def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
DEFAULT_MODEL_NAME,
'hosted_inference_api',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) == 384


@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(
'',
'inference_endpoints',
mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) == 384
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,25 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
if (v?.model_type === 'embeddings') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'model_type',
]
}
else {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
Expand Down Expand Up @@ -173,7 +184,7 @@ const config: ProviderConfig = {
},
},
{
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api',
hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings',
type: 'radio',
key: 'task_type',
required: true,
Expand Down

0 comments on commit abb96e1

Please sign in to comment.