Skip to content

Commit

Permalink
BF16 support for Quant-LLM kernel (#1147)
Browse files Browse the repository at this point in the history
* Add FP6 benchmark option to use BF16

* Change dequant bit-shifting logic for BF16

* Modify dequant + tensor core ops for bf16

* Template progress

* Modify fpx quant logic to include bf16

* Add tests for FP6 BF16

* Use type punning for large exponent multiplication

* Fix some TODOs

* Remove option to add exponent bias directly to the exponent bits

This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable

* Reformat

* Cleanup

* Fix alignment

* Remove templated input type whenever possible

* Remove templated input type whenever possible 2

* Remove templated input type whenever possible 3

* Less hacky way to construct a float with a large exponent

* rtol=1e-2 instead of 1e-3 for bfloat16 test

* Guards for SM75

* Remove redundant `__CUDA_ARCH` guards in host code

Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions

* Fix consistency in checking for `CUDA_ARCH` versions

* Update docs

* Make float bias a constexpr

* Update docs more

* Fix SM75 support

* Compile guard for sm<75

* Check for CUDA synchronous errors after kernel launch

If this is not done, the kernel may still run but fail silently, leading to unexpected behavior

* Updated compile guard

* Fix problematic usage of `__CUDA_ARCH__`

There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used

* Fix incorrect CUDA error handling

* Make the kernel fail for sm75 + bfloat16 inputs
  • Loading branch information
tobiasvanderwerff authored Nov 4, 2024
1 parent f99b667 commit 8c07d22
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 153 deletions.
37 changes: 25 additions & 12 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,42 @@


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
fp6_output = F.linear(fp16_act, fp6_weight)
float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda")
float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2))
fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight_fp16.dequantize(torch.float16)
bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16)

fp16_act = torch.randn(m, k, dtype=torch.float16, device="cuda")
bf16_act = fp16_act.to(torch.bfloat16)
fp6_output_fp16 = F.linear(fp16_act, fp6_weight_fp16)
fp6_output_bf16 = F.linear(bf16_act, fp6_weight_bf16)
fp16_output = F.linear(fp16_act, fp16_weight)
bf16_output = F.linear(bf16_act, bf16_weight)

fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight)
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight)
fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16)
fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16)

# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
# doesn't seem to be the right way to check for correctness
correct = (fp6_output - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2

return {
"m": m,
"k": k,
"n": n,
"fp6_latency (ms)": fp6_time,
"fp16_latency (ms)": fp16_time,
"speedup (d/s)": fp16_time / fp6_time,
"correct": correct,
"fp6-fp16 latency (ms)": fp6_time_fp16,
"fp16 latency (ms)": fp16_time,
"speedup fp16": fp16_time / fp6_time_fp16,
"correct fp16": correct_fp16,
"fp6-bf16 latency (ms)": fp6_time_bf16,
"bf16 latency (ms)": bf16_time,
"speedup bf16": bf16_time / fp6_time_bf16,
"correct bf16": correct_bf16,
}


Expand Down
7 changes: 4 additions & 3 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,17 @@ def test_to_copy_device(self, ebits, mbits):
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _Floatx_DTYPES)
@parametrize("bias", [False, True])
@parametrize("dtype", [torch.half, torch.bfloat16])
@pytest.mark.skipif(is_fbcode(), reason="broken in fbcode")
def test_fpx_weight_only(self, ebits, mbits, bias):
def test_fpx_weight_only(self, ebits, mbits, bias, dtype):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half)
linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=dtype)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
x = torch.randn(N, IC, device=device, dtype=dtype)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
# somehow compile now changes the result a bit
Expand Down
21 changes: 12 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,23 @@


class TestOps(TestCase):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
def _create_floatx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device, dtype):
# Randomly initialize each byte
nbits = 1 + ebits + mbits
floatx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8)
scale = torch.rand(OC).half() + 0.5
fp16_act = torch.rand(BS, IC).half() + 0.5
scale = torch.rand(OC).to(dtype) + 0.5
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear(self, ebits, mbits):
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear(self, ebits, mbits, dtype):
BS = 2
OC = 256
IC = 256
splitK = 1
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)

# smoke test
torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)
Expand All @@ -60,19 +61,21 @@ def test_quant_llm_linear(self, ebits, mbits):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK, dtype):
# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda")
floatx_weight, scale, fp16_act = self._create_floatx_inputs(ebits, mbits, BS, OC, IC, "cuda", dtype)

results_floatx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, floatx_weight, scale, splitK)

fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).half()
fp16_weight = from_scaled_tc_floatx(floatx_weight, ebits, mbits, scale).to(dtype)
results_fp16 = fp16_act @ fp16_weight.T

error = (results_floatx - results_fp16).abs().mean()
gt = results_fp16.abs().mean()
relative_error = error / gt
assert relative_error < 1e-3
rtol = 1e-2 if dtype == torch.bfloat16 else 1e-3
assert relative_error < rtol

instantiate_parametrized_tests(TestOps)

Expand Down
4 changes: 2 additions & 2 deletions torchao/csrc/cuda/fp6_llm/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# FP6-LLM kernel

This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN).
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN).

On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion.

See https://github.com/pytorch/ao/pull/223 for some benchmark results.
See https://github.com/pytorch/ao/pull/223 and and https://github.com/pytorch/ao/pull/1147 for some benchmark results.
Loading

0 comments on commit 8c07d22

Please sign in to comment.