Skip to content

Commit

Permalink
Merge branch 'feat/huggingface-embedding-support' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Sep 21, 2023
2 parents c7e5211 + 2c10a1e commit e001b2d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
34 changes: 34 additions & 0 deletions api/core/model_providers/providers/huggingface_hub_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Type
import requests

from huggingface_hub import HfApi

Expand All @@ -14,6 +15,8 @@
from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
from models.provider import ProviderType

HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'


class HuggingfaceHubProvider(BaseModelProvider):
@property
Expand Down Expand Up @@ -132,13 +135,44 @@ def check_llm_valid(cls, credentials: dict):

@classmethod
def check_embedding_valid(cls, credentials: dict, model_name: str):

cls.check_endpoint_url_model_repository_name(credentials, model_name)

embedding_model = HuggingfaceHubEmbeddings(
model=model_name,
**credentials
)

embedding_model.embed_query("ping")

@classmethod
def check_endpoint_url_model_repository_name(cls, credentials: dict, model_name: str):
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
'Content-Type': 'application/json'
}

response =requests.get(url=url, headers=headers)

if response.status_code != 200:
raise ValueError('User Name or Organization Name is invalid.')

model_repository_name = ''

for item in response.json().get("items", []):
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
model_repository_name = item.get("model", {}).get("repository")
break

if model_repository_name != model_name:
raise ValueError(f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')

except Exception as e:
raise ValueError(str(e))


@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HuggingfaceHubEmbeddings(BaseModel, Embeddings):
client: Any
model: str

huggingface_namespace: Optional[str] = None
task_type: Optional[str] = None
huggingfacehub_api_type: Optional[str] = None
huggingfacehub_api_token: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
if (v.model_type === 'embeddings') {
return [
'huggingfacehub_api_token',
'huggingface_namespace',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
]
}
return [
'huggingfacehub_api_token',
'model_name',
Expand All @@ -68,14 +77,27 @@ const config: ProviderConfig = {
]
}
if (v?.huggingfacehub_api_type === 'inference_endpoints') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
if (v.model_type === 'embeddings') {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'huggingface_namespace',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
else {
filteredKeys = [
'huggingfacehub_api_type',
'huggingfacehub_api_token',
'model_name',
'huggingfacehub_endpoint_url',
'task_type',
'model_type',
]
}
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
Expand Down Expand Up @@ -146,6 +168,20 @@ const config: ProviderConfig = {
'zh-Hans': '在此输入您的 Hugging Face Hub API Token',
},
},
{
hidden: (value?: FormValue) => !(value?.huggingfacehub_api_type === 'inference_endpoints' && value?.model_type === 'embeddings'),
type: 'text',
key: 'huggingface_namespace',
required: true,
label: {
'en': 'User Name / Organization Name',
'zh-Hans': '用户名 / 组织名称',
},
placeholder: {
'en': 'Enter your User Name / Organization Name here',
'zh-Hans': '在此输入您的用户名 / 组织名称',
},
},
{
type: 'text',
key: 'model_name',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ const ModelModal: FC<ModelModalProps> = ({
return (
<Portal>
<div className='fixed inset-0 flex items-center justify-center bg-black/[.25]'>
<div className='w-[640px] max-h-screen bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='w-[640px] max-h-[calc(100vh-120px)] bg-white shadow-xl rounded-2xl overflow-y-auto'>
<div className='px-8 pt-8'>
<div className='flex justify-between items-center mb-2'>
<div className='text-xl font-semibold text-gray-900'>{renderTitlePrefix()}</div>
Expand Down

0 comments on commit e001b2d

Please sign in to comment.