diff --git a/README.md b/README.md index 2df27de..050fd16 100644 --- a/README.md +++ b/README.md @@ -77,11 +77,24 @@ deterministic CSRNG. The reference implementation uses AES256 CTR DRBG. I have implemented this in [`ase256_ctr_drbg.py`](src/dilithium_py/drbg/ase256_ctr_drbg.py). However, I have not implemented AES itself, instead I import this from `pycryptodome`. -To install dependencies, run `pip -r install requirements`. +To install dependencies, run `pip install -r requirements.txt`. If you're happy to use system randomness (`os.urandom`) then you don't need this dependency. +#### `xoflib` + +There is an additional optional dependency of +[`xoflib`](https://github.com/GiacomoPope/xoflib) which is a python package with +bindings to many Rust implementations of eXtendable-Output Functions (XOFx). The +creation of this package was inspired by this repository as Dilithium needs a streaming API from the shake XOFs which `hashlib` doesn't support. + +`xoflib` can be installed by running `pip install xoflib` or by installing from requirements as above. + +If you do not wish to install this dependency, then we include a small +[`shake_wrapper`](src/dilithium_py/shake/shake_wrapper.py) to mimic `xoflib` but +with a much higher memory consumption due to the limitations of `hashlib`. + ## Using dilithium-py ### ML DSA @@ -126,12 +139,12 @@ The above example would also work with the other NIST levels Some very rough benchmarks to give an idea about performance: -| 500 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` | +| 1000 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` | |--------------------------|--------------|--------------|--------------| -| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms | -| `Sign()` Median Time | 29 ms | 52 ms | 61 ms | -| `Sign()` Average Time | 36 ms | 64 ms | 75 ms | -| `Verify()` Median Time | 8 ms | 12 ms | 18 ms | +| `KeyGen()` Median Time | 6 ms | 10 ms | 14 ms | +| `Sign()` Median Time | 29 ms | 49 ms | 59 ms | +| `Sign()` Average Time | 36 ms | 62 ms | 75 ms | +| `Verify()` Median Time | 8 ms | 11 ms | 17 ms | All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls. @@ -177,12 +190,12 @@ The above example would also work with the other NIST levels Some very rough benchmarks to give an idea about performance: -| 500 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` | +| 1000 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` | |--------------------------|---------------|--------------|--------------| -| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms | +| `KeyGen()` Median Time | 6 ms | 9 ms | 15 ms | | `Sign()` Median Time | 27 ms | 46 ms | 58 ms | | `Sign()` Average Time | 35 ms | 58 ms | 72 ms | -| `Verify()` Median Time | 8 ms | 12 ms | 18 ms | +| `Verify()` Median Time | 7 ms | 11 ms | 18 ms | All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls. diff --git a/benchmarks/benchmark_dilithium.py b/benchmarks/benchmark_dilithium.py index aff8ecd..6dcb502 100644 --- a/benchmarks/benchmark_dilithium.py +++ b/benchmarks/benchmark_dilithium.py @@ -70,8 +70,8 @@ def benchmark_dilithium(Dilithium, name, count): # I used 1000 calls for the README, but you might want to # shrink this down if you're playing count = 1000 - benchmark_dilithium(Dilithium2, "Dilithium2", count) - benchmark_dilithium(Dilithium3, "Dilithium3", count) - benchmark_dilithium(Dilithium5, "Dilithium5", count) + # benchmark_dilithium(Dilithium2, "Dilithium2", count) + # benchmark_dilithium(Dilithium3, "Dilithium3", count) + # benchmark_dilithium(Dilithium5, "Dilithium5", count) - # profile_dilithium(Dilithium2) + profile_dilithium(Dilithium2) diff --git a/requirements.txt b/requirements.txt index a7902ff..21a0d6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -pycryptodome == 3.14.1 \ No newline at end of file +pycryptodome == 3.14.1 +xoflib diff --git a/src/dilithium_py/dilithium/dilithium.py b/src/dilithium_py/dilithium/dilithium.py index 143af8a..4fa0f1c 100644 --- a/src/dilithium_py/dilithium/dilithium.py +++ b/src/dilithium_py/dilithium/dilithium.py @@ -1,6 +1,10 @@ import os from ..modules.modules import ModuleDilithium -from ..shake.shake_wrapper import Shake256 + +try: + from xoflib import shake256 +except ImportError: + from ..shake.shake_wrapper import shake256 class Dilithium: @@ -56,7 +60,7 @@ def _h(input_bytes, length): """ H: B^* -> B^* """ - return Shake256.digest(input_bytes, length) + return shake256(input_bytes).read(length) def _expand_matrix_from_seed(self, rho): """ diff --git a/src/dilithium_py/ml_dsa/ml_dsa.py b/src/dilithium_py/ml_dsa/ml_dsa.py index ff19a34..ea41f22 100644 --- a/src/dilithium_py/ml_dsa/ml_dsa.py +++ b/src/dilithium_py/ml_dsa/ml_dsa.py @@ -1,6 +1,10 @@ import os from ..modules.modules import ModuleDilithium -from ..shake.shake_wrapper import Shake256 + +try: + from xoflib import shake256 +except ImportError: + from ..shake.shake_wrapper import shake256 class ML_DSA: @@ -57,7 +61,7 @@ def _h(input_bytes, length): """ H: B^* -> B^* """ - return Shake256.digest(input_bytes, length) + return shake256(input_bytes).read(length) def _expand_matrix_from_seed(self, rho): """ diff --git a/src/dilithium_py/polynomials/polynomials.py b/src/dilithium_py/polynomials/polynomials.py index 08e06a2..a8809dd 100644 --- a/src/dilithium_py/polynomials/polynomials.py +++ b/src/dilithium_py/polynomials/polynomials.py @@ -6,9 +6,13 @@ decompose, check_norm_bound, ) -from ..shake.shake_wrapper import Shake128, Shake256 from ..utilities.utils import make_hint, make_hint_optimised, use_hint +try: + from xoflib import shake128, shake256 +except ImportError: + from ..shake.shake_wrapper import shake128, shake256 + class PolynomialRingDilithium(PolynomialRing): def __init__(self): @@ -53,11 +57,11 @@ def rejection_sample(i, xof): return j # Initialise the XOF - Shake256.absorb(seed) + xof = shake256(seed) # Set the first 8 bytes for the sign, and leave the rest for # sampling. - sign_bytes = Shake256.read(8) + sign_bytes = xof.read(8) sign_int = int.from_bytes(sign_bytes, "little") # Set the list of coeffs to be 0 @@ -65,7 +69,7 @@ def rejection_sample(i, xof): # Now set tau values of coeffs to be ±1 for i in range(256 - tau, 256): - j = rejection_sample(i, Shake256) + j = rejection_sample(i, xof) coeffs[i] = coeffs[j] coeffs[j] = 1 - 2 * (sign_int & 1) sign_int >>= 1 @@ -93,8 +97,8 @@ def rejection_sample(xof): # Initialise the XOF seed = rho + bytes([j, i]) - Shake128.absorb(seed) - coeffs = [rejection_sample(Shake128) for _ in range(256)] + xof = shake128(seed) + coeffs = [rejection_sample(xof) for _ in range(256)] return self(coeffs, is_ntt=True) def rejection_bounded_poly(self, rho_prime, i, eta): @@ -116,14 +120,14 @@ def coefficient_from_half_byte(j, eta): # Initialise the XOF seed = rho_prime + int.to_bytes(i, 2, "little") - Shake256.absorb(seed) + xof = shake256(seed) # Sample bytes for all n coeffs i = 0 coeffs = [0 for _ in range(256)] while i < 256: # Consider two values for each byte (top and bottom four bits) - j = Shake256.read(1)[0] + j = xof.read(1)[0] c0 = coefficient_from_half_byte(j % 16, eta) if c0 is not False: @@ -151,7 +155,7 @@ def sample_mask_polynomial(self, rho_prime, i, kappa, gamma_1): # Initialise the XOF seed = rho_prime + int.to_bytes(kappa + i, 2, "little") - xof_bytes = Shake256.digest(seed, total_bytes) + xof_bytes = shake256(seed).read(total_bytes) r = int.from_bytes(xof_bytes, "little") mask = (1 << bit_count) - 1 coeffs = [gamma_1 - ((r >> bit_count * i) & mask) for i in range(self.n)] diff --git a/src/dilithium_py/shake/shake_wrapper.py b/src/dilithium_py/shake/shake_wrapper.py index 889324d..a71bf1b 100644 --- a/src/dilithium_py/shake/shake_wrapper.py +++ b/src/dilithium_py/shake/shake_wrapper.py @@ -17,69 +17,45 @@ class Shake: def __init__(self, algorithm, block_length): self.algorithm = algorithm self.block_length = block_length - self.index = 0 - self.read_blocks = 0 - self.bytes_left = 0 - self.read_data = b"" + self.buf = b"" + self.len_buf = 0 def absorb(self, input_bytes): """ - Initialise the XOF with the seed - and reset other init. + Initialise the XOF with the seed and reset other init. """ - self.read_data = b"" - self.read_blocks = 0 - self.bytes_left = 0 + # Initalize the buffer self.index = 0 - self.xof = self.algorithm(input_bytes) - def digest(self, input_bytes, length): - """ - Sometimes we just want n bytes, so rather than read - them slowly, we can just pull them straight out. - """ - return self.algorithm(input_bytes).digest(length) + # Set the reading method from hashlib digest + self.xof_read = self.algorithm(input_bytes).digest - def get_n_blocks(self, n): - """ - Requests n blocks from Shake and stores them - Ignores any bytes previously read - """ - # Because of hashlib we need to request ALL bytes even - # if we only want 5 more blocks - byte_count = self.block_length * (self.read_blocks + n) - xof_data = self.xof.digest(byte_count) - - # include the extra blocks and remove the read ones - self.read_data = ( - self.read_data[self.index :] + xof_data[-self.block_length * n :] - ) - self.read_blocks += n - self.bytes_left += self.block_length * n - self.index = 0 + # Start by requesting 5 blocks from the XOF + self.buf = self.xof_read(5 * self.block_length) + self.len_buf = 5 * self.block_length def read(self, n): """ - Rad n bytes from the XOF + Read n bytes from the XOF """ # Make sure there are enough bytes to read - if n > self.bytes_left: - # If we don't need many bytes, just get 5 blocks - if (n - self.bytes_left) < 5 * self.block_length: - self.get_n_blocks(5) - # Otherwise get as many as we need - else: - self.get_n_blocks(n // self.block_length + 1) + while self.index + n > self.len_buf: + # double the size of the buffer + self.len_buf *= 2 + self.buf = self.xof_read(self.len_buf) # Read from the buffer data the bytes requested - send = self.read_data[self.index : self.index + n] + send = self.buf[self.index : self.index + n] - # Store that we've read the bytes and shift the index - self.bytes_left -= n + # Shift the index along the buffer self.index += n return send + def __call__(self, input_bytes): + self.absorb(input_bytes) + return self + -Shake128 = Shake(shake_128, 168) -Shake256 = Shake(shake_256, 136) +shake128 = Shake(shake_128, 168) +shake256 = Shake(shake_256, 136) diff --git a/tests/test_shake.py b/tests/test_shake.py index e78b6b7..2154335 100644 --- a/tests/test_shake.py +++ b/tests/test_shake.py @@ -1,5 +1,5 @@ from hashlib import shake_128, shake_256 -from dilithium_py.shake.shake_wrapper import Shake128, Shake256 +from dilithium_py.shake.shake_wrapper import shake128, shake256 from Crypto.Hash.SHAKE128 import SHAKE128_XOF from Crypto.Hash.SHAKE256 import SHAKE256_XOF @@ -22,12 +22,12 @@ def hashlib_test_many_calls(self, Shake, shake_hashlib): self.assertEqual(shake_hashlib(absorb_bytes).digest(l), output) def test_hashlib_shake128(self): - self.hashlib_test_long_calls(Shake128, shake_128) - self.hashlib_test_many_calls(Shake128, shake_128) + self.hashlib_test_long_calls(shake128, shake_128) + self.hashlib_test_many_calls(shake128, shake_128) def test_hashlib_shake256(self): - self.hashlib_test_long_calls(Shake256, shake_256) - self.hashlib_test_many_calls(Shake256, shake_256) + self.hashlib_test_long_calls(shake256, shake_256) + self.hashlib_test_many_calls(shake256, shake_256) class TestShakeCrypto(unittest.TestCase): @@ -40,5 +40,5 @@ def pycryptodome_test_read_chunks(self, Shake, ShakeCrypto): self.assertEqual(Shake.read(chunk), ShakeCrypto.read(chunk)) def test_pycryptodome_shake(self): - self.pycryptodome_test_read_chunks(Shake128, SHAKE128_XOF()) - self.pycryptodome_test_read_chunks(Shake256, SHAKE256_XOF()) + self.pycryptodome_test_read_chunks(shake128, SHAKE128_XOF()) + self.pycryptodome_test_read_chunks(shake256, SHAKE256_XOF())