From 3efaa713da3f9e84e3c2a6ca124d22c514fcfa2d Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 13 Oct 2023 15:46:09 +0800 Subject: [PATCH] feat: use xinference client instead of xinference (#1339) --- .../models/embedding/xinference_embedding.py | 3 +- .../providers/xinference_provider.py | 2 +- .../embeddings/xinference_embedding.py | 43 +++++++++++++++-- .../langchain/llms/xinference_llm.py | 47 +++++++++++++++++-- api/requirements.txt | 2 +- 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/api/core/model_providers/models/embedding/xinference_embedding.py b/api/core/model_providers/models/embedding/xinference_embedding.py index 4cecd511e5caad..81f9756a166bff 100644 --- a/api/core/model_providers/models/embedding/xinference_embedding.py +++ b/api/core/model_providers/models/embedding/xinference_embedding.py @@ -1,8 +1,7 @@ -from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings - from core.model_providers.error import LLMBadRequestError from core.model_providers.providers.base import BaseModelProvider from core.model_providers.models.embedding.base import BaseEmbedding +from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings class XinferenceEmbedding(BaseEmbedding): diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index fff0119eaf09d6..af1f050b87a8db 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -2,7 +2,6 @@ from typing import Type import requests -from langchain.embeddings import XinferenceEmbeddings from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding @@ -11,6 +10,7 @@ from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel +from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings from core.third_party.langchain.llms.xinference_llm import XinferenceLLM from models.provider import ProviderType diff --git a/api/core/third_party/langchain/embeddings/xinference_embedding.py b/api/core/third_party/langchain/embeddings/xinference_embedding.py index 371e240e9f623f..da433822c40a0e 100644 --- a/api/core/third_party/langchain/embeddings/xinference_embedding.py +++ b/api/core/third_party/langchain/embeddings/xinference_embedding.py @@ -1,21 +1,54 @@ -from typing import List +from typing import List, Optional, Any import numpy as np -from langchain.embeddings import XinferenceEmbeddings +from langchain.embeddings.base import Embeddings +from xinference_client.client.restful.restful_client import Client -class XinferenceEmbedding(XinferenceEmbeddings): +class XinferenceEmbeddings(Embeddings): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None + ): + + super().__init__() + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.client = Client(server_url) def embed_documents(self, texts: List[str]) -> List[List[float]]: - vectors = super().embed_documents(texts) + model = self.client.get_model(self.model_uid) + embeddings = [ + model.create_embedding(text)["data"][0]["embedding"] for text in texts + ] + vectors = [list(map(float, e)) for e in embeddings] normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors] return normalized_vectors def embed_query(self, text: str) -> List[float]: - vector = super().embed_query(text) + model = self.client.get_model(self.model_uid) + + embedding_res = model.create_embedding(text) + + embedding = embedding_res["data"][0]["embedding"] + vector = list(map(float, embedding)) normalized_vector = (vector / np.linalg.norm(vector)).tolist() return normalized_vector diff --git a/api/core/third_party/langchain/llms/xinference_llm.py b/api/core/third_party/langchain/llms/xinference_llm.py index c65688adc7a9d8..33ab9ababde294 100644 --- a/api/core/third_party/langchain/llms/xinference_llm.py +++ b/api/core/third_party/langchain/llms/xinference_llm.py @@ -1,16 +1,53 @@ -from typing import Optional, List, Any, Union, Generator +from typing import Optional, List, Any, Union, Generator, Mapping from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms import Xinference +from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from xinference.client import ( +from xinference_client.client.restful.restful_client import ( RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, - RESTfulGenerateModelHandle, + RESTfulGenerateModelHandle, Client, ) -class XinferenceLLM(Xinference): +class XinferenceLLM(LLM): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None + ): + super().__init__( + **{ + "server_url": server_url, + "model_uid": model_uid, + } + ) + + if self.server_url is None: + raise ValueError("Please provide server URL") + + if self.model_uid is None: + raise ValueError("Please provide the model UID") + + self.client = Client(server_url) + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "xinference" + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + **{"server_url": self.server_url}, + **{"model_uid": self.model_uid}, + } + def _call( self, prompt: str, diff --git a/api/requirements.txt b/api/requirements.txt index 5c2383c9e8325f..8c873f6113b778 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -49,7 +49,7 @@ huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 -xinference==0.5.2 +xinference-client~=0.1.2 safetensors==0.3.2 zhipuai==1.0.7 werkzeug==2.3.7