-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server #10635
Changes from all commits
6af8e61
821665b
5f2164a
dd01b53
4a6efcb
f89eaa0
980fff8
da646c1
b61a04f
e4cb992
e59cc81
f0c0a2f
b35a063
ff1d6a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import asyncio | ||
import contextlib | ||
import random | ||
import time | ||
from typing import Callable | ||
|
||
import openai | ||
import pytest | ||
import pytest_asyncio | ||
import requests | ||
|
||
from tests.utils import RemoteOpenAIServer | ||
|
||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def server(): # noqa: F811 | ||
args = [ | ||
# use half precision for speed and memory savings in CI environment | ||
"--dtype", | ||
"bfloat16", | ||
"--max-model-len", | ||
"8192", | ||
"--enforce-eager", | ||
"--max-num-seqs", | ||
"128", | ||
"--load-format", | ||
"dummy", | ||
] | ||
|
||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: | ||
yield remote_server | ||
|
||
|
||
@pytest_asyncio.fixture | ||
async def client(server): | ||
async with server.get_async_client() as async_client: | ||
yield async_client | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize( | ||
ids=["completion", "chat"], | ||
argnames=["create_func_gen", "content_body"], | ||
argvalues=[ | ||
(lambda x: x.completions.create, { | ||
"prompt": " ".join(['A'] * 10_000) | ||
}), | ||
(lambda x: x.chat.completions.create, { | ||
"messages": [{ | ||
"role": "user", | ||
"content": " ".join(['A'] * 10_000) | ||
}] | ||
}), | ||
], | ||
) | ||
async def test_with_and_without_truncate( | ||
server: RemoteOpenAIServer, | ||
client: openai.AsyncOpenAI, | ||
create_func_gen: Callable, | ||
content_body: dict, | ||
): | ||
create_func = create_func_gen(client) | ||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} | ||
|
||
num_requests = 10 | ||
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * | ||
(num_requests - num_requests // 2)) | ||
random.shuffle(truncate_prompt_tokens) | ||
|
||
bodies = [{ | ||
**body, "extra_body": { | ||
'truncate_prompt_tokens': t | ||
} | ||
} for t in truncate_prompt_tokens] | ||
|
||
async def get_status_code(**kwargs): | ||
try: | ||
await create_func(**kwargs) | ||
return 200 | ||
except openai.APIStatusError as e: | ||
return e.status_code | ||
|
||
responses = await asyncio.gather(*[get_status_code(**b) for b in bodies]) | ||
assert 500 not in responses | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize( | ||
ids=["single completion", "multiple completions", "chat"], | ||
argnames=["create_func_gen", "content_body"], | ||
argvalues=[ | ||
(lambda x: x.completions.create, { | ||
"prompt": " ".join(['A'] * 300_000) | ||
}), | ||
(lambda x: x.completions.create, { | ||
"prompt": [" ".join(['A'] * 300_000)] * 2 | ||
}), | ||
(lambda x: x.chat.completions.create, { | ||
"messages": [{ | ||
"role": "user", | ||
"content": " ".join(['A'] * 300_000) | ||
}] | ||
}), | ||
], | ||
) | ||
async def test_healthcheck_response_time( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The embedding endpoint is not tested since it requires launching the server again with |
||
server: RemoteOpenAIServer, | ||
client: openai.AsyncOpenAI, | ||
create_func_gen: Callable, | ||
content_body: dict, | ||
): | ||
num_requests = 50 | ||
|
||
create_func = create_func_gen(client) | ||
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} | ||
|
||
def get_response_time(url): | ||
start_time = time.monotonic() | ||
res = requests.get(url) | ||
end_time = time.monotonic() | ||
assert res.status_code == 200 | ||
return end_time - start_time | ||
|
||
no_load_response_time = get_response_time(server.url_for("health")) | ||
tasks = [ | ||
asyncio.create_task(create_func(**body)) for _ in range(num_requests) | ||
] | ||
await asyncio.sleep(1) # give the tasks a chance to start running | ||
load_response_time = get_response_time(server.url_for("health")) | ||
|
||
with contextlib.suppress(openai.APIStatusError): | ||
await asyncio.gather(*tasks) | ||
|
||
assert load_response_time < 100 * no_load_response_time | ||
assert load_response_time < 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails if we use a threadpool with max_workers>1 for tokenization because the tokenizer is not threadsafe. It passes both on
main
and on the PR