Skip to content

Commit

Permalink
fix: replace os.path.join with yarl (#2690)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Mar 5, 2024
1 parent 552f319 commit 9573379
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 43 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from os import path
from threading import Lock
from time import time

from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL


class XinferenceModelExtraParameter:
Expand Down Expand Up @@ -55,7 +55,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
get xinference model extra parameter like model_format and model_handle_type
"""

url = path.join(server_url, 'v1/models', model_uid)
if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
raise RuntimeError('model_uid is empty')

url = str(URL(server_url) / 'v1' / 'models' / model_uid)

# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
Expand All @@ -66,7 +69,6 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')

if response.status_code != 200:
raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')

Expand Down
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ pydub~=0.25.1
gmpy2~=2.1.5
numexpr~=2.9.0
duckduckgo-search==4.4.3
arxiv==2.1.0
arxiv==2.1.0
yarl~=1.9.4
80 changes: 41 additions & 39 deletions api/tests/integration_tests/model_runtime/__mock/xinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,68 +32,70 @@ def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
# get model uid
model_uid = url.split('/')[-1]
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
response.status_code = 404
response._content = b'{}'
return response

# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
response.status_code = 404
response._content = b'{}'
return response

if model_uid in ['generate', 'chat']:
response.status_code = 200
response._content = b'''{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "chatglm3-6b",
"model_lang": [
"en"
],
"model_ability": [
"generate",
"chat"
],
"model_description": "latest chatglm3",
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantization": "none",
"model_hub": "huggingface",
"revision": null,
"context_length": 2048,
"replica": 1
}'''
return response

elif model_uid == 'embedding':
response.status_code = 200
response._content = b'''{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
"0",
"1"
],
"model_name": "bge",
"model_lang": [
"en"
],
"revision": null,
"max_tokens": 512
}'''
return response

elif 'v1/cluster/auth' in url:
response.status_code = 200
response._content = b'''{
"auth": true
}'''
"auth": true
}'''
return response

def _check_cluster_authenticated(self):
Expand Down

0 comments on commit 9573379

Please sign in to comment.