-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
BF16 support for Quant-LLM kernel (#1147)
* Add FP6 benchmark option to use BF16 * Change dequant bit-shifting logic for BF16 * Modify dequant + tensor core ops for bf16 * Template progress * Modify fpx quant logic to include bf16 * Add tests for FP6 BF16 * Use type punning for large exponent multiplication * Fix some TODOs * Remove option to add exponent bias directly to the exponent bits This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable * Reformat * Cleanup * Fix alignment * Remove templated input type whenever possible * Remove templated input type whenever possible 2 * Remove templated input type whenever possible 3 * Less hacky way to construct a float with a large exponent * rtol=1e-2 instead of 1e-3 for bfloat16 test * Guards for SM75 * Remove redundant `__CUDA_ARCH` guards in host code Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions * Fix consistency in checking for `CUDA_ARCH` versions * Update docs * Make float bias a constexpr * Update docs more * Fix SM75 support * Compile guard for sm<75 * Check for CUDA synchronous errors after kernel launch If this is not done, the kernel may still run but fail silently, leading to unexpected behavior * Updated compile guard * Fix problematic usage of `__CUDA_ARCH__` There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used * Fix incorrect CUDA error handling * Make the kernel fail for sm75 + bfloat16 inputs
- Loading branch information
1 parent
f99b667
commit 8c07d22
Showing
15 changed files
with
258 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
# FP6-LLM kernel | ||
|
||
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN). | ||
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN). | ||
|
||
On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. | ||
|
||
See https://github.com/pytorch/ao/pull/223 for some benchmark results. | ||
See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results. |
Oops, something went wrong.