Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server #10635

Merged
merged 14 commits into from
Nov 27, 2024

Conversation

tomeras91
Copy link
Contributor

@tomeras91 tomeras91 commented Nov 25, 2024

Currently, tokenization of requests in the OpenAI compatible server is done synchronously on the process running the server. Usually, this is not a problem, but when dealing with long sequence lengths (256K), tokenization can take a few hundred milliseconds. Combining long requests with high loads causes the server to be non-responsive as it spends a lot of time tokenizing requests synchronously and sequentially, blocking the asyncio event loop.

This PR aims to unblock the event loop by using a thread pool to tokenize requests. This following #3512 which also introduced parallel async tokenization but was closed and this feature was never delivered.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the frontend label Nov 25, 2024
@tomeras91 tomeras91 changed the title [Frontend] don't block GIL in tokenization (preprocess) in OpenAI compatible server [Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server Nov 25, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Nov 25, 2024

LGTM - as a sanity check, can you run a quick ShareGPT benchmark on a small model? If you need instructions for this, I can share commands.

…ver by using threadpool for tokenization

Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
@tomeras91 tomeras91 force-pushed the async-tokenization-in-oai-server branch from 5162f66 to dd01b53 Compare November 25, 2024 13:07
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
@tomeras91
Copy link
Contributor Author

LGTM - as a sanity check, can you run a quick ShareGPT benchmark on a small model? If you need instructions for this, I can share commands.

Sure. Here are the results

Model: Qwen/Qwen2.5-1.5B-Instruct
Hardware: single H100

branch throughput (requests/s) mean TTFT (msec)
main (commit ed46f14) 51.33 6645.88
pr (commit 980fff8) 51.62 6251.77

So results are pretty much identical. Maybe a bit better with the PR but I guess the diffs are too small to be meaningful.

Other than that, I think the CI failures are not related to the changes in this PR

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tomeras91.

Another thought is that it's likely optimal to only dispatch to another thread if the length of text / number of tokens is above some threshold, otherwise it may be slightly detrimental.

But like @robertgshaw2-neuralmagic said, we should make sure to benchmark this anyhow.

Comment on lines 145 to 148
self._tokenize_prompt_inputs_async = make_async(
self._tokenize_prompt_inputs)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these will work as intended. The methods return a generator which will be done asynchronously but the actual work done to generate the outputs will still be done on the asyncio event loop while iterating.

We'll probably need to think of another way to arrange this, possibly we can change these methods to just return lists rather than generators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good catch. It really didn't work for /v1/completions which uses these methods. It did work for /v1/chat/completions that uses _tokenize_prompt_input which doesn't return a generator and actually does the tokenization work.
Anyway, fixed by making _tokenize_prompt_input_or_inputs return a List as you suggested.

Will post updated benchmarks and something that shows this works shortly

@njhill
Copy link
Member

njhill commented Nov 25, 2024

@tomeras91 presumably you'll also test with your >200k token workload to ensure that it helps?

…in self._tokenize_prompt_input

Signed-off-by: Tomer Asida <[email protected]>
…c will actually make execution run in thread and not just generator creation

Signed-off-by: Tomer Asida <[email protected]>
@tomeras91
Copy link
Contributor Author

tomeras91 commented Nov 25, 2024

Updated benchmarks

Now both on /v1/completions and on /v1/chat/completions

Model: Qwen/Qwen2.5-1.5B-Instruct
Hardware: single H100
Serve run command: vllm serve Qwen/Qwen2.5-1.5B-Instruct

Endpoint: v1/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 52.88 6406.30 23.09 189265
pr (commit b61a04f) 52.92 5985.69 21.29 191136

Conclusion: Pretty much the same (very slightly better performance in PR even though more tokens were generated)

Endpoint: v1/chat/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --endpoint /v1/chat/completions --backend openai-chat

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 57.63 5946.73 23.73 171968
pr (commit b61a04f) 55.52 5972.38 23.50 172531

Conclusion: Pretty much the same (slightly worse performance in PR, maybe due to more generated tokens)

Anyway, I would expect this PR to affect TTFT much more than it affects TPOT, but it seems like a very insignificant change regardless. If at all, the effect is more significant for the completions endpoint, and is an improvement.

RE something to show this actually solves the problem:

Here is the code I used:

import requests
import threading
from threading import Thread
import time

def send_request(url, body):
    # print(f"Sending from thread {threading.get_ident()}")
    start = time.time()
    res = requests.post(url, json=body)
    end = time.time()
    # print(f"got res in thread {threading.get_ident()} after {1000*(end-start):.2f} ms with status {res.status_code}. content: {res.json() if res.status_code!=200 else 'success'}")
    return res

model = "Qwen/Qwen2.5-1.5B-Instruct"

chat_url = "http://localhost:8000/v1/chat/completions"
chat_body = {'model': model,
             'messages': [{'role': 'user', 'content': ' '.join(['A']*300_000)}],
            }

completions_url = "http://localhost:8000/v1/completions"
completions_body = {'model': model,
                    'prompt': ' '.join(['A']*300_000),
                   }
completions_body_mult = {'model': model,
                         'prompt': [' '.join(['A']*300_000)]*2,
                        }

for _ in range(60):
    # Thread(target=send_request, kwargs={'url': completions_url, 'body': completions_body}).start()
    Thread(target=send_request, kwargs={'url': completions_url, 'body': completions_body_mult}).start()
    # Thread(target=send_request, kwargs={'url': chat_url, 'body': chat_body}).start()
start = time.time()
res = requests.get("http://localhost:8000/health")
end = time.time()
print(f"######################### healthcheck took {end-start} secs #########################")

The idea is to send many long context requests in close succession (using multiple threads), and then check how long /health takes to respond. If the event loop is blocked, we would expect long response time of the healthcheck. As you can see, I checked it for both v1/completions and v1/chat/completions, and also for multiple inputs for the completions endpoint. Note that these requests are longer than the supported max model length (32K) so no work is actually done by the model and the requests are blocked after tokenization.

The response time of /health when the server is not dealing with any requests was about 2.3 msecs

branch endpoint health response time (msec)
main (commit ed46f14) v1/completions 17848.1
main (commit ed46f14) v1/completions multiple inputs 14285.4
main (commit ed46f14) v1/chat/completions 17159.3
pr (commit b61a04f) v1/completions 34.6
pr (commit b61a04f) v1/completions multiple inputs 41.0
pr (commit b61a04f) v1/chat/completions 16.4

The response time is much shorter in the PR. We do see that the healthcheck takes a bit longer to respond under this long context requests load, which can be attributed to the fact that tokenizing many such long requests does take a bit CPU

@njhill
Copy link
Member

njhill commented Nov 26, 2024

Thanks @tomeras91, it looks great now. One minor concern I realized is that the HF tokenizers technically aren't thread-safe. I think in practice this is only the case if things like truncation and/or padding are being done, but that can be true for us (if truncate parameter is specified in the request).

Despite this it might still not be a problem given that I don't think the GIL is really released. But I'm not sure whether we should put an explicit lock around the calls to the tokenizer or create a tokenizer per thread. The former would likely be less invasive and could be done by wrapping/replacing the __call__ method when the tokenizer is created. It would probably be best to re-run the perf test after doing that.

@tomeras91
Copy link
Contributor Author

tomeras91 commented Nov 26, 2024

Thanks @tomeras91, it looks great now. One minor concern I realized is that the HF tokenizers technically aren't thread-safe. I think in practice this is only the case if things like truncation and/or padding are being done, but that can be true for us (if truncate parameter is specified in the request).

Despite this it might still not be a problem given that I don't think the GIL is really released. But I'm not sure whether we should put an explicit lock around the calls to the tokenizer or create a tokenizer per thread. The former would likely be less invasive and could be done by wrapping/replacing the __call__ method when the tokenizer is created. It would probably be best to re-run the perf test after doing that.

@njhill - That's another great catch. Without putting an explicit lock (commit b61a04f), sending multiple requests in parallel, some with truncate_prompt_tokens and some without causes error code 500 failures. I guess this is some sort of similar issue to huggingface/tokenizers#537, dealing with the tokenizer's mutable truncation state, exactly as you pointed out. main (commit ed46f14) doesn't fail in that situation so that's indeed a regression.

Fix - explicit lock

As you suggested, I fixed this by wrapping the tokenizer's __call__ with an explicit threading lock. This required a bit more changes in the code - I'm not sure the choices I made are the best and would like some feedback - commit e4cb992. Basically what I did is to add another flag to get_tokenizer called use_threadsafe and to propagate it all the way to init_tokenizer_from_configs. When this flag is set to True, the tokenizer is wrapped as a ThreadsafeTokenizer, similar to the existing CachedTokenizer. The tokenizer used for encoding requests in the OpenAI compatible server is the one held in the MQLLMEngineClient, so now this client initializes a threadsafe tokenizer. I again sent multiple requests, some with truncation and some without, and validated that I get no 500 errors (I used similar code what I posted in #10635 (comment), but edited the request body in each thread to include 'truncate_prompt_tokens': 1000 if random.random()>0.5 else None).

Benchmarks

I ran the same sort of benchmarks as in #10635 (comment)

Model: Qwen/Qwen2.5-1.5B-Instruct
Hardware: single H100
Serve run command: vllm serve Qwen/Qwen2.5-1.5B-Instruct

Endpoint: v1/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 52.88 6406.30 23.09 189265
pr (commit b61a04f) 52.92 5985.69 21.29 191136
pr threadsafe (commit e59cc81) 52.75 6090.83 24.06 190028

Endpoint: v1/chat/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --endpoint /v1/chat/completions --backend openai-chat

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 57.63 5946.73 23.73 171968
pr (commit b61a04f) 55.52 5972.38 23.50 172531
pr threadsafe (commit e59cc81) 55.83 5782.09 25.14 171961

Conclusion: I think the differences are still within the randomness of the benchmarking script, although we do see a slight increase in TPOT.

/health response times:
The response time of /health when the server is not dealing with any requests was about 2.3 msecs

branch endpoint health response time (msec)
main (commit ed46f14) v1/completions 17848.1
main (commit ed46f14) v1/completions multiple inputs 14285.4
main (commit ed46f14) v1/chat/completions 17159.3
pr (commit b61a04f) v1/completions 34.6
pr (commit b61a04f) v1/completions multiple inputs 41.0
pr (commit b61a04f) v1/chat/completions 16.4
pr threadsafe (commit e59cc81) v1/completions 2.7
pr threadsafe (commit e59cc81) v1/completions multiple inputs 3.0
pr threadsafe (commit e59cc81) v1/chat/completions 4.8

So even better than before in the buggy non-threadsafe version

@njhill
Copy link
Member

njhill commented Nov 26, 2024

Thanks @tomeras91 this looks great. Unfortunately it looks like there's one more problem (seen via failing CI test)...

The tokenizer gets included in logitsprocessors used for guided decoding. We currently initialize these in the front-end process and they are pickled and sent to the back-end .. which is a bit crazy but needed to avoid their creation blocking the GIL in the main proc (unpickling them is faster). There are plans to change this but in the meantime this change breaks things because the lock can't be picked.

Ideas to fix this:

  1. Use (potentially separately constructed) non-threadsafe tokenizer for these logits processors
  2. Implement this in your threadsafe tokenizer class https://docs.python.org/3/library/pickle.html#handling-stateful-objects
  3. Another idea I thought of which might be simpler overall - use a custom ThreadPoolExecutor with max_workers=1 for the make_async calls. This would avoid needing the lock provided that all the tokenize calls for the tokenizer are done async

I think I like (2) the least ... just makes this workaround more complex/convoluted.

@tomeras91
Copy link
Contributor Author

tomeras91 commented Nov 26, 2024

Yeah I saw the failing CI but didn't get to fixing it yet.
I thought of fixing it by trying to extract the underlying non-threadsafe tokenizer before sending it to outlines. I saw there's already a "adapt tokenizer" function so I can add a bit more logic there.

I think I'll try your idea (3) as well. It makes a lot of sense since, because of the required lock, there's no real gain in running the tokenizer in multiple threads and it will remove the need for the lock.

I also want to lock some of these behaviors (short healthcheck response time under heavy long context load, no failing with mix of truncated and not truncated requests) in unittests. Still don't know how hard will that be but given there are many tests that use an OpenAI server I think it shouldn't be too hard

@tomeras91
Copy link
Contributor Author

yep. Making the tokenization async with a ThreadPoolExecutor(max_workers=1) works perfectly 🙂
I'll update the PR and add new benchmarks and hopefully some unittests

… No need for threadsafe tokenizer anymore since all tokenization happens on the same thread

Signed-off-by: Tomer Asida <[email protected]>
…an be sent concurrently and (2) that /health response time is short under high tokenization load

Signed-off-by: Tomer Asida <[email protected]>
@tomeras91 tomeras91 force-pushed the async-tokenization-in-oai-server branch from 806d0ee to b35a063 Compare November 27, 2024 12:00
Signed-off-by: Tomer Asida <[email protected]>
Copy link
Contributor Author

@tomeras91 tomeras91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill - Now we use a threadpool with max_workers=1.

Also added tests. I put them in a new file because I couldn't find a better place where they belong. If you want me to move them elsewhere, that's not a problem of course

}),
],
)
async def test_with_and_without_truncate(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails if we use a threadpool with max_workers>1 for tokenization because the tokenizer is not threadsafe. It passes both on main and on the PR

}),
],
)
async def test_healthcheck_response_time(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on main and passes on the PR. It validates the fix actually works and the server event loop isn't blocked on tokenization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embedding endpoint is not tested since it requires launching the server again with --task embedding and I wanted to save some time. I think that test will be redundant anyway, since the embedding endpoint uses the same tokenization code paths as completions and chat

@tomeras91
Copy link
Contributor Author

Updated (hopefully last) Benchmarks

I ran the same sort of benchmarks as in #10635 (comment)

Model: Qwen/Qwen2.5-1.5B-Instruct
Hardware: single H100
Serve run command: vllm serve Qwen/Qwen2.5-1.5B-Instruct

Endpoint: v1/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 52.88 6406.30 23.09 189265
pr (commit b61a04f) 52.92 5985.69 21.29 191136
pr threadsafe (commit e59cc81) 52.75 6090.83 24.06 190028
pr single thread threadpool (commit ff1d6a9) 52.66 6092.80 23.57 190356

Endpoint: v1/chat/completions
(benchmark run command: python3 benchmark_serving.py --model $MODEL --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --endpoint /v1/chat/completions --backend openai-chat

branch throughput (requests/s) mean TTFT (msec) mean TPOT (msec) Total generated tokens
main (commit ed46f14) 57.63 5946.73 23.73 171968
pr (commit b61a04f) 55.52 5972.38 23.50 172531
pr threadsafe (commit e59cc81) 55.83 5782.09 25.14 171961
pr single thread threadpool (commit ff1d6a9) 55.68 5851.25 24.94 171941

same conclusions - there's no hit to performance

/health response times are <~5 ms. I also added a unittest to make sure they response time under high tokenization load are not too different than under no load.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 27, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tomeras91!

@njhill njhill merged commit 395b1c7 into vllm-project:main Nov 27, 2024
57 checks passed
@tomeras91 tomeras91 deleted the async-tokenization-in-oai-server branch November 28, 2024 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants