Skip to content

Commit

Permalink
fix: add get_customizable_model_schema to xinference
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly committed Dec 28, 2023
1 parent 8635c58 commit 3ed46ce
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
InvokeAuthorizationError, InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.entities.model_entities import FetchFrom, ModelType, AIModelEntity
from core.model_runtime.entities.common_entities import I18nObject

from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle, RESTfulModelHandle, Client
from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle, Client

class XinferenceRerankModel(RerankModel):
"""
Expand Down Expand Up @@ -126,3 +128,20 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
ValueError
]
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={ },
parameter_rules=[]
)

return entity
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.model_entities import PriceType, FetchFrom, ModelType, AIModelEntity
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult, EmbeddingUsage
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
Expand Down Expand Up @@ -155,3 +156,20 @@ def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> Em
)

return usage

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TEXT_EMBEDDING,
model_properties={},
parameter_rules=[]
)

return entity

0 comments on commit 3ed46ce

Please sign in to comment.