-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: xinference rerank model support (#1615)
- Loading branch information
Showing
9 changed files
with
215 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
58 changes: 58 additions & 0 deletions
58
api/core/model_providers/models/reranking/xinference_reranking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import logging | ||
from typing import Optional, List | ||
|
||
from langchain.schema import Document | ||
from xinference_client.client.restful.restful_client import Client | ||
|
||
from core.model_providers.error import LLMBadRequestError | ||
from core.model_providers.models.reranking.base import BaseReranking | ||
from core.model_providers.providers.base import BaseModelProvider | ||
|
||
|
||
class XinferenceReranking(BaseReranking): | ||
|
||
def __init__(self, model_provider: BaseModelProvider, name: str): | ||
self.credentials = model_provider.get_model_credentials( | ||
model_name=name, | ||
model_type=self.type | ||
) | ||
|
||
client = Client(self.credentials['server_url']) | ||
|
||
super().__init__(model_provider, client, name) | ||
|
||
def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: | ||
docs = [] | ||
doc_id = [] | ||
for document in documents: | ||
if document.metadata['doc_id'] not in doc_id: | ||
doc_id.append(document.metadata['doc_id']) | ||
docs.append(document.page_content) | ||
|
||
model = self.client.get_model(self.credentials['model_uid']) | ||
response = model.rerank(query=query, documents=docs, top_n=top_k) | ||
rerank_documents = [] | ||
|
||
for idx, result in enumerate(response['results']): | ||
# format document | ||
index = result['index'] | ||
rerank_document = Document( | ||
page_content=result['document'], | ||
metadata={ | ||
"doc_id": documents[index].metadata['doc_id'], | ||
"doc_hash": documents[index].metadata['doc_hash'], | ||
"document_id": documents[index].metadata['document_id'], | ||
"dataset_id": documents[index].metadata['dataset_id'], | ||
'score': result['relevance_score'] | ||
} | ||
) | ||
# score threshold check | ||
if score_threshold is not None: | ||
if result.relevance_score >= score_threshold: | ||
rerank_documents.append(rerank_document) | ||
else: | ||
rerank_documents.append(rerank_document) | ||
return rerank_documents | ||
|
||
def handle_exceptions(self, ex: Exception) -> Exception: | ||
return LLMBadRequestError(f"Xinference rerank: {str(ex)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
61 changes: 61 additions & 0 deletions
61
api/tests/integration_tests/models/reranking/test_cohere_reranking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import json | ||
import os | ||
from unittest.mock import patch | ||
|
||
from langchain.schema import Document | ||
|
||
from core.model_providers.models.reranking.cohere_reranking import CohereReranking | ||
from core.model_providers.providers.cohere_provider import CohereProvider | ||
from models.provider import Provider, ProviderType | ||
|
||
|
||
def get_mock_provider(valid_api_key): | ||
return Provider( | ||
id='provider_id', | ||
tenant_id='tenant_id', | ||
provider_name='cohere', | ||
provider_type=ProviderType.CUSTOM.value, | ||
encrypted_config=json.dumps({'api_key': valid_api_key}), | ||
is_valid=True, | ||
) | ||
|
||
|
||
def get_mock_model(): | ||
valid_api_key = os.environ['COHERE_API_KEY'] | ||
provider = CohereProvider(provider=get_mock_provider(valid_api_key)) | ||
return CohereReranking( | ||
model_provider=provider, | ||
name='rerank-english-v2.0' | ||
) | ||
|
||
|
||
def decrypt_side_effect(tenant_id, encrypted_api_key): | ||
return encrypted_api_key | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_run(mock_decrypt): | ||
model = get_mock_model() | ||
|
||
docs = [] | ||
docs.append(Document( | ||
page_content='bye', | ||
metadata={ | ||
"doc_id": 'a', | ||
"doc_hash": 'doc_hash', | ||
"document_id": 'document_id', | ||
"dataset_id": 'dataset_id', | ||
} | ||
)) | ||
docs.append(Document( | ||
page_content='hello', | ||
metadata={ | ||
"doc_id": 'b', | ||
"doc_hash": 'doc_hash', | ||
"document_id": 'document_id', | ||
"dataset_id": 'dataset_id', | ||
} | ||
)) | ||
rst = model.rerank('hello', docs, None, 2) | ||
|
||
assert rst[0].page_content == 'hello' |
78 changes: 78 additions & 0 deletions
78
api/tests/integration_tests/models/reranking/test_xinference_reranking.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import json | ||
import os | ||
from unittest.mock import patch, MagicMock | ||
|
||
from langchain.schema import Document | ||
|
||
from core.model_providers.models.entity.model_params import ModelType | ||
from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking | ||
from core.model_providers.providers.xinference_provider import XinferenceProvider | ||
from models.provider import Provider, ProviderType, ProviderModel | ||
|
||
|
||
def get_mock_provider(valid_server_url, valid_model_uid): | ||
return Provider( | ||
id='provider_id', | ||
tenant_id='tenant_id', | ||
provider_name='xinference', | ||
provider_type=ProviderType.CUSTOM.value, | ||
encrypted_config=json.dumps({'server_url': valid_server_url, 'model_uid': valid_model_uid}), | ||
is_valid=True, | ||
) | ||
|
||
|
||
def get_mock_model(mocker): | ||
valid_server_url = os.environ['XINFERENCE_SERVER_URL'] | ||
valid_model_uid = os.environ['XINFERENCE_MODEL_UID'] | ||
model_name = 'bge-reranker-base' | ||
provider = XinferenceProvider(provider=get_mock_provider(valid_server_url, valid_model_uid)) | ||
|
||
mock_query = MagicMock() | ||
mock_query.filter.return_value.first.return_value = ProviderModel( | ||
provider_name='xinference', | ||
model_name=model_name, | ||
model_type=ModelType.RERANKING.value, | ||
encrypted_config=json.dumps({ | ||
'server_url': valid_server_url, | ||
'model_uid': valid_model_uid | ||
}), | ||
is_valid=True, | ||
) | ||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) | ||
|
||
return XinferenceReranking( | ||
model_provider=provider, | ||
name=model_name | ||
) | ||
|
||
|
||
def decrypt_side_effect(tenant_id, encrypted_api_key): | ||
return encrypted_api_key | ||
|
||
|
||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) | ||
def test_run(mock_decrypt, mocker): | ||
model = get_mock_model(mocker) | ||
|
||
docs = [] | ||
docs.append(Document( | ||
page_content='bye', | ||
metadata={ | ||
"doc_id": 'a', | ||
"doc_hash": 'doc_hash', | ||
"document_id": 'document_id', | ||
"dataset_id": 'dataset_id', | ||
} | ||
)) | ||
docs.append(Document( | ||
page_content='hello', | ||
metadata={ | ||
"doc_id": 'b', | ||
"doc_hash": 'doc_hash', | ||
"document_id": 'document_id', | ||
"dataset_id": 'dataset_id', | ||
} | ||
)) | ||
rst = model.rerank('hello', docs, None, 2) | ||
|
||
assert rst[0].page_content == 'hello' |