From 9d3ba98d07d207d5ce060145e5abdb948743a5e2 Mon Sep 17 00:00:00 2001 From: StyleZhang Date: Tue, 19 Sep 2023 11:02:54 +0800 Subject: [PATCH 1/2] fix: frontend huggingface embedding hide task type --- .../model-page/configs/huggingface_hub.tsx | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx index d7b9f64f931cfd..6c21a7cd73e8ae 100644 --- a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx +++ b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx @@ -68,14 +68,25 @@ const config: ProviderConfig = { ] } if (v?.huggingfacehub_api_type === 'inference_endpoints') { - filteredKeys = [ - 'huggingfacehub_api_type', - 'huggingfacehub_api_token', - 'model_name', - 'huggingfacehub_endpoint_url', - 'task_type', - 'model_type', - ] + if (v?.model_type === 'embeddings') { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'model_type', + ] + } + else { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + 'model_type', + ] + } } return filteredKeys.reduce((prev: FormValue, next: string) => { prev[next] = v?.[next] || '' @@ -173,7 +184,7 @@ const config: ProviderConfig = { }, }, { - hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api', + hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api' || value?.model_type === 'embeddings', type: 'radio', key: 'task_type', required: true, From b8db580833d52bf01580e1206f0012f810580894 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Tue, 19 Sep 2023 19:22:16 +0800 Subject: [PATCH 2/2] feat: hugging face supports embeddings. --- .../providers/huggingface_hub_provider.py | 39 +++++++++++----- .../huggingface_endpoint_embedding.py | 27 ----------- .../embeddings/huggingface_hub_embedding.py | 45 ++++++++++++++++--- api/requirements.txt | 3 +- .../test_huggingface_hub_embedding.py | 35 ++++++++++++--- 5 files changed, 96 insertions(+), 53 deletions(-) delete mode 100644 api/core/third_party/langchain/embeddings/huggingface_endpoint_embedding.py diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index a67ca81891ecb5..d1845c56c96816 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -10,6 +10,7 @@ from core.model_providers.models.base import BaseProviderModel from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM +from core.third_party.langchain.embeddings.huggingface_hub_embedding import HuggingfaceHubEmbeddings from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding from models.provider import ProviderType @@ -91,19 +92,15 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT if 'task_type' not in credentials: raise CredentialsValidateFailedError('Task Type must be provided.') - if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'sentence-similarity'): + if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization", 'feature-extraction'): raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, ' - 'text-generation, summarization, sentence-similarity.') + 'text-generation, summarization, feature-extraction.') try: - llm = HuggingFaceEndpointLLM( - endpoint_url=credentials['huggingfacehub_endpoint_url'], - task=credentials['task_type'], - model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, - huggingfacehub_api_token=credentials['huggingfacehub_api_token'] - ) - - llm("ping") + if credentials['task_type'] == 'feature-extraction': + cls.check_embedding_valid(credentials, model_name) + else: + cls.check_llm_valid(credentials) except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") else: @@ -115,13 +112,33 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT if 'inference' in model_info.cardData and not model_info.cardData['inference']: raise ValueError(f'Inference API has been turned off for this model {model_name}.') - VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "sentence-similarity") + VALID_TASKS = ("text2text-generation", "text-generation", "summarization", "feature-extraction") if model_info.pipeline_tag not in VALID_TASKS: raise ValueError(f"Model {model_name} is not a valid task, " f"must be one of {VALID_TASKS}.") except Exception as e: raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") + @classmethod + def check_llm_valid(cls, credentials: dict): + llm = HuggingFaceEndpointLLM( + endpoint_url=credentials['huggingfacehub_endpoint_url'], + task=credentials['task_type'], + model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, + huggingfacehub_api_token=credentials['huggingfacehub_api_token'] + ) + + llm("ping") + + @classmethod + def check_embedding_valid(cls, credentials: dict, model_name: str): + embedding_model = HuggingfaceHubEmbeddings( + model=model_name, + **credentials + ) + + embedding_model.embed_query("ping") + @classmethod def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict: diff --git a/api/core/third_party/langchain/embeddings/huggingface_endpoint_embedding.py b/api/core/third_party/langchain/embeddings/huggingface_endpoint_embedding.py deleted file mode 100644 index d8f2d30bc0184d..00000000000000 --- a/api/core/third_party/langchain/embeddings/huggingface_endpoint_embedding.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Extra, root_validator - -from langchain.embeddings.base import Embeddings - - -class HuggingfaceEndpointEmbeddings(BaseModel, Embeddings): - client: Any - model: str - - task_type: Optional[str] = None - huggingfacehub_api_type: Optional[str] = None - huggingfacehub_api_token: Optional[str] = None - - class Config: - extra = Extra.forbid - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - pass - - def embed_documents(self, texts: List[str]) -> List[List[float]]: - pass - - def embed_query(self, text: str) -> List[float]: - pass diff --git a/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py index 5abb106dd09b80..8e828bbec09c96 100644 --- a/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py +++ b/api/core/third_party/langchain/embeddings/huggingface_hub_embedding.py @@ -1,10 +1,14 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +import json from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings from langchain.utils import get_from_dict_or_env -from langchain.embeddings import HuggingFaceEmbeddings +from huggingface_hub import InferenceClient + +HOSTED_INFERENCE_API = 'hosted_inference_api' +INFERENCE_ENDPOINTS = 'inference_endpoints' class HuggingfaceHubEmbeddings(BaseModel, Embeddings): @@ -14,6 +18,7 @@ class HuggingfaceHubEmbeddings(BaseModel, Embeddings): task_type: Optional[str] = None huggingfacehub_api_type: Optional[str] = None huggingfacehub_api_token: Optional[str] = None + huggingfacehub_endpoint_url: Optional[str] = None class Config: extra = Extra.forbid @@ -24,14 +29,40 @@ def validate_environment(cls, values: Dict) -> Dict: values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) - values['client'] = HuggingFaceEmbeddings( - model_name=values['model'] - ) + values['client'] = InferenceClient(values['huggingfacehub_api_token']) return values + def embeddings(self, inputs: Union[str, List[str]]) -> str: + model = '' + + if self.huggingfacehub_api_type == HOSTED_INFERENCE_API: + model = self.model + else: + model = self.huggingfacehub_endpoint_url + + output = self.client.post( + json={ + "inputs": inputs, + "options": { + "wait_for_model": False + } + }, model=model) + + return json.loads(output.decode()) + def embed_documents(self, texts: List[str]) -> List[List[float]]: - return self.client.embed_documents(texts) + output = self.embeddings(texts) + + if isinstance(output, list): + return output + + return [list(map(float, e)) for e in output] def embed_query(self, text: str) -> List[float]: - return self.client.embed_query(text) + output = self.embeddings(text) + + if isinstance(output, list): + return output + + return list(map(float, output)) diff --git a/api/requirements.txt b/api/requirements.txt index 0571c3e21c2b0f..d616d509723c56 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -51,5 +51,4 @@ stripe~=5.5.0 pandas==1.5.3 xinference==0.4.2 safetensors==0.3.2 -zhipuai==1.0.7 -sentence-transformers==2.2.2 +zhipuai==1.0.7 \ No newline at end of file diff --git a/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py index d07837fac72162..e841e926e7be3e 100644 --- a/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py +++ b/api/tests/integration_tests/models/embedding/test_huggingface_hub_embedding.py @@ -7,7 +7,7 @@ from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider from models.provider import Provider, ProviderType, ProviderModel -DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" +DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2' def get_mock_provider(): return Provider( @@ -28,7 +28,7 @@ def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker): credentials = { 'huggingfacehub_api_type': huggingfacehub_api_type, 'huggingfacehub_api_token': valid_api_key, - 'task_type': 'sentence-similarity' + 'task_type': 'feature-extraction' } if huggingfacehub_api_type == 'inference_endpoints': @@ -56,7 +56,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_embed_documents(mock_decrypt, mocker): +def test_hosted_inference_api_embed_documents(mock_decrypt, mocker): embedding_model = get_mock_embedding_model( DEFAULT_MODEL_NAME, 'hosted_inference_api', @@ -64,15 +64,38 @@ def test_embed_documents(mock_decrypt, mocker): rst = embedding_model.client.embed_documents(['test', 'test1']) assert isinstance(rst, list) assert len(rst) == 2 - assert len(rst[0]) == 768 + assert len(rst[0]) == 384 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) -def test_embed_query(mock_decrypt, mocker): +def test_hosted_inference_api_embed_query(mock_decrypt, mocker): embedding_model = get_mock_embedding_model( DEFAULT_MODEL_NAME, 'hosted_inference_api', mocker) rst = embedding_model.client.embed_query('test') assert isinstance(rst, list) - assert len(rst) == 768 + assert len(rst) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 384 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model( + '', + 'inference_endpoints', + mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 384