Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server #10635

Merged
merged 14 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions tests/entrypoints/openai/test_async_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import asyncio
import contextlib
import random
import time
from typing import Callable

import openai
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


@pytest.fixture(scope="module")
def server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--load-format",
"dummy",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["completion", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 10_000)
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 10_000)
}]
}),
],
)
async def test_with_and_without_truncate(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails if we use a threadpool with max_workers>1 for tokenization because the tokenizer is not threadsafe. It passes both on main and on the PR

server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

num_requests = 10
truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
(num_requests - num_requests // 2))
random.shuffle(truncate_prompt_tokens)

bodies = [{
**body, "extra_body": {
'truncate_prompt_tokens': t
}
} for t in truncate_prompt_tokens]

async def get_status_code(**kwargs):
try:
await create_func(**kwargs)
return 200
except openai.APIStatusError as e:
return e.status_code

responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
assert 500 not in responses


@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=["single completion", "multiple completions", "chat"],
argnames=["create_func_gen", "content_body"],
argvalues=[
(lambda x: x.completions.create, {
"prompt": " ".join(['A'] * 300_000)
}),
(lambda x: x.completions.create, {
"prompt": [" ".join(['A'] * 300_000)] * 2
}),
(lambda x: x.chat.completions.create, {
"messages": [{
"role": "user",
"content": " ".join(['A'] * 300_000)
}]
}),
],
)
async def test_healthcheck_response_time(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on main and passes on the PR. It validates the fix actually works and the server event loop isn't blocked on tokenization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embedding endpoint is not tested since it requires launching the server again with --task embedding and I wanted to save some time. I think that test will be redundant anyway, since the embedding endpoint uses the same tokenization code paths as completions and chat

server: RemoteOpenAIServer,
client: openai.AsyncOpenAI,
create_func_gen: Callable,
content_body: dict,
):
num_requests = 50

create_func = create_func_gen(client)
body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

def get_response_time(url):
start_time = time.monotonic()
res = requests.get(url)
end_time = time.monotonic()
assert res.status_code == 200
return end_time - start_time

no_load_response_time = get_response_time(server.url_for("health"))
tasks = [
asyncio.create_task(create_func(**body)) for _ in range(num_requests)
]
await asyncio.sleep(1) # give the tasks a chance to start running
load_response_time = get_response_time(server.url_for("health"))

with contextlib.suppress(openai.APIStatusError):
await asyncio.gather(*tasks)

assert load_response_time < 100 * no_load_response_time
assert load_response_time < 0.1
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def create_completion(

tokenizer = await self.engine_client.get_tokenizer(lora_request)

request_prompts, engine_prompts = self._preprocess_completion(
request_prompts, engine_prompts = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,14 @@ async def create_embedding(
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
Expand Down
75 changes: 40 additions & 35 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Expand Down Expand Up @@ -46,7 +47,7 @@
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
from vllm.utils import AtomicCounter, is_list_of, make_async

logger = init_logger(__name__)

Expand Down Expand Up @@ -140,6 +141,14 @@ def __init__(
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids

self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
Expand Down Expand Up @@ -368,53 +377,49 @@ def _tokenize_prompt_input_or_inputs(
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
) -> List[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.

According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
return [
self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
if prompt_input["is_tokens"] is False else
self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
]

def _preprocess_completion(
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)

engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
Expand Down Expand Up @@ -493,7 +498,7 @@ async def _preprocess_chat(
request=request)

if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
Expand Down
10 changes: 6 additions & 4 deletions vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import make_async, merge_async_iterators, random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -145,9 +145,11 @@ async def create_score(
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens

prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ async def create_tokenize(
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
Expand Down Expand Up @@ -134,7 +135,7 @@ async def create_detokenize(
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)

prompt_input = self._tokenize_prompt_input(
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,
Expand Down
8 changes: 6 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import asyncio
import concurrent
import contextlib
import datetime
import enum
Expand Down Expand Up @@ -351,7 +352,10 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower()


def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
def make_async(
func: Callable[P, T],
executor: Optional[concurrent.futures.Executor] = None
) -> Callable[P, Awaitable[T]]:
"""Take a blocking function, and run it on in an executor thread.

This function prevents the blocking function from blocking the
Expand All @@ -362,7 +366,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
loop = asyncio.get_event_loop()
p_func = partial(func, *args, **kwargs)
return loop.run_in_executor(executor=None, func=p_func)
return loop.run_in_executor(executor=executor, func=p_func)

return _async_wrapper

Expand Down