diff --git a/src/dilithium_py/dilithium/default_parameters.py b/src/dilithium_py/dilithium/default_parameters.py index 9518e4f..d20cca0 100644 --- a/src/dilithium_py/dilithium/default_parameters.py +++ b/src/dilithium_py/dilithium/default_parameters.py @@ -2,43 +2,34 @@ DEFAULT_PARAMETERS = { "dilithium2": { - "n": 256, - "q": 8380417, "d": 13, "k": 4, "l": 4, "eta": 2, - "eta_bound": 15, "tau": 39, "omega": 80, "gamma_1": 131072, # 2^17 "gamma_2": 95232, # (q-1)/88 }, "dilithium3": { - "n": 256, - "q": 8380417, "d": 13, "k": 6, "l": 5, "eta": 4, - "eta_bound": 9, "tau": 49, "omega": 55, "gamma_1": 524288, # 2^19 - "gamma_2": 261888, # (q-1)/88 + "gamma_2": 261888, # (q-1)/32 }, "dilithium5": { - "n": 256, - "q": 8380417, "d": 13, "k": 8, "l": 7, "eta": 2, - "eta_bound": 15, "tau": 60, "omega": 75, "gamma_1": 524288, # 2^19 - "gamma_2": 261888, # (q-1)/88 + "gamma_2": 261888, # (q-1)/32 }, } diff --git a/src/dilithium_py/dilithium/dilithium.py b/src/dilithium_py/dilithium/dilithium.py index 2ba50cf..bef4043 100644 --- a/src/dilithium_py/dilithium/dilithium.py +++ b/src/dilithium_py/dilithium/dilithium.py @@ -1,28 +1,22 @@ import os - -from ..polynomials.polynomials import PolynomialRingDilithium from ..modules.modules import ModuleDilithium -from ..shake.shake_wrapper import Shake128, Shake256 -from ..utilities.utils import make_hint, use_hint +from ..shake.shake_wrapper import Shake256 class Dilithium: def __init__(self, parameter_set): - self.n = parameter_set["n"] - self.q = parameter_set["q"] self.d = parameter_set["d"] self.k = parameter_set["k"] self.l = parameter_set["l"] self.eta = parameter_set["eta"] - self.eta_bound = parameter_set["eta_bound"] self.tau = parameter_set["tau"] self.omega = parameter_set["omega"] self.gamma_1 = parameter_set["gamma_1"] self.gamma_2 = parameter_set["gamma_2"] self.beta = self.tau * self.eta - self.R = PolynomialRingDilithium() self.M = ModuleDilithium() + self.R = self.M.ring # Use system randomness by default, for deterministic randomness # use the method `set_drbg_seed()` @@ -64,217 +58,36 @@ def _h(input_bytes, length): """ return Shake256.digest(input_bytes, length) - """ - Figure 3 (Supporting algorithms for Dilithium) - `_make_hint/_use_hint` is applied to matrices and `_make_hint_poly/_use_hint_poly` - applies to the polynomials, which are elements of the matrices. - - `_make_hint_poly/_use_hint_poly` uses the util functions `use_hint/make_hint` - which works on field elements (see utils.py) - - https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf - """ - - def _make_hint(self, v1, v2, alpha): - matrix = [ - [ - self._make_hint_poly(p1, p2, alpha) - for p1, p2 in zip(v1._data[i], v2._data[i]) - ] - for i in range(len(v1._data)) - ] - return self.M(matrix) - - def _use_hint(self, v1, v2, alpha): - matrix = [ - [ - self._use_hint_poly(p1, p2, alpha) - for p1, p2 in zip(v1._data[i], v2._data[i]) - ] - for i in range(len(v1._data)) - ] - return self.M(matrix) - - def _make_hint_poly(self, p1, p2, alpha): - coeffs = [make_hint(r, z, alpha, self.q) for r, z in zip(p1.coeffs, p2.coeffs)] - return self.R(coeffs) - - def _use_hint_poly(self, p1, p2, alpha): - coeffs = [use_hint(h, r, alpha, self.q) for h, r in zip(p1.coeffs, p2.coeffs)] - return self.R(coeffs) - - @staticmethod - def _sum_hint(hint): - """ - Helper function to count the number of coeffs == 1 - in all the polynomials of a matrix - """ - return sum(c for row in hint._data for p in row for c in p) - - def _sample_in_ball(self, seed): - """ - Figure 2 (Sample in Ball) - https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf - - Create a random 256-element array with τ ±1’s and (256 − τ) 0′s using - the input seed ρ (and an SHAKE256) to generate the randomness needed - """ - - def rejection_sample(i, xof): - """ - Sample random bytes from `xof_bytes` and - interpret them as integers in {0, ..., 255} - - Rejects values until a value j <= i is found - """ - while True: - j = xof.read(1) - j = int.from_bytes(j, "little") - if j <= i: - return j - - # Initialise the XOF - Shake256.absorb(seed) - - # Set the first 8 bytes for the sign, and leave the rest for - # sampling. - sign_bytes = Shake256.read(8) - sign_int = int.from_bytes(sign_bytes, "little") - - # Set the list of coeffs to be 0 - coeffs = [0 for _ in range(self.n)] - - # Now set tau values of coeffs to be ±1 - for i in range(256 - self.tau, self.n): - j = rejection_sample(i, Shake256) - coeffs[i] = coeffs[j] - coeffs[j] = 1 - 2 * (sign_int & 1) - sign_int >>= 1 - - return self.R(coeffs) - - def _sample_error_polynomial(self, rho_prime, i, is_ntt=False): - def rejection_sample(xof): - """ - Sample a random byte from `xof_bytes` and - interpret it as two integers in {0, ..., 2η} - by considering the top and bottom four bits - - Rejects values until a value j < 2η is found - """ - while True: - js = [] - - # Consider two values for each byte (top and bottom four bits) - j = xof.read(1) - j = int.from_bytes(j, "little") - j0 = j & 0x0F - j1 = j >> 4 - - # rejection sample - if j0 < self.eta_bound: - if self.eta == 2: - j0 %= 5 - js.append(self.eta - j0) - - if j1 < self.eta_bound: - if self.eta == 2: - j1 %= 5 - js.append(self.eta - j1) - - if js: - return js - - # Initialise the XOF - seed = rho_prime + int.to_bytes(i, 2, "little") - Shake256.absorb(seed) - - # Sample bytes for all n coeffs - # TODO: make this better. - coeffs = [] - while len(coeffs) < self.n: - js = rejection_sample(Shake256) - coeffs += js - - # Remove the last byte if we ended up overfilling - if len(coeffs) > self.n: - coeffs = coeffs[: self.n] - - return self.R(coeffs, is_ntt=is_ntt) - - def _sample_matrix_polynomial(self, rho, i, j, is_ntt=False): - def rejection_sample(xof): - """ - Sample three random bytes from `xof` and - interpret them as integers in {0, ..., 2^23 - 1} - - Rejects values until a value j < q is found - """ - while True: - j_bytes = xof.read(3) - j = int.from_bytes(j_bytes, "little") - j &= 0x7FFFFF - if j < self.q: - return j - - # Initialise the XOF - seed = rho + bytes([j, i]) - Shake128.absorb(seed) - coeffs = [rejection_sample(Shake128) for _ in range(self.n)] - return self.R(coeffs, is_ntt=is_ntt) - - def _sample_mask_polynomial(self, rho_prime, i, kappa, is_ntt=False): - if self.gamma_1 == (1 << 17): - bit_count = 18 - total_bytes = 576 # (256 * 18) / 8 - else: - bit_count = 20 - total_bytes = 640 # (256 * 20) / 8 - - # Initialise the XOF - seed = rho_prime + int.to_bytes(kappa + i, 2, "little") - xof_bytes = Shake256.digest(seed, total_bytes) - r = int.from_bytes(xof_bytes, "little") - mask = (1 << bit_count) - 1 - coeffs = [self.gamma_1 - ((r >> bit_count * i) & mask) for i in range(self.n)] - - return self.R(coeffs, is_ntt=is_ntt) - - def _expandA(self, rho, is_ntt=False): + def _expand_matrix_from_seed(self, rho): """ Helper function which generates a element of size k x l from a seed `rho`. - - When `transpose` is set to True, the matrix A is - built as the transpose. """ - matrix = [ - [ - self._sample_matrix_polynomial(rho, i, j, is_ntt=is_ntt) - for j in range(self.l) - ] - for i in range(self.k) - ] - return self.M(matrix) + A_data = [[0 for _ in range(self.l)] for _ in range(self.k)] + for i in range(self.k): + for j in range(self.l): + A_data[i][j] = self.R.rejection_sample_ntt_poly(rho, i, j) + return self.M(A_data) - def _expandS(self, rho_prime): + def _expand_vector_from_seed(self, rho_prime): s1_elements = [ - self._sample_error_polynomial(rho_prime, i) for i in range(self.l) + self.R.rejection_bounded_poly(rho_prime, i, self.eta) for i in range(self.l) ] s2_elements = [ - self._sample_error_polynomial(rho_prime, i) + self.R.rejection_bounded_poly(rho_prime, i, self.eta) for i in range(self.l, self.l + self.k) ] - s1 = self.M(s1_elements).transpose() - s2 = self.M(s2_elements).transpose() + s1 = self.M.vector(s1_elements) + s2 = self.M.vector(s2_elements) return s1, s2 - def _expandMask(self, rho_prime, kappa): + def _expand_mask_vector(self, rho_prime, kappa): elements = [ - self._sample_mask_polynomial(rho_prime, i, kappa) for i in range(self.l) + self.R.sample_mask_polynomial(rho_prime, i, kappa, self.gamma_1) + for i in range(self.l) ] - return self.M(elements).transpose() + return self.M.vector(elements) @staticmethod def _pack_pk(rho, t1): @@ -351,7 +164,7 @@ def _unpack_h(self, h_bytes): matrix = [] for poly_non_zero in non_zero_positions: - coeffs = [0 for _ in range(self.n)] + coeffs = [0 for _ in range(256)] for non_zero in poly_non_zero: coeffs[non_zero] = 1 matrix.append([self.R(coeffs)]) @@ -367,6 +180,9 @@ def _unpack_sig(self, sig_bytes): return c_tilde, z, h def keygen(self): + """ + Generates a public-private keyair + """ # Random seed zeta = self.random_bytes(32) @@ -376,15 +192,15 @@ def keygen(self): # Split bytes into suitible chunks rho, rho_prime, K = seed_bytes[:32], seed_bytes[32:96], seed_bytes[96:] - # Generate matrix A ∈ R^(kxl) - A = self._expandA(rho, is_ntt=True) + # Generate matrix A ∈ R^(kxl) in the NTT domain + A_hat = self._expand_matrix_from_seed(rho) # Generate the error vectors s1 ∈ R^l, s2 ∈ R^k - s1, s2 = self._expandS(rho_prime) + s1, s2 = self._expand_vector_from_seed(rho_prime) s1_hat = s1.to_ntt() # Matrix multiplication - t = (A @ s1_hat).from_ntt() + s2 + t = (A_hat @ s1_hat).from_ntt() + s2 t1, t0 = t.power_2_round(self.d) @@ -396,11 +212,14 @@ def keygen(self): return pk, sk def sign(self, sk_bytes, m): + """ + Generates a signature for a message m from a byte-encoded private key + """ # unpack the secret key rho, K, tr, s1, s2, t0 = self._unpack_sk(sk_bytes) - # Generate matrix A ∈ R^(kxl) - A = self._expandA(rho, is_ntt=True) + # Generate matrix A ∈ R^(kxl) in the NTT domain + A_hat = self._expand_matrix_from_seed(rho) # Set seeds and nonce (kappa) mu = self._h(tr + m, 64) @@ -414,13 +233,13 @@ def sign(self, sk_bytes, m): alpha = self.gamma_2 << 1 while True: - y = self._expandMask(rho_prime, kappa) + y = self._expand_mask_vector(rho_prime, kappa) y_hat = y.to_ntt() # increment the nonce kappa += self.l - w = (A @ y_hat).from_ntt() + w = (A_hat @ y_hat).from_ntt() # Extract out both the high and low bits w1, w0 = w.decompose(alpha) @@ -428,7 +247,7 @@ def sign(self, sk_bytes, m): # Create challenge polynomial w1_bytes = w1.bit_pack_w(self.gamma_2) c_tilde = self._h(mu + w1_bytes, 32) - c = self._sample_in_ball(c_tilde) + c = self.R.sample_in_ball(c_tilde, self.tau) # Store c in NTT form c = c.to_ntt() @@ -447,27 +266,31 @@ def sign(self, sk_bytes, m): w0_minus_cs2_plus_ct0 = w0_minus_cs2 + c_t0 - h = self._make_hint(w0_minus_cs2_plus_ct0, w1, alpha) - if self._sum_hint(h) > self.omega: + h = w0_minus_cs2_plus_ct0.make_hint(w1, alpha) + if h.sum_hint() > self.omega: continue return self._pack_sig(c_tilde, z, h) def verify(self, pk_bytes, m, sig_bytes): + """ + Verifies a signature for a message m from a byte encoded public key and + signature + """ rho, t1 = self._unpack_pk(pk_bytes) c_tilde, z, h = self._unpack_sig(sig_bytes) - if self._sum_hint(h) > self.omega: + if h.sum_hint() > self.omega: return False if z.check_norm_bound(self.gamma_1 - self.beta): return False - A = self._expandA(rho, is_ntt=True) + A_hat = self._expand_matrix_from_seed(rho) tr = self._h(pk_bytes, 32) mu = self._h(tr + m, 64) - c = self._sample_in_ball(c_tilde) + c = self.R.sample_in_ball(c_tilde, self.tau) # Convert to NTT for computation c = c.to_ntt() @@ -476,10 +299,10 @@ def verify(self, pk_bytes, m, sig_bytes): t1 = t1.scale(1 << self.d) t1 = t1.to_ntt() - Az_minus_ct1 = (A @ z) - t1.scale(c) + Az_minus_ct1 = (A_hat @ z) - t1.scale(c) Az_minus_ct1 = Az_minus_ct1.from_ntt() - w_prime = self._use_hint(h, Az_minus_ct1, 2 * self.gamma_2) + w_prime = h.use_hint(Az_minus_ct1, 2 * self.gamma_2) w_prime_bytes = w_prime.bit_pack_w(self.gamma_2) return c_tilde == self._h(mu + w_prime_bytes, 32) diff --git a/src/dilithium_py/modules/modules.py b/src/dilithium_py/modules/modules.py index 941977b..b6882d3 100644 --- a/src/dilithium_py/modules/modules.py +++ b/src/dilithium_py/modules/modules.py @@ -162,3 +162,32 @@ def low_bits(self, alpha, is_ntt=False): [ele.low_bits(alpha, is_ntt=is_ntt) for ele in row] for row in self._data ] return self.parent(matrix) + + def make_hint(self, other, alpha): + """ + Figure 3 (Supporting algorithms for Dilithium) + https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf + """ + matrix = [ + [p.make_hint(q, alpha) for p, q in zip(r1, r2)] + for r1, r2 in zip(self._data, other._data) + ] + return self.parent(matrix) + + def use_hint(self, other, alpha): + """ + Figure 3 (Supporting algorithms for Dilithium) + https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf + """ + matrix = [ + [p.use_hint(q, alpha) for p, q in zip(r1, r2)] + for r1, r2 in zip(self._data, other._data) + ] + return self.parent(matrix) + + def sum_hint(self): + """ + Helper function to count the number of coeffs == 1 + in all the polynomials of a matrix + """ + return sum(c for row in self._data for p in row for c in p) diff --git a/src/dilithium_py/polynomials/polynomials.py b/src/dilithium_py/polynomials/polynomials.py index d3ad100..db88606 100644 --- a/src/dilithium_py/polynomials/polynomials.py +++ b/src/dilithium_py/polynomials/polynomials.py @@ -6,6 +6,8 @@ decompose, check_norm_bound, ) +from ..shake.shake_wrapper import Shake128, Shake256 +from ..utilities.utils import make_hint, use_hint class PolynomialRingDilithium(PolynomialRing): @@ -29,6 +31,137 @@ def br(i, k): bin_i = bin(i & (2**k - 1))[2:].zfill(k) return int(bin_i[::-1], 2) + def sample_in_ball(self, seed, tau): + """ + Figure 2 (Sample in Ball) + https://pq-crystals.org/dilithium/data/dilithium-specification-round3-20210208.pdf + + Create a random 256-element array with τ ±1’s and (256 − τ) 0′s using + the input seed ρ (and an SHAKE256) to generate the randomness needed + """ + + def rejection_sample(i, xof): + """ + Sample random bytes from `xof_bytes` and + interpret them as integers in {0, ..., 255} + + Rejects values until a value j <= i is found + """ + while True: + j = xof.read(1)[0] + if j <= i: + return j + + # Initialise the XOF + Shake256.absorb(seed) + + # Set the first 8 bytes for the sign, and leave the rest for + # sampling. + sign_bytes = Shake256.read(8) + sign_int = int.from_bytes(sign_bytes, "little") + + # Set the list of coeffs to be 0 + coeffs = [0 for _ in range(256)] + + # Now set tau values of coeffs to be ±1 + for i in range(256 - tau, 256): + j = rejection_sample(i, Shake256) + coeffs[i] = coeffs[j] + coeffs[j] = 1 - 2 * (sign_int & 1) + sign_int >>= 1 + + return self(coeffs) + + def rejection_sample_ntt_poly(self, rho, i, j): + """ + Samples an element in the NTT domain of R^q using rejection sampling + """ + + def rejection_sample(xof): + """ + Sample three random bytes from `xof` and + interpret them as integers in {0, ..., 2^23 - 1} + + Rejects values until a value j < q is found + """ + while True: + j_bytes = xof.read(3) + j = int.from_bytes(j_bytes, "little") + j &= 0x7FFFFF + if j < 8380417: + return j + + # Initialise the XOF + seed = rho + bytes([j, i]) + Shake128.absorb(seed) + coeffs = [rejection_sample(Shake128) for _ in range(256)] + return self(coeffs, is_ntt=True) + + def rejection_bounded_poly(self, rho_prime, i, eta): + """ + Computes an element of the polynomial ring with coefficients between + -eta and eta using rejection sampling from an XOF + """ + + def coefficient_from_half_byte(j, eta): + """ + Rejects values until a value j < 2η is found + """ + if eta == 2 and j < 15: + return 2 - (j % 5) + elif j < 9: + assert eta == 4 + return 4 - j + return False + + # Initialise the XOF + seed = rho_prime + int.to_bytes(i, 2, "little") + Shake256.absorb(seed) + + # Sample bytes for all n coeffs + i = 0 + coeffs = [0 for _ in range(256)] + while i < 256: + # Consider two values for each byte (top and bottom four bits) + j = Shake256.read(1)[0] + + c0 = coefficient_from_half_byte(j % 16, eta) + if c0 is not False: + coeffs[i] = c0 + i += 1 + + c1 = coefficient_from_half_byte(j // 16, eta) + if c1 is not False and i < 256: + coeffs[i] = c1 + i += 1 + + # Remove the last byte if we ended up overfilling + if len(coeffs) > 256: + coeffs = coeffs[:256] + + return self(coeffs) + + def sample_mask_polynomial(self, rho_prime, i, kappa, gamma_1): + """ + Samples an element in the polynomial ring with elements bounded + between -gamma_1 + 1 and gamma_1. + """ + if gamma_1 == (1 << 17): + bit_count = 18 + total_bytes = 576 # (256 * 18) / 8 + else: + bit_count = 20 + total_bytes = 640 # (256 * 20) / 8 + + # Initialise the XOF + seed = rho_prime + int.to_bytes(kappa + i, 2, "little") + xof_bytes = Shake256.digest(seed, total_bytes) + r = int.from_bytes(xof_bytes, "little") + mask = (1 << bit_count) - 1 + coeffs = [gamma_1 - ((r >> bit_count * i) & mask) for i in range(self.n)] + + return self(coeffs) + def __bit_unpack(self, input_bytes, n_bits): if (len(input_bytes) * n_bits) % 8 != 0: raise ValueError( @@ -53,10 +186,10 @@ def bit_unpack_s(self, input_bytes, eta): if eta == 2: altered_coeffs = self.__bit_unpack(input_bytes, 3) # Level 3 parameter set - elif eta == 4: - altered_coeffs = self.__bit_unpack(input_bytes, 4) else: - raise ValueError("Expected eta to be either 2 or 4") + assert eta == 4, f"Expected eta to be either 2 or 4, got {eta = }" + altered_coeffs = self.__bit_unpack(input_bytes, 4) + coefficients = [eta - c for c in altered_coeffs] return self(coefficients) @@ -65,10 +198,12 @@ def bit_unpack_w(self, input_bytes, gamma_2): if gamma_2 == 95232: coefficients = self.__bit_unpack(input_bytes, 6) # Level 3 and 5 parameter set - elif gamma_2 == 261888: - coefficients = self.__bit_unpack(input_bytes, 4) else: - raise ValueError("Expected gamma_2 to be either (q-1)/88 or (q-1)/32") + assert ( + gamma_2 == 261888 + ), f"Expected gamma_2 to be either (q-1)/88 or (q-1)/32, got {gamma_2 = }" + coefficients = self.__bit_unpack(input_bytes, 4) + return self(coefficients) def bit_unpack_z(self, input_bytes, gamma_1): @@ -76,10 +211,12 @@ def bit_unpack_z(self, input_bytes, gamma_1): if gamma_1 == (1 << 17): altered_coeffs = self.__bit_unpack(input_bytes, 18) # Level 3 and 5 parameter set - elif gamma_1 == (1 << 19): - altered_coeffs = self.__bit_unpack(input_bytes, 20) else: - raise ValueError("Expected gamma_1 to be either 2^17 or 2^19") + assert gamma_1 == ( + 1 << 19 + ), f"Expected gamma_1 to be either 2^17 or 2^19, got {gamma_1 = }" + altered_coeffs = self.__bit_unpack(input_bytes, 20) + coefficients = [gamma_1 - c for c in altered_coeffs] return self(coefficients) @@ -205,20 +342,18 @@ def bit_pack_s(self, eta): if eta == 2: return self.__bit_pack(altered_coeffs, 3, 96) # Level 3 parameter set - elif eta == 4: - return self.__bit_pack(altered_coeffs, 4, 128) - else: - raise ValueError("Expected eta to be either 2 or 4") + assert eta == 4, f"Expected eta to be either 2 or 4, got {eta = }" + return self.__bit_pack(altered_coeffs, 4, 128) def bit_pack_w(self, gamma_2): # Level 2 parameter set if gamma_2 == 95232: return self.__bit_pack(self.coeffs, 6, 192) # Level 3 and 5 parameter set - elif gamma_2 == 261888: - return self.__bit_pack(self.coeffs, 4, 128) - else: - raise ValueError("Expected gamma_2 to be either (q-1)/88 or (q-1)/32") + assert ( + gamma_2 == 261888 + ), f"Expected gamma_2 to be either (q-1)/88 or (q-1)/32, got {gamma_2 = }" + return self.__bit_pack(self.coeffs, 4, 128) def bit_pack_z(self, gamma_1): altered_coeffs = [self._sub_mod_q(gamma_1, c) for c in self.coeffs] @@ -226,10 +361,22 @@ def bit_pack_z(self, gamma_1): if gamma_1 == (1 << 17): return self.__bit_pack(altered_coeffs, 18, 576) # Level 3 and 5 parameter set - elif gamma_1 == (1 << 19): - return self.__bit_pack(altered_coeffs, 20, 640) - else: - raise ValueError("Expected gamma_1 to be either 2^17 or 2^19") + assert gamma_1 == ( + 1 << 19 + ), f"Expected gamma_1 to be either 2^17 or 2^19, got: {gamma_1 = }" + return self.__bit_pack(altered_coeffs, 20, 640) + + def make_hint(self, other, alpha): + coeffs = [ + make_hint(r, z, alpha, 8380417) for r, z in zip(self.coeffs, other.coeffs) + ] + return self.parent(coeffs) + + def use_hint(self, other, alpha): + coeffs = [ + use_hint(h, r, alpha, 8380417) for h, r in zip(self.coeffs, other.coeffs) + ] + return self.parent(coeffs) class PolynomialDilithiumNTT(PolynomialDilithium): diff --git a/src/dilithium_py/utilities/utils.py b/src/dilithium_py/utilities/utils.py index da4f8b9..bb3b6db 100644 --- a/src/dilithium_py/utilities/utils.py +++ b/src/dilithium_py/utilities/utils.py @@ -1,26 +1,22 @@ -from collections import deque - - -def reduce_mod_pm(n, a): +def reduce_mod_pm(x, n): """ - Takes an integer n and represents + Takes an integer 0 < x < n and represents it as an integer in the range - r = n % a + r = x % n - for a odd: - -(a-1)/2 < r <= (a-1)/2 - for a even: - - a / 2 < r <= a / 2 + for n odd: + -(n-1)/2 < r <= (n-1)/2 + for n even: + - n / 2 < r <= n / 2 """ - r = n % a - if r > (a >> 1): - r -= a + x = x % n + if x > (n >> 1): + x -= n + # assert x > -(n >> 1) + # assert x <= (n >> 1) - # assert r > -(a >> 1) - # assert r <= (a >> 1) - # assert (n % a) == (r % a) - return r + return x def decompose(r, a, q): @@ -34,13 +30,13 @@ def decompose(r, a, q): -(a << 1) < r0 <= (a << 1) """ - r = r % q - r0 = reduce_mod_pm(r, a) - r1 = r - r0 + rp = r % q + r0 = reduce_mod_pm(rp, a) + r1 = rp - r0 if r1 == q - 1: return 0, r0 - 1 r1 = r1 // a - assert r == r1 * a + r0 + return r1, r0 @@ -100,30 +96,6 @@ def check_norm_bound(n, b, q): return x >= b -def get_n_blocks(xof, n, blocks_read): - blocks_read += n - # extract last n blocks - total_bytes = 136 * blocks_read - xof_bytes = xof.digest(total_bytes)[-136 * n :] - # We use `deque` because it has a fast .popleft() - return deque(xof_bytes), blocks_read - - -def get_mask_integers(bit_count, xof, n, blocks_read): - blocks_read += n * bit_count - # extract last n*bit_count blocks - total_bytes = 136 * blocks_read - xof_bytes = xof.digest(total_bytes)[-136 * n * bit_count :] - - r = int.from_bytes(xof_bytes, "little") - mask = (1 << bit_count) - 1 - mask_integers = [] - for _ in range(256): - mask_integers.append(r & mask) - r >>= bit_count - return deque(mask_integers), blocks_read - - def xor_bytes(a, b): """ XOR two byte arrays, assume that they are