-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server #10635
[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server #10635
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
LGTM - as a sanity check, can you run a quick ShareGPT benchmark on a small model? If you need instructions for this, I can share commands. |
…ver by using threadpool for tokenization Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
5162f66
to
dd01b53
Compare
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Sure. Here are the results Model: Qwen/Qwen2.5-1.5B-Instruct
So results are pretty much identical. Maybe a bit better with the PR but I guess the diffs are too small to be meaningful. Other than that, I think the CI failures are not related to the changes in this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tomeras91.
Another thought is that it's likely optimal to only dispatch to another thread if the length of text / number of tokens is above some threshold, otherwise it may be slightly detrimental.
But like @robertgshaw2-neuralmagic said, we should make sure to benchmark this anyhow.
self._tokenize_prompt_inputs_async = make_async( | ||
self._tokenize_prompt_inputs) | ||
self._tokenize_prompt_input_or_inputs_async = make_async( | ||
self._tokenize_prompt_input_or_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these will work as intended. The methods return a generator which will be done asynchronously but the actual work done to generate the outputs will still be done on the asyncio event loop while iterating.
We'll probably need to think of another way to arrange this, possibly we can change these methods to just return lists rather than generators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good catch. It really didn't work for /v1/completions which uses these methods. It did work for /v1/chat/completions that uses _tokenize_prompt_input
which doesn't return a generator and actually does the tokenization work.
Anyway, fixed by making _tokenize_prompt_input_or_inputs
return a List
as you suggested.
Will post updated benchmarks and something that shows this works shortly
@tomeras91 presumably you'll also test with your >200k token workload to ensure that it helps? |
…in self._tokenize_prompt_input Signed-off-by: Tomer Asida <[email protected]>
…c will actually make execution run in thread and not just generator creation Signed-off-by: Tomer Asida <[email protected]>
Updated benchmarksNow both on Model: Qwen/Qwen2.5-1.5B-Instruct Endpoint:
Conclusion: Pretty much the same (very slightly better performance in PR even though more tokens were generated) Endpoint:
Conclusion: Pretty much the same (slightly worse performance in PR, maybe due to more generated tokens) Anyway, I would expect this PR to affect TTFT much more than it affects TPOT, but it seems like a very insignificant change regardless. If at all, the effect is more significant for the completions endpoint, and is an improvement. RE something to show this actually solves the problem:Here is the code I used: import requests
import threading
from threading import Thread
import time
def send_request(url, body):
# print(f"Sending from thread {threading.get_ident()}")
start = time.time()
res = requests.post(url, json=body)
end = time.time()
# print(f"got res in thread {threading.get_ident()} after {1000*(end-start):.2f} ms with status {res.status_code}. content: {res.json() if res.status_code!=200 else 'success'}")
return res
model = "Qwen/Qwen2.5-1.5B-Instruct"
chat_url = "http://localhost:8000/v1/chat/completions"
chat_body = {'model': model,
'messages': [{'role': 'user', 'content': ' '.join(['A']*300_000)}],
}
completions_url = "http://localhost:8000/v1/completions"
completions_body = {'model': model,
'prompt': ' '.join(['A']*300_000),
}
completions_body_mult = {'model': model,
'prompt': [' '.join(['A']*300_000)]*2,
}
for _ in range(60):
# Thread(target=send_request, kwargs={'url': completions_url, 'body': completions_body}).start()
Thread(target=send_request, kwargs={'url': completions_url, 'body': completions_body_mult}).start()
# Thread(target=send_request, kwargs={'url': chat_url, 'body': chat_body}).start()
start = time.time()
res = requests.get("http://localhost:8000/health")
end = time.time()
print(f"######################### healthcheck took {end-start} secs #########################") The idea is to send many long context requests in close succession (using multiple threads), and then check how long The response time of
The response time is much shorter in the PR. We do see that the healthcheck takes a bit longer to respond under this long context requests load, which can be attributed to the fact that tokenizing many such long requests does take a bit CPU |
Thanks @tomeras91, it looks great now. One minor concern I realized is that the HF tokenizers technically aren't thread-safe. I think in practice this is only the case if things like truncation and/or padding are being done, but that can be true for us (if truncate parameter is specified in the request). Despite this it might still not be a problem given that I don't think the GIL is really released. But I'm not sure whether we should put an explicit lock around the calls to the tokenizer or create a tokenizer per thread. The former would likely be less invasive and could be done by wrapping/replacing the |
Signed-off-by: Tomer Asida <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
@njhill - That's another great catch. Without putting an explicit lock (commit b61a04f), sending multiple requests in parallel, some with Fix - explicit lockAs you suggested, I fixed this by wrapping the tokenizer's BenchmarksI ran the same sort of benchmarks as in #10635 (comment) Model: Qwen/Qwen2.5-1.5B-Instruct Endpoint:
Endpoint:
Conclusion: I think the differences are still within the randomness of the benchmarking script, although we do see a slight increase in TPOT.
So even better than before in the buggy non-threadsafe version |
Thanks @tomeras91 this looks great. Unfortunately it looks like there's one more problem (seen via failing CI test)... The tokenizer gets included in logitsprocessors used for guided decoding. We currently initialize these in the front-end process and they are pickled and sent to the back-end .. which is a bit crazy but needed to avoid their creation blocking the GIL in the main proc (unpickling them is faster). There are plans to change this but in the meantime this change breaks things because the lock can't be picked. Ideas to fix this:
I think I like (2) the least ... just makes this workaround more complex/convoluted. |
Yeah I saw the failing CI but didn't get to fixing it yet. I think I'll try your idea (3) as well. It makes a lot of sense since, because of the required lock, there's no real gain in running the tokenizer in multiple threads and it will remove the need for the lock. I also want to lock some of these behaviors (short healthcheck response time under heavy long context load, no failing with mix of truncated and not truncated requests) in unittests. Still don't know how hard will that be but given there are many tests that use an OpenAI server I think it shouldn't be too hard |
yep. Making the tokenization async with a |
… No need for threadsafe tokenizer anymore since all tokenization happens on the same thread Signed-off-by: Tomer Asida <[email protected]>
…an be sent concurrently and (2) that /health response time is short under high tokenization load Signed-off-by: Tomer Asida <[email protected]>
806d0ee
to
b35a063
Compare
Signed-off-by: Tomer Asida <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@njhill - Now we use a threadpool with max_workers=1.
Also added tests. I put them in a new file because I couldn't find a better place where they belong. If you want me to move them elsewhere, that's not a problem of course
}), | ||
], | ||
) | ||
async def test_with_and_without_truncate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails if we use a threadpool with max_workers>1 for tokenization because the tokenizer is not threadsafe. It passes both on main
and on the PR
}), | ||
], | ||
) | ||
async def test_healthcheck_response_time( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails on main
and passes on the PR. It validates the fix actually works and the server event loop isn't blocked on tokenization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The embedding endpoint is not tested since it requires launching the server again with --task embedding
and I wanted to save some time. I think that test will be redundant anyway, since the embedding endpoint uses the same tokenization code paths as completions and chat
Updated (hopefully last) BenchmarksI ran the same sort of benchmarks as in #10635 (comment) Model: Qwen/Qwen2.5-1.5B-Instruct Endpoint:
Endpoint:
same conclusions - there's no hit to performance
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tomeras91!
Currently, tokenization of requests in the OpenAI compatible server is done synchronously on the process running the server. Usually, this is not a problem, but when dealing with long sequence lengths (256K), tokenization can take a few hundred milliseconds. Combining long requests with high loads causes the server to be non-responsive as it spends a lot of time tokenizing requests synchronously and sequentially, blocking the asyncio event loop.
This PR aims to unblock the event loop by using a thread pool to tokenize requests. This following #3512 which also introduced parallel async tokenization but was closed and this feature was never delivered.