Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
GarfieldDai committed Jan 2, 2024
2 parents 841b2f4 + 281bd30 commit 768e0d7
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 3 deletions.
2 changes: 1 addition & 1 deletion api/core/entities/provider_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def custom_model_credentials_validate(self, model_type: ModelType, model: str, c

model_schema = (
model_provider_factory.get_provider_instance(self.provider.provider)
.get_model_instance(model_type).get_customizable_model_schema(
.get_model_instance(model_type)._get_customizable_model_schema(
model=model,
credentials=credentials
)
Expand Down
48 changes: 46 additions & 2 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import PriceInfo, AIModelEntity, PriceType, PriceConfig, \
DefaultParameterName, FetchFrom, ModelType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer

Expand Down Expand Up @@ -243,9 +244,52 @@ def get_customizable_model_schema_from_credentials(self, model: str, credentials
return model_instance
except ValidationError as e:
logging.exception(f"Invalid model schema for {model}")
return self.get_customizable_model_schema(model, credentials)
return self._get_customizable_model_schema(model, credentials)

return self.get_customizable_model_schema(model, credentials)
return self._get_customizable_model_schema(model, credentials)

def _get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema and fill in the template
"""
schema = self.get_customizable_model_schema(model, credentials)

if not schema:
return None

# fill in the template
new_parameter_rules = []
for parameter_rule in schema.parameter_rules:
if parameter_rule.use_template:
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
if not parameter_rule.max:
parameter_rule.max = default_parameter_rule['max']
if not parameter_rule.min:
parameter_rule.min = default_parameter_rule['min']
if not parameter_rule.precision:
parameter_rule.default = default_parameter_rule['default']
if not parameter_rule.precision:
parameter_rule.precision = default_parameter_rule['precision']
if not parameter_rule.required:
parameter_rule.required = default_parameter_rule['required']
if not parameter_rule.help:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
if not parameter_rule.help.en_US:
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
if not parameter_rule.help.zh_Hans:
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError:
pass

new_parameter_rules.append(parameter_rule)

schema.parameter_rules = new_parameter_rules

return schema

def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import numpy as np
import requests
from huggingface_hub import InferenceClient, HfApi

from core.model_runtime.entities.common_entities import I18nObject
Expand All @@ -13,6 +14,9 @@
from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub


HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/'


class HuggingfaceHubTextEmbeddingModel(_CommonHuggingfaceHub, TextEmbeddingModel):

def _invoke(self, model: str, credentials: dict, texts: list[str],
Expand Down Expand Up @@ -72,7 +76,10 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
if credentials['task_type'] != 'feature-extraction':
raise CredentialsValidateFailedError('Huggingface Hub Task Type is invalid.')

self._check_endpoint_url_model_repository_name(credentials, model)

model = credentials['huggingfacehub_endpoint_url']

elif credentials['huggingfacehub_api_type'] == 'hosted_inference_api':
self._check_hosted_model_task_type(credentials['huggingfacehub_api_token'],
model)
Expand Down Expand Up @@ -154,3 +161,31 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em
)

return usage

@staticmethod
def _check_endpoint_url_model_repository_name(credentials: dict, model_name: str):
try:
url = f'{HUGGINGFACE_ENDPOINT_API}{credentials["huggingface_namespace"]}'
headers = {
'Authorization': f'Bearer {credentials["huggingfacehub_api_token"]}',
'Content-Type': 'application/json'
}

response = requests.get(url=url, headers=headers)

if response.status_code != 200:
raise ValueError('User Name or Organization Name is invalid.')

model_repository_name = ''

for item in response.json().get("items", []):
if item.get("status", {}).get("url") == credentials['huggingfacehub_endpoint_url']:
model_repository_name = item.get("model", {}).get("repository")
break

if model_repository_name != model_name:
raise ValueError(
f'Model Name {model_name} is invalid. Please check it on the inference endpoints console.')

except Exception as e:
raise ValueError(str(e))
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/localai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
type=ParameterType.INT,
use_template='max_tokens',
min=1,
max=2048,
default=512,
label=I18nObject(
zh_Hans='最大生成长度',
Expand Down

0 comments on commit 768e0d7

Please sign in to comment.