Skip to content

Commit

Permalink
solve make_hint mystery
Browse files Browse the repository at this point in the history
  • Loading branch information
GiacomoPope committed Jul 25, 2024
1 parent 42c0767 commit 5617b7c
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/dilithium_py/dilithium/dilithium.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def sign(self, sk_bytes, m):

w0_minus_cs2_plus_ct0 = w0_minus_cs2 + c_t0

h = w0_minus_cs2_plus_ct0.make_hint(w1, alpha)
h = w0_minus_cs2_plus_ct0.make_hint_optimised(w1, alpha)
if h.sum_hint() > self.omega:
continue

Expand Down
36 changes: 21 additions & 15 deletions src/dilithium_py/ml_dsa/ml_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def sign(self, sk_bytes, m, deterministic=False):
rho_prime = self._h(K + rnd + mu, 64)

# Precompute NTT representation
s1 = s1.to_ntt()
s2 = s2.to_ntt()
t0 = t0.to_ntt()
s1_hat = s1.to_ntt()
s2_hat = s2.to_ntt()
t0_hat = t0.to_ntt()

alpha = self.gamma_2 << 1
while True:
Expand All @@ -246,33 +246,39 @@ def sign(self, sk_bytes, m, deterministic=False):

w = (A_hat @ y_hat).from_ntt()

# Extract out both the high and low bits
w1, w0 = w.decompose(alpha)
# NOTE: there is an optimisation possible where both the high and
# low bits of w are extracted here, which speeds up some checks
# below and requires the use of make_hint_optimised() -- to see the
# implementation of this, look at the signing algorithm for
# dilithium. We include this slower version to mirror the FIPS 204
# document precisely.
# Extract out only the high bits
w1 = w.high_bits(alpha)

# Create challenge polynomial
w1_bytes = w1.bit_pack_w(self.gamma_2)
c_tilde = self._h(mu + w1_bytes, self.c_tilde_bytes)
c_seed_bytes = c_tilde[:32]
c = self.R.sample_in_ball(c_seed_bytes, self.tau)
c_hat = c.to_ntt()

# Store c in NTT form
c = c.to_ntt()

z = y + (s1.scale(c)).from_ntt()
# NOTE: unlike FIPS 204 we start again as soon as a vector
# fails the norm bound to reduce any unneeded computations.
c_s1 = s1_hat.scale(c_hat).from_ntt()
z = y + c_s1
if z.check_norm_bound(self.gamma_1 - self.beta):
continue

w0_minus_cs2 = w0 - s2.scale(c).from_ntt()
if w0_minus_cs2.check_norm_bound(self.gamma_2 - self.beta):
c_s2 = s2_hat.scale(c_hat).from_ntt()
r0 = (w - c_s2).low_bits(alpha)
if r0.check_norm_bound(self.gamma_2 - self.beta):
continue

c_t0 = t0.scale(c).from_ntt()
c_t0 = t0_hat.scale(c_hat).from_ntt()
if c_t0.check_norm_bound(self.gamma_2):
continue

w0_minus_cs2_plus_ct0 = w0_minus_cs2 + c_t0

h = w0_minus_cs2_plus_ct0.make_hint(w1, alpha)
h = (-c_t0).make_hint(w - c_s2 + c_t0, alpha)
if h.sum_hint() > self.omega:
continue

Expand Down
11 changes: 11 additions & 0 deletions src/dilithium_py/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ def make_hint(self, other, alpha):
]
return self.parent(matrix)

def make_hint_optimised(self, other, alpha):
"""
Figure 3 (Supporting algorithms for Dilithium)
https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf
"""
matrix = [
[p.make_hint_optimised(q, alpha) for p, q in zip(r1, r2)]
for r1, r2 in zip(self._data, other._data)
]
return self.parent(matrix)

def use_hint(self, other, alpha):
"""
Figure 3 (Supporting algorithms for Dilithium)
Expand Down
9 changes: 8 additions & 1 deletion src/dilithium_py/polynomials/polynomials.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
check_norm_bound,
)
from ..shake.shake_wrapper import Shake128, Shake256
from ..utilities.utils import make_hint, use_hint
from ..utilities.utils import make_hint, make_hint_optimised, use_hint


class PolynomialRingDilithium(PolynomialRing):
Expand Down Expand Up @@ -372,6 +372,13 @@ def make_hint(self, other, alpha):
]
return self.parent(coeffs)

def make_hint_optimised(self, other, alpha):
coeffs = [
make_hint_optimised(r, z, alpha, 8380417)
for r, z in zip(self.coeffs, other.coeffs)
]
return self.parent(coeffs)

def use_hint(self, other, alpha):
coeffs = [
use_hint(h, r, alpha, 8380417) for h, r in zip(self.coeffs, other.coeffs)
Expand Down
30 changes: 11 additions & 19 deletions src/dilithium_py/utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@ def reduce_mod_pm(x, n):
for n odd:
-(n-1)/2 < r <= (n-1)/2
for n even:
- n / 2 < r <= n / 2
- n / 2 <= r <= n / 2
"""
x = x % n
if x > (n >> 1):
x -= n

# Asserts to try and understand __broken_make_hint()
# assert x > -(n >> 1)
# assert x <= (n >> 1)

return x


Expand All @@ -41,11 +36,9 @@ def decompose(r, a, q):
else:
r1 = (rp - r0) // a

# Asserts to try and understand __broken_make_hint()
# assert r0 > -(a >> 1)
# assert r0 <= (a >> 1)
# assert r % q == (r0 + r1 * a) % q

return r1, r0


Expand All @@ -59,20 +52,19 @@ def low_bits(r, a, q):
return r0


# def __broken_make_hint(z, r, a, q):
# r1 = high_bits(r, a, q)
# v1 = high_bits(r + z, a, q)
# return int(r1 != v1)


def make_hint(z0, r1, a, q):
def make_hint(z, r, a, q):
"""
Check whether the top bit of z will change when r is added
"""
The above function from the documentation
fails sometimes, but this seems to work...
r1 = high_bits(r, a, q)
v1 = high_bits(r + z, a, q)
return int(r1 != v1)

This assumes that

TODO: learn what the edge case is for the above function
def make_hint_optimised(z0, r1, a, q):
"""
Optimised version of the above used when the low bits w0 are extracted from
`w = (A_hat @ y_hat).from_ntt()` during signing
"""
gamma2 = a >> 1
if z0 <= gamma2 or z0 > (q - gamma2) or (z0 == (q - gamma2) and r1 == 0):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
from dilithium_py.utilities.utils import reduce_mod_pm
from random import randint


class TestUtils(unittest.TestCase):
def test_reduce_mod_pm_even(self):
for _ in range(100):
modulus = 2 * randint(0, 100)
for i in range(modulus):
x = reduce_mod_pm(i, modulus)
self.assertTrue(x <= modulus // 2)
self.assertTrue(x > -modulus // 2)

def test_reduce_mod_pm_odd(self):
for _ in range(100):
modulus = 2 * randint(0, 100) + 1
for i in range(modulus):
x = reduce_mod_pm(i, modulus)
self.assertTrue(x <= (modulus - 1) // 2)
self.assertTrue(x >= -(modulus - 1) // 2)

0 comments on commit 5617b7c

Please sign in to comment.