Skip to content

Commit

Permalink
Fix : BitNet tests (#34895)
Browse files Browse the repository at this point in the history
* fix_tests_bitnet

* fix format
  • Loading branch information
MekkCyber authored Nov 25, 2024
1 parent 9121ab8 commit 4e6b19c
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions tests/quantization/bitnet_integration/test_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def test_replace_with_bitlinear(self):

self.assertEqual(nb_linears - 1, nb_bitnet_linear)

def test_quantized_model(self, quantized_model, tokenizer):
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_text = "What are we having for dinner?"
expected_output = "What are we having for dinner? What are we going to do for fun this weekend?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
input_ids = self.tokenizer(input_text, return_tensors="pt").to("cuda")

output = quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
output = self.quantized_model.generate(**input_ids, max_new_tokens=11, do_sample=False)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)

def test_packing_unpacking(self):
"""
Expand All @@ -113,9 +113,12 @@ def test_packing_unpacking(self):

from transformers.integrations import pack_weights, unpack_weights

u = torch.randint(0, 255, (1024, 1024), dtype=torch.uint8)
u = torch.randint(0, 255, (256, 256), dtype=torch.uint8)
unpacked_u = unpack_weights(u, dtype=torch.bfloat16)
self.assertEqual(pack_weights(unpacked_u), u)
repacked_u = pack_weights(unpacked_u)
for i in range(u.shape[0]):
for j in range(u.shape[1]):
self.assertEqual(repacked_u[i][j], u[i][j])

def test_activation_quant(self):
"""
Expand All @@ -127,15 +130,14 @@ def test_activation_quant(self):
layer = BitLinear(in_features=4, out_features=2, bias=False, dtype=torch.float32)
layer.to(self.device)

input_tensor = torch.tensor([[1.0, -1.0, -1.0, 1.0], [1.0, -1.0, 1.0, 1.0]], dtype=torch.float32).to(
torch_device
)
input_tensor = torch.tensor([1.0, -1.0, -1.0, 1.0], dtype=torch.float32).to(torch_device)

# Quantize the input tensor
quantized_tensor, scale = layer.activation_quant(input_tensor)

# Verify the output quantized tensor
self.assertEqual(quantized_tensor, input_tensor)
for i in range(input_tensor.shape[0]):
self.assertEqual(quantized_tensor[i] / scale, input_tensor[i])

# Verify the scale tensor
self.assertEqual(scale, 127)
Expand Down

0 comments on commit 4e6b19c

Please sign in to comment.