Skip to content

Commit

Permalink
use xoflib
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Aug 1, 2024
1 parent b4b2c01 commit 302b769
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 80 deletions.
31 changes: 22 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,24 @@ deterministic CSRNG. The reference implementation uses
AES256 CTR DRBG. I have implemented this in [`ase256_ctr_drbg.py`](src/dilithium_py/drbg/ase256_ctr_drbg.py).
However, I have not implemented AES itself, instead I import this from `pycryptodome`.

To install dependencies, run `pip -r install requirements`.
To install dependencies, run `pip install -r requirements.txt`.

If you're happy to use system randomness (`os.urandom`) then you don't need
this dependency.

#### `xoflib`

There is an additional optional dependency of
[`xoflib`](https://github.com/GiacomoPope/xoflib) which is a python package with
bindings to many Rust implementations of eXtendable-Output Functions (XOFx). The
creation of this package was inspired by this repository as Dilithium needs a streaming API from the shake XOFs which `hashlib` doesn't support.

`xoflib` can be installed by running `pip install xoflib` or by installing from requirements as above.

If you do not wish to install this dependency, then we include a small
[`shake_wrapper`](src/dilithium_py/shake/shake_wrapper.py) to mimic `xoflib` but
with a much higher memory consumption due to the limitations of `hashlib`.

## Using dilithium-py

### ML DSA
Expand Down Expand Up @@ -126,12 +139,12 @@ The above example would also work with the other NIST levels

Some very rough benchmarks to give an idea about performance:

| 500 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` |
| 1000 Iterations | `ML_DSA_44` | `ML_DSA_65` | `ML_DSA_87` |
|--------------------------|--------------|--------------|--------------|
| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms |
| `Sign()` Median Time | 29 ms | 52 ms | 61 ms |
| `Sign()` Average Time | 36 ms | 64 ms | 75 ms |
| `Verify()` Median Time | 8 ms | 12 ms | 18 ms |
| `KeyGen()` Median Time | 6 ms | 10 ms | 14 ms |
| `Sign()` Median Time | 29 ms | 49 ms | 59 ms |
| `Sign()` Average Time | 36 ms | 62 ms | 75 ms |
| `Verify()` Median Time | 8 ms | 11 ms | 17 ms |

All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.

Expand Down Expand Up @@ -177,12 +190,12 @@ The above example would also work with the other NIST levels

Some very rough benchmarks to give an idea about performance:

| 500 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` |
| 1000 Iterations | `Dilithium2` | `Dilithium3` | `Dilithium5` |
|--------------------------|---------------|--------------|--------------|
| `KeyGen()` Median Time | 6 ms | 10 ms | 16 ms |
| `KeyGen()` Median Time | 6 ms | 9 ms | 15 ms |
| `Sign()` Median Time | 27 ms | 46 ms | 58 ms |
| `Sign()` Average Time | 35 ms | 58 ms | 72 ms |
| `Verify()` Median Time | 8 ms | 12 ms | 18 ms |
| `Verify()` Median Time | 7 ms | 11 ms | 18 ms |

All times recorded using a Intel Core i7-9750H CPU averaged over 1000 calls.

Expand Down
8 changes: 4 additions & 4 deletions benchmarks/benchmark_dilithium.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def benchmark_dilithium(Dilithium, name, count):
# I used 1000 calls for the README, but you might want to
# shrink this down if you're playing
count = 1000
benchmark_dilithium(Dilithium2, "Dilithium2", count)
benchmark_dilithium(Dilithium3, "Dilithium3", count)
benchmark_dilithium(Dilithium5, "Dilithium5", count)
# benchmark_dilithium(Dilithium2, "Dilithium2", count)
# benchmark_dilithium(Dilithium3, "Dilithium3", count)
# benchmark_dilithium(Dilithium5, "Dilithium5", count)

# profile_dilithium(Dilithium2)
profile_dilithium(Dilithium2)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pycryptodome == 3.14.1
pycryptodome == 3.14.1
xoflib
8 changes: 6 additions & 2 deletions src/dilithium_py/dilithium/dilithium.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from ..modules.modules import ModuleDilithium
from ..shake.shake_wrapper import Shake256

try:
from xoflib import shake256
except ImportError:
from ..shake.shake_wrapper import shake256


class Dilithium:
Expand Down Expand Up @@ -56,7 +60,7 @@ def _h(input_bytes, length):
"""
H: B^* -> B^*
"""
return Shake256.digest(input_bytes, length)
return shake256(input_bytes).read(length)

def _expand_matrix_from_seed(self, rho):
"""
Expand Down
8 changes: 6 additions & 2 deletions src/dilithium_py/ml_dsa/ml_dsa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os
from ..modules.modules import ModuleDilithium
from ..shake.shake_wrapper import Shake256

try:
from xoflib import shake256
except ImportError:
from ..shake.shake_wrapper import shake256


class ML_DSA:
Expand Down Expand Up @@ -57,7 +61,7 @@ def _h(input_bytes, length):
"""
H: B^* -> B^*
"""
return Shake256.digest(input_bytes, length)
return shake256(input_bytes).read(length)

def _expand_matrix_from_seed(self, rho):
"""
Expand Down
22 changes: 13 additions & 9 deletions src/dilithium_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
decompose,
check_norm_bound,
)
from ..shake.shake_wrapper import Shake128, Shake256
from ..utilities.utils import make_hint, make_hint_optimised, use_hint

try:
from xoflib import shake128, shake256
except ImportError:
from ..shake.shake_wrapper import shake128, shake256


class PolynomialRingDilithium(PolynomialRing):
def __init__(self):
Expand Down Expand Up @@ -53,19 +57,19 @@ def rejection_sample(i, xof):
return j

# Initialise the XOF
Shake256.absorb(seed)
xof = shake256(seed)

# Set the first 8 bytes for the sign, and leave the rest for
# sampling.
sign_bytes = Shake256.read(8)
sign_bytes = xof.read(8)
sign_int = int.from_bytes(sign_bytes, "little")

# Set the list of coeffs to be 0
coeffs = [0 for _ in range(256)]

# Now set tau values of coeffs to be ±1
for i in range(256 - tau, 256):
j = rejection_sample(i, Shake256)
j = rejection_sample(i, xof)
coeffs[i] = coeffs[j]
coeffs[j] = 1 - 2 * (sign_int & 1)
sign_int >>= 1
Expand Down Expand Up @@ -93,8 +97,8 @@ def rejection_sample(xof):

# Initialise the XOF
seed = rho + bytes([j, i])
Shake128.absorb(seed)
coeffs = [rejection_sample(Shake128) for _ in range(256)]
xof = shake128(seed)
coeffs = [rejection_sample(xof) for _ in range(256)]
return self(coeffs, is_ntt=True)

def rejection_bounded_poly(self, rho_prime, i, eta):
Expand All @@ -116,14 +120,14 @@ def coefficient_from_half_byte(j, eta):

# Initialise the XOF
seed = rho_prime + int.to_bytes(i, 2, "little")
Shake256.absorb(seed)
xof = shake256(seed)

# Sample bytes for all n coeffs
i = 0
coeffs = [0 for _ in range(256)]
while i < 256:
# Consider two values for each byte (top and bottom four bits)
j = Shake256.read(1)[0]
j = xof.read(1)[0]

c0 = coefficient_from_half_byte(j % 16, eta)
if c0 is not False:
Expand Down Expand Up @@ -151,7 +155,7 @@ def sample_mask_polynomial(self, rho_prime, i, kappa, gamma_1):

# Initialise the XOF
seed = rho_prime + int.to_bytes(kappa + i, 2, "little")
xof_bytes = Shake256.digest(seed, total_bytes)
xof_bytes = shake256(seed).read(total_bytes)
r = int.from_bytes(xof_bytes, "little")
mask = (1 << bit_count) - 1
coeffs = [gamma_1 - ((r >> bit_count * i) & mask) for i in range(self.n)]
Expand Down
68 changes: 22 additions & 46 deletions src/dilithium_py/shake/shake_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,69 +17,45 @@ class Shake:
def __init__(self, algorithm, block_length):
self.algorithm = algorithm
self.block_length = block_length
self.index = 0
self.read_blocks = 0
self.bytes_left = 0
self.read_data = b""
self.buf = b""
self.len_buf = 0

def absorb(self, input_bytes):
"""
Initialise the XOF with the seed
and reset other init.
Initialise the XOF with the seed and reset other init.
"""
self.read_data = b""
self.read_blocks = 0
self.bytes_left = 0
# Initalize the buffer
self.index = 0
self.xof = self.algorithm(input_bytes)

def digest(self, input_bytes, length):
"""
Sometimes we just want n bytes, so rather than read
them slowly, we can just pull them straight out.
"""
return self.algorithm(input_bytes).digest(length)
# Set the reading method from hashlib digest
self.xof_read = self.algorithm(input_bytes).digest

def get_n_blocks(self, n):
"""
Requests n blocks from Shake and stores them
Ignores any bytes previously read
"""
# Because of hashlib we need to request ALL bytes even
# if we only want 5 more blocks
byte_count = self.block_length * (self.read_blocks + n)
xof_data = self.xof.digest(byte_count)

# include the extra blocks and remove the read ones
self.read_data = (
self.read_data[self.index :] + xof_data[-self.block_length * n :]
)
self.read_blocks += n
self.bytes_left += self.block_length * n
self.index = 0
# Start by requesting 5 blocks from the XOF
self.buf = self.xof_read(5 * self.block_length)
self.len_buf = 5 * self.block_length

def read(self, n):
"""
Rad n bytes from the XOF
Read n bytes from the XOF
"""
# Make sure there are enough bytes to read
if n > self.bytes_left:
# If we don't need many bytes, just get 5 blocks
if (n - self.bytes_left) < 5 * self.block_length:
self.get_n_blocks(5)
# Otherwise get as many as we need
else:
self.get_n_blocks(n // self.block_length + 1)
while self.index + n > self.len_buf:
# double the size of the buffer
self.len_buf *= 2
self.buf = self.xof_read(self.len_buf)

# Read from the buffer data the bytes requested
send = self.read_data[self.index : self.index + n]
send = self.buf[self.index : self.index + n]

# Store that we've read the bytes and shift the index
self.bytes_left -= n
# Shift the index along the buffer
self.index += n

return send

def __call__(self, input_bytes):
self.absorb(input_bytes)
return self


Shake128 = Shake(shake_128, 168)
Shake256 = Shake(shake_256, 136)
shake128 = Shake(shake_128, 168)
shake256 = Shake(shake_256, 136)
14 changes: 7 additions & 7 deletions tests/test_shake.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hashlib import shake_128, shake_256
from dilithium_py.shake.shake_wrapper import Shake128, Shake256
from dilithium_py.shake.shake_wrapper import shake128, shake256
from Crypto.Hash.SHAKE128 import SHAKE128_XOF
from Crypto.Hash.SHAKE256 import SHAKE256_XOF

Expand All @@ -22,12 +22,12 @@ def hashlib_test_many_calls(self, Shake, shake_hashlib):
self.assertEqual(shake_hashlib(absorb_bytes).digest(l), output)

def test_hashlib_shake128(self):
self.hashlib_test_long_calls(Shake128, shake_128)
self.hashlib_test_many_calls(Shake128, shake_128)
self.hashlib_test_long_calls(shake128, shake_128)
self.hashlib_test_many_calls(shake128, shake_128)

def test_hashlib_shake256(self):
self.hashlib_test_long_calls(Shake256, shake_256)
self.hashlib_test_many_calls(Shake256, shake_256)
self.hashlib_test_long_calls(shake256, shake_256)
self.hashlib_test_many_calls(shake256, shake_256)


class TestShakeCrypto(unittest.TestCase):
Expand All @@ -40,5 +40,5 @@ def pycryptodome_test_read_chunks(self, Shake, ShakeCrypto):
self.assertEqual(Shake.read(chunk), ShakeCrypto.read(chunk))

def test_pycryptodome_shake(self):
self.pycryptodome_test_read_chunks(Shake128, SHAKE128_XOF())
self.pycryptodome_test_read_chunks(Shake256, SHAKE256_XOF())
self.pycryptodome_test_read_chunks(shake128, SHAKE128_XOF())
self.pycryptodome_test_read_chunks(shake256, SHAKE256_XOF())

0 comments on commit 302b769

Please sign in to comment.