-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'feat/huggingface-embedding-support' into deploy/dev
- Loading branch information
Showing
4 changed files
with
218 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from typing import Any, Dict, List, Optional, Union | ||
import json | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.embeddings.base import Embeddings | ||
from langchain.utils import get_from_dict_or_env | ||
from huggingface_hub import InferenceClient | ||
|
||
HOSTED_INFERENCE_API = 'hosted_inference_api' | ||
INFERENCE_ENDPOINTS = 'inference_endpoints' | ||
|
||
|
||
class HuggingfaceHubEmbeddings(BaseModel, Embeddings): | ||
client: Any | ||
model: str | ||
|
||
task_type: Optional[str] = None | ||
huggingfacehub_api_type: Optional[str] = None | ||
huggingfacehub_api_token: Optional[str] = None | ||
huggingfacehub_endpoint_url: Optional[str] = None | ||
|
||
class Config: | ||
extra = Extra.forbid | ||
|
||
@root_validator() | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
values['huggingfacehub_api_token'] = get_from_dict_or_env( | ||
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" | ||
) | ||
|
||
values['client'] = InferenceClient(values['huggingfacehub_api_token']) | ||
|
||
return values | ||
|
||
def embeddings(self, inputs: Union[str, List[str]]) -> str: | ||
model = '' | ||
|
||
if self.huggingfacehub_api_type == HOSTED_INFERENCE_API: | ||
model = self.model | ||
else: | ||
model = self.huggingfacehub_endpoint_url | ||
|
||
output = self.client.post( | ||
json={ | ||
"inputs": inputs, | ||
"options": { | ||
"wait_for_model": False | ||
} | ||
}, model=model) | ||
|
||
return json.loads(output.decode()) | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
output = self.embeddings(texts) | ||
|
||
if isinstance(output, list): | ||
return output | ||
|
||
return [list(map(float, e)) for e in output] | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
output = self.embeddings(text) | ||
|
||
if isinstance(output, list): | ||
return output | ||
|
||
return list(map(float, output)) |
101 changes: 101 additions & 0 deletions
101
api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import json | ||
import os | ||
from unittest.mock import patch, MagicMock | ||
|
||
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType | ||
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding | ||
from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider | ||
from models.provider import Provider, ProviderType, ProviderModel | ||
|
||
DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2' | ||
|
||
def get_mock_provider(): | ||
return Provider( | ||
id='provider_id', | ||
tenant_id='tenant_id', | ||
provider_name='huggingface_hub', | ||
provider_type=ProviderType.CUSTOM.value, | ||
encrypted_config='', | ||
is_valid=True, | ||
) | ||
|
||
|
||
def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker): | ||
valid_api_key = os.environ['HUGGINGFACE_API_KEY'] | ||
endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL'] | ||
model_provider = HuggingfaceHubProvider(provider=get_mock_provider()) | ||
|
||
credentials = { | ||
'huggingfacehub_api_type': huggingfacehub_api_type, | ||
'huggingfacehub_api_token': valid_api_key, | ||
'task_type': 'feature-extraction' | ||
} | ||
|
||
if huggingfacehub_api_type == 'inference_endpoints': | ||
credentials['huggingfacehub_endpoint_url'] = endpoint_url | ||
|
||
mock_query = MagicMock() | ||
mock_query.filter.return_value.first.return_value = ProviderModel( | ||
provider_name='huggingface_hub', | ||
model_name=model_name, | ||
model_type=ModelType.EMBEDDINGS.value, | ||
encrypted_config=json.dumps(credentials), | ||
is_valid=True, | ||
) | ||
mocker.patch('extensions.ext_database.db.session.query', | ||
return_value=mock_query) | ||
|
||
return HuggingfaceEmbedding( | ||
model_provider=model_provider, | ||
name=model_name | ||
) | ||
|
||
|
||
def decrypt_side_effect(tenant_id, encrypted_api_key): | ||
return encrypted_api_key | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_hosted_inference_api_embed_documents(mock_decrypt, mocker): | ||
embedding_model = get_mock_embedding_model( | ||
DEFAULT_MODEL_NAME, | ||
'hosted_inference_api', | ||
mocker) | ||
rst = embedding_model.client.embed_documents(['test', 'test1']) | ||
assert isinstance(rst, list) | ||
assert len(rst) == 2 | ||
assert len(rst[0]) == 384 | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_hosted_inference_api_embed_query(mock_decrypt, mocker): | ||
embedding_model = get_mock_embedding_model( | ||
DEFAULT_MODEL_NAME, | ||
'hosted_inference_api', | ||
mocker) | ||
rst = embedding_model.client.embed_query('test') | ||
assert isinstance(rst, list) | ||
assert len(rst) == 384 | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker): | ||
embedding_model = get_mock_embedding_model( | ||
'', | ||
'inference_endpoints', | ||
mocker) | ||
rst = embedding_model.client.embed_documents(['test', 'test1']) | ||
assert isinstance(rst, list) | ||
assert len(rst) == 2 | ||
assert len(rst[0]) == 384 | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker): | ||
embedding_model = get_mock_embedding_model( | ||
'', | ||
'inference_endpoints', | ||
mocker) | ||
rst = embedding_model.client.embed_query('test') | ||
assert isinstance(rst, list) | ||
assert len(rst) == 384 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters