Skip to content

Commit

Permalink
feat: use xinference client instead of xinference (#1339)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Oct 13, 2023
1 parent 9822f68 commit 3efaa71
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings

from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings


class XinferenceEmbedding(BaseEmbedding):
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_providers/providers/xinference_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Type

import requests
from langchain.embeddings import XinferenceEmbeddings

from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
Expand All @@ -11,6 +10,7 @@
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError

from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType

Expand Down
43 changes: 38 additions & 5 deletions api/core/third_party/langchain/embeddings/xinference_embedding.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,54 @@
from typing import List
from typing import List, Optional, Any

import numpy as np
from langchain.embeddings import XinferenceEmbeddings
from langchain.embeddings.base import Embeddings
from xinference_client.client.restful.restful_client import Client


class XinferenceEmbedding(XinferenceEmbeddings):
class XinferenceEmbeddings(Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):

super().__init__()

if server_url is None:
raise ValueError("Please provide server URL")

if model_uid is None:
raise ValueError("Please provide the model UID")

self.server_url = server_url

self.model_uid = model_uid

self.client = Client(server_url)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
vectors = super().embed_documents(texts)
model = self.client.get_model(self.model_uid)

embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
vectors = [list(map(float, e)) for e in embeddings]
normalized_vectors = [(vector / np.linalg.norm(vector)).tolist() for vector in vectors]

return normalized_vectors

def embed_query(self, text: str) -> List[float]:
vector = super().embed_query(text)
model = self.client.get_model(self.model_uid)

embedding_res = model.create_embedding(text)

embedding = embedding_res["data"][0]["embedding"]

vector = list(map(float, embedding))
normalized_vector = (vector / np.linalg.norm(vector)).tolist()

return normalized_vector
47 changes: 42 additions & 5 deletions api/core/third_party/langchain/llms/xinference_llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,53 @@
from typing import Optional, List, Any, Union, Generator
from typing import Optional, List, Any, Union, Generator, Mapping

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import (
from xinference_client.client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulGenerateModelHandle, Client,
)


class XinferenceLLM(Xinference):
class XinferenceLLM(LLM):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)

if self.server_url is None:
raise ValueError("Please provide server URL")

if self.model_uid is None:
raise ValueError("Please provide the model UID")

self.client = Client(server_url)

@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}

def _call(
self,
prompt: str,
Expand Down
2 changes: 1 addition & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.5.2
xinference-client~=0.1.2
safetensors==0.3.2
zhipuai==1.0.7
werkzeug==2.3.7
Expand Down

0 comments on commit 3efaa71

Please sign in to comment.