diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index 75fffac7222b7f..60edf754447279 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -10,6 +10,8 @@ from core.model_providers.models.base import BaseProviderModel from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM +from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings +from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding from models.provider import ProviderType @@ -88,19 +90,15 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT if 'task_type' not in credentials: raise CredentialsValidateFailedError('Task Type must be provided.') - if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"): + if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'): raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, ' - 'text-generation, summarization.') + 'text-generation, summarization, feature-extraction.') try: - llm = HuggingFaceEndpointLLM( - endpoint_url=credentials['huggingfacehub_endpoint_url'], - task=credentials['task_type'], - model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, - huggingfacehub_api_token=credentials['huggingfacehub_api_token'] - ) - - llm("ping") + if credentials['task_type'] == 'feature-extraction': + cls.check_embedding_valid(credentials, model_name) + else: + cls.check_llm_valid(credentials) except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") else: @@ -112,13 +110,33 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT if 'inference' in model_info.cardData and not model_info.cardData['inference']: raise ValueError(f'Inference API has been turned off for this model {model_name}.') - VALID_TASKS = ("text2text-generation", "text-generation", "summarization") + VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction") if model_info.pipeline_tag not in VALID_TASKS: raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {VALID_TASKS}.") except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") + @classmethod + def check_llm_valid(cls, credentials: dict): + llm = HuggingFaceEndpointLLM( + endpoint_url=credentials['huggingfacehub_endpoint_url'], + task=credentials['task_type'], + model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, + huggingfacehub_api_token=credentials['huggingfacehub_api_token'] + ) + + llm("ping") + + @classmethod + def check_embedding_valid(cls, credentials: dict, model_name: str): + embedding_model = HuggingfaceHubEmbeddings( + model=model_name, + **credentials + ) + + embedding_model.embed_query("ping") + @classmethod def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict: diff --git a/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py new file mode 100644 index 00000000000000..8e828bbec09c96 --- /dev/null +++ b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py @@ -0,0 +1,68 @@ +from typing import Any, Dict, List, Optional, Union +import json + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env +from huggingface_hub import InferenceClient + +HOSTED_INFERENCE_API = 'hosted_inference_api' +INFERENCE_ENDPOINTS = 'inference_endpoints' + + +class HuggingfaceHubEmbeddings(BaseModel, Embeddings): + client: Any + model: str + + task_type: Optional[str] = None + huggingfacehub_api_type: Optional[str] = None + huggingfacehub_api_token: Optional[str] = None + huggingfacehub_endpoint_url: Optional[str] = None + + class Config: + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values['huggingfacehub_api_token'] = get_from_dict_or_env( + values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" + ) + + values['client'] = InferenceClient(values['huggingfacehub_api_token']) + + return values + + def embeddings(self, inputs: Union[str, List[str]]) -> str: + model = '' + + if self.huggingfacehub_api_type == HOSTED_INFERENCE_API: + model = self.model + else: + model = self.huggingfacehub_endpoint_url + + output = self.client.post( + json={ + "inputs": inputs, + "options": { + "wait_for_model": False + } + }, model=model) + + return json.loads(output.decode()) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + output = self.embeddings(texts) + + if isinstance(output, list): + return output + + return [list(map(float, e)) for e in output] + + def embed_query(self, text: str) -> List[float]: + output = self.embeddings(text) + + if isinstance(output, list): + return output + + return list(map(float, output)) diff --git a/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py new file mode 100644 index 00000000000000..e841e926e7be3e --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py @@ -0,0 +1,101 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding +from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider +from models.provider import Provider, ProviderType, ProviderModel + +DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2' + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='huggingface_hub', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker): + valid_api_key = os.environ['HUGGINGFACE_API_KEY'] + endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL'] + model_provider = HuggingfaceHubProvider(provider=get_mock_provider()) + + credentials = { + 'huggingfacehub_api_type': huggingfacehub_api_type, + 'huggingfacehub_api_token': valid_api_key, + 'task_type': 'feature-extraction' + } + + if huggingfacehub_api_type == 'inference_endpoints': + credentials['huggingfacehub_endpoint_url'] = endpoint_url + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='huggingface_hub', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps(credentials), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', + return_value=mock_query) + + return HuggingfaceEmbedding( + model_provider=model_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + DEFAULT_MODEL_NAME, + 'hosted_inference_api', + mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + DEFAULT_MODEL_NAME, + 'hosted_inference_api', + mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 384 diff --git a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx index d7b9f64f931cfd..6c21a7cd73e8ae 100644 --- a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx +++ b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx @@ -68,14 +68,25 @@ const config: ProviderConfig = { ] } if (v?.huggingfacehub_api_type === 'inference_endpoints') { - filteredKeys = [ - 'huggingfacehub_api_type', - 'huggingfacehub_api_token', - 'model_name', - 'huggingfacehub_endpoint_url', - 'task_type', - 'model_type', - ] + if (v?.model_type === 'embeddings') { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'model_type', + ] + } + else { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + 'model_type', + ] + } } return filteredKeys.reduce((prev: FormValue, next: string) => { prev[next] = v?.[next] || '' @@ -173,7 +184,7 @@ const config: ProviderConfig = { }, }, { - hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api', + hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings', type: 'radio', key: 'task_type', required: true,