-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add LindormAI as a custom model provider
- Loading branch information
jiangzhijie
committed
Nov 26, 2024
1 parent
39cdfcf
commit 5fbff14
Showing
14 changed files
with
448 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,3 +41,4 @@ | |
- mixedbread | ||
- nomic | ||
- voyage | ||
- lindormai |
Empty file.
Binary file added
BIN
+112 KB
api/core/model_runtime/model_providers/lindormai/_assets/icon_l_en.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions
104
api/core/model_runtime/model_providers/lindormai/_common.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import json | ||
from collections.abc import Mapping | ||
from typing import Any | ||
from urllib.error import HTTPError, URLError | ||
|
||
import requests | ||
|
||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
|
||
|
||
def _check_credentials_fields(credentials: Mapping) -> None: | ||
if "lindormai_endpoint" not in credentials: | ||
raise CredentialsValidateFailedError("LindormAI EndPoint must be provided") | ||
if "lindormai_username" not in credentials: | ||
raise CredentialsValidateFailedError("LindormAI Username must be provided") | ||
if "lindormai_password" not in credentials: | ||
raise CredentialsValidateFailedError("LindormAI Password must be provided") | ||
|
||
|
||
class _CommonLindormAI: | ||
HTTP_HDR_AK_KEY = "x-ld-ak" | ||
HTTP_HDR_SK_KEY = "x-ld-sk" | ||
REST_URL_PATH = "/v1/ai" | ||
REST_URL_MODELS_PATH = REST_URL_PATH + "/models" | ||
INFER_INPUT_KEY = "input" | ||
INFER_PARAMS_KEY = "params" | ||
RSP_DATA_KEY = "data" | ||
RSP_MODELS_KEY = "models" | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
The key is the error type thrown to the caller | ||
The value is the error type thrown by the model, | ||
which needs to be converted into a unified error type for the caller. | ||
:return: Invoke error mapping | ||
""" | ||
return { | ||
InvokeConnectionError: [URLError], | ||
InvokeServerUnavailableError: [HTTPError], | ||
InvokeRateLimitError: [InvokeRateLimitError], | ||
InvokeAuthorizationError: [InvokeAuthorizationError], | ||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError, json.JSONDecodeError], | ||
} | ||
|
||
def _post(self, url, data=None, json=None, **kwargs): | ||
response = requests.post(url=url, data=data, json=json, **kwargs) | ||
response.raise_for_status() | ||
return response | ||
|
||
def _get(self, url, params=None, **kwargs): | ||
response = requests.get(url=url, params=params, **kwargs) | ||
response.raise_for_status() | ||
return response | ||
|
||
def _check_model_status(self, model: str, credentials: Mapping) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
endpoint = credentials.get("lindormai_endpoint") | ||
username = credentials.get("lindormai_username") | ||
passwd = credentials.get("lindormai_password") | ||
headers = {_CommonLindormAI.HTTP_HDR_AK_KEY: username, _CommonLindormAI.HTTP_HDR_SK_KEY: passwd} | ||
url = f"{endpoint}{_CommonLindormAI.REST_URL_MODELS_PATH}/{model}/status" | ||
response = self._get(url, headers=headers) | ||
if response.status_code != 200: | ||
raise ValueError("UserName or PassWord is invalid.") | ||
msg = response.json().get("msg", "ERROR:No Response Msg") | ||
if msg != "SUCCESS": | ||
raise ValueError(msg) | ||
data = response.json().get("data", {}) | ||
status = data.get("status", "") | ||
if status != "READY": | ||
raise ValueError("Model is not in READY status") | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _infer_model(self, model: str, credentials: Mapping, input_data: Any, params: dict) -> dict: | ||
_check_credentials_fields(credentials) | ||
endpoint = credentials.get("lindormai_endpoint") | ||
username = credentials.get("lindormai_username") | ||
passwd = credentials.get("lindormai_password") | ||
headers = {_CommonLindormAI.HTTP_HDR_AK_KEY: username, _CommonLindormAI.HTTP_HDR_SK_KEY: passwd} | ||
url = f"{endpoint}{_CommonLindormAI.REST_URL_MODELS_PATH}/{model}/infer" | ||
infer_dict = {_CommonLindormAI.INFER_INPUT_KEY: input_data, _CommonLindormAI.INFER_PARAMS_KEY: params} | ||
response = self._post(url, json=infer_dict, headers=headers) | ||
response.raise_for_status() | ||
result = response.json() | ||
return result[_CommonLindormAI.RSP_DATA_KEY] if result else None |
10 changes: 10 additions & 0 deletions
10
api/core/model_runtime/model_providers/lindormai/lindormai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import logging | ||
|
||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LindormAIProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
pass |
58 changes: 58 additions & 0 deletions
58
api/core/model_runtime/model_providers/lindormai/lindormai.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
provider: lindormai | ||
label: | ||
en_US: Lindorm AI | ||
icon_small: | ||
en_US: icon_s_en.png | ||
icon_large: | ||
en_US: icon_l_en.png | ||
background: "#FAF5FF" | ||
help: | ||
title: | ||
en_US: How to deploy LindormAI | ||
zh_Hans: 如何部署 LindormAI | ||
url: | ||
en_US: https://help.aliyun.com/document_detail/2393245.html?spm=a2c4g.11186623.help-menu-172543.d_2_7.4b0f1513fp0y82&scm=20140722.H_2393245._.OR_help-T_cn#DAS#zh-V_1 | ||
supported_model_types: | ||
# - llm | ||
- text-embedding | ||
- rerank | ||
# - speech2text | ||
# - tts | ||
configurate_methods: | ||
- customizable-model | ||
model_credential_schema: | ||
model: | ||
label: | ||
en_US: Model Name | ||
zh_Hans: 模型名称 | ||
placeholder: | ||
en_US: Enter your model name | ||
zh_Hans: 输入模型名称 | ||
credential_form_schemas: | ||
- variable: lindormai_endpoint | ||
label: | ||
zh_Hans: 服务器ENDPOINT | ||
en_US: Server Endpoint | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入Lindorm的AI连接地址,如 http://ld-xxxxxxxxxxxxx-proxy-ai-pub.lindorm.aliyuncs.com:9002 | ||
en_US: Enter the endpoint of you LindormAI, e.g. http://ld-xxxxxxxxxxxxx-proxy-ai-pub.lindorm.aliyuncs.com:9002 | ||
- variable: lindormai_username | ||
label: | ||
zh_Hans: lindorm 用户名 | ||
en_US: Model uid | ||
type: text-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的用户名 | ||
en_US: Enter your lindorm username | ||
- variable: lindormai_password | ||
label: | ||
zh_Hans: 密码 | ||
en_US: password | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的密码 | ||
en_US: Enter the password |
Empty file.
81 changes: 81 additions & 0 deletions
81
api/core/model_runtime/model_providers/lindormai/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import logging | ||
from collections.abc import Mapping | ||
from typing import Optional | ||
|
||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType | ||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
from core.model_runtime.model_providers.lindormai._common import _check_credentials_fields, _CommonLindormAI | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LindormAIRerankModel(_CommonLindormAI, RerankModel): | ||
def validate_credentials(self, model: str, credentials: Mapping) -> None: | ||
try: | ||
_check_credentials_fields(credentials) | ||
super()._check_model_status(model, credentials) | ||
self._invoke( | ||
model=model, | ||
credentials=dict(credentials), | ||
query="What is the capital of the United States?", | ||
docs=[ | ||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States " | ||
"Census, Carson City had a population of 55,274.", | ||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " | ||
"are a political division controlled by the United States. Its capital is Saipan.", | ||
], | ||
score_threshold=0.1, | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
try: | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
_check_credentials_fields(credentials) | ||
if top_n is None: | ||
top_n = -1 | ||
results = super()._infer_model( | ||
model=model, | ||
credentials=credentials, | ||
input_data={"query": query, "chunks": docs}, | ||
params={"topK": top_n}, | ||
) | ||
rerank_documents = [] | ||
for res in results: | ||
if res["score"] >= score_threshold: | ||
rerank_document = RerankDocument(index=res["index"], text=res["chunk"], score=res["score"]) | ||
rerank_documents.append(rerank_document) | ||
rerank_documents.sort(key=lambda x: x.score, reverse=True) | ||
return RerankResult(model=model, docs=rerank_documents) | ||
except Exception as e: | ||
logger.exception(f"Failed to invoke rerank model, model: {model}") | ||
raise | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | ||
""" | ||
used to define customizable model schema | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_type=ModelType.RERANK, | ||
model_properties={}, | ||
parameter_rules=[], | ||
) | ||
|
||
return entity |
Empty file.
95 changes: 95 additions & 0 deletions
95
api/core/model_runtime/model_providers/lindormai/text_embedding/text_embedding.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import time | ||
from collections.abc import Mapping | ||
from typing import Optional | ||
|
||
from core.entities.embedding_type import EmbeddingInputType | ||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType | ||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel | ||
from core.model_runtime.model_providers.lindormai._common import _check_credentials_fields, _CommonLindormAI | ||
|
||
|
||
class LindormAITextEmbeddingModel(_CommonLindormAI, TextEmbeddingModel): | ||
def validate_credentials(self, model: str, credentials: Mapping) -> None: | ||
try: | ||
_check_credentials_fields(credentials) | ||
super()._check_model_status(model, credentials) | ||
self._invoke(model=model, credentials=dict(credentials), texts=["hello, New York!"]) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: | ||
""" | ||
Calculate response usage | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param tokens: input tokens | ||
:return: usage | ||
""" | ||
input_price_info = self.get_price( | ||
model=model, credentials=credentials, tokens=tokens, price_type=PriceType.INPUT | ||
) | ||
|
||
usage = EmbeddingUsage( | ||
tokens=tokens, | ||
total_tokens=tokens, | ||
unit_price=input_price_info.unit_price, | ||
price_unit=input_price_info.unit, | ||
total_price=input_price_info.total_amount, | ||
currency=input_price_info.currency, | ||
latency=time.perf_counter() - self.started_at, | ||
) | ||
|
||
return usage | ||
|
||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: | ||
return 0 | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
texts: list[str], | ||
user: Optional[str] = None, | ||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, | ||
) -> TextEmbeddingResult: | ||
""" | ||
Invoke text embedding model | ||
credentials should be like: | ||
{ | ||
'server_url': 'server url', | ||
'model_uid': 'model uid', | ||
} | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param texts: texts to embed | ||
:param user: unique user id | ||
:param input_type: input type | ||
:return: embeddings result | ||
""" | ||
batch_embeddings = super()._infer_model(model, credentials, texts, {}) | ||
token = self.get_num_tokens(model, credentials, texts) | ||
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=token) | ||
result = TextEmbeddingResult(model=model, embeddings=batch_embeddings, usage=usage) | ||
return result | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: | ||
""" | ||
used to define customizable model schema | ||
""" | ||
|
||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_type=ModelType.TEXT_EMBEDDING, | ||
model_properties={}, | ||
parameter_rules=[], | ||
) | ||
|
||
return entity |
Empty file.
Oops, something went wrong.