From f3580aff85a99c27bcd31b914c1f981aaba4a132 Mon Sep 17 00:00:00 2001 From: Paul Miller Date: Tue, 24 Sep 2024 19:59:17 +0000 Subject: [PATCH] ML-DSA: expose internal and non-internal apis --- src/ml-dsa.ts | 47 +++++++++++++++++++++++++++++++++++++++++------ src/utils.ts | 3 ++- test/avcp.test.js | 8 ++++++-- 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/ml-dsa.ts b/src/ml-dsa.ts index d43e73e..2cc84cf 100644 --- a/src/ml-dsa.ts +++ b/src/ml-dsa.ts @@ -10,6 +10,7 @@ import { randomBytes, splitCoder, vecCoder, + concatBytes, } from './utils.js'; /* @@ -65,10 +66,13 @@ const { mod, smod, NTT, bitsCoder } = genCrystals({ brvBits: 8, }); -const polyCoder = (d: number, compress?: (n: number) => number) => +const id = (n: T): T => n; +type IdNum = (n: number) => number; + +const polyCoder = (d: number, compress: IdNum = id, verify: IdNum = id) => bitsCoder(d, { - encode: (i: number) => (compress ? compress(i) : i), - decode: (i: number) => (compress ? compress(i) : i), + encode: (i: number) => compress(verify(i)), + decode: (i: number) => verify(compress(i)), }); const polyAdd = (a: Poly, b: Poly) => { @@ -117,6 +121,8 @@ function RejNTTPoly(xof: XofGet) { return r; } +const EMPTY = new Uint8Array(0); + type DilithiumOpts = { K: number; L: number; @@ -132,7 +138,7 @@ type DilithiumOpts = { XOF256: XOF; }; -function getDilithium(opts: DilithiumOpts): Signer { +function getDilithium(opts: DilithiumOpts) { const { K, L, GAMMA1, GAMMA2, TAU, ETA, OMEGA } = opts; const { CRH_BYTES, TR_BYTES, C_TILDE_BYTES, XOF128, XOF256 } = opts; @@ -214,7 +220,15 @@ function getDilithium(opts: DilithiumOpts): Signer { }, }; - const ETACoder = polyCoder(ETA === 2 ? 3 : 4, (i: number) => ETA - i); + const ETACoder = polyCoder( + ETA === 2 ? 3 : 4, + (i: number) => ETA - i, + (i: number) => { + if (!(-ETA <= i && i <= ETA)) + throw new Error(`malformed key s1/s3 ${i} outside of ETA range [${-ETA}, ${ETA}]`); + return i; + } + ); const T0Coder = polyCoder(13, (i: number) => (1 << (D - 1)) - i); const T1Coder = polyCoder(10); // Requires smod. Need to fix! @@ -307,7 +321,7 @@ function getDilithium(opts: DilithiumOpts): Signer { const signRandBytes = 32; const seedCoder = splitCoder(32, 64, 32); // API & argument positions are exactly as in FIPS204. - return { + const internal: Signer = { signRandBytes, keygen: (seed = randomBytes(32)) => { // H(𝜉||IntegerToBytes(𝑘, 1)||IntegerToBytes(ℓ, 1), 128) 2: ▷ expand seed @@ -486,6 +500,27 @@ function getDilithium(opts: DilithiumOpts): Signer { return equalBytes(cTilde, c2); }, }; + const getMessage = (msg: Uint8Array, ctx = EMPTY) => { + ensureBytes(msg); + ensureBytes(ctx); + if (ctx.length > 255) throw new Error('context should be less than 255 bytes'); + return concatBytes(new Uint8Array([0, ctx.length]), ctx, msg); + }; + // TODO: no hash-dsa vectors for now, so we don't implement it yet + return { + internal, + keygen: internal.keygen, + signRandBytes: internal.signRandBytes, + sign: (secretKey: Uint8Array, msg: Uint8Array, ctx = EMPTY, random?: Uint8Array) => { + const M = getMessage(msg, ctx); + const res = internal.sign(secretKey, M, random); + M.fill(0); + return res; + }, + verify: (publicKey: Uint8Array, msg: Uint8Array, sig: Uint8Array, ctx = EMPTY) => { + return internal.verify(publicKey, getMessage(msg, ctx), sig); + }, + }; } // ML-DSA diff --git a/src/utils.ts b/src/utils.ts index 1349ae3..e1840a5 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,9 +1,10 @@ /*! noble-post-quantum - MIT License (c) 2024 Paul Miller (paulmillr.com) */ import { bytes as abytes } from '@noble/hashes/_assert'; -import { TypedArray, randomBytes as randb } from '@noble/hashes/utils'; +import { TypedArray, randomBytes as randb, concatBytes } from '@noble/hashes/utils'; export const ensureBytes = abytes; export const randomBytes = randb; +export { concatBytes }; // Compares 2 u8a-s in kinda constant time export function equalBytes(a: Uint8Array, b: Uint8Array) { diff --git a/test/avcp.test.js b/test/avcp.test.js index b6ed52c..70e39c8 100644 --- a/test/avcp.test.js +++ b/test/avcp.test.js @@ -105,7 +105,7 @@ describe('AVCP', () => { const mldsa = NAMES[g.info.p.parameterSet]; for (const t of g.tests) { const rnd = t.p.rnd ? hexToBytes(t.p.rnd) : undefined; - const sig = mldsa.sign(hexToBytes(t.p.sk), hexToBytes(t.p.message), rnd); + const sig = mldsa.internal.sign(hexToBytes(t.p.sk), hexToBytes(t.p.message), rnd); deepStrictEqual(sig, hexToBytes(t.er.signature)); } } @@ -115,7 +115,11 @@ describe('AVCP', () => { const mldsa = NAMES[g.info.p.parameterSet]; const pk = hexToBytes(g.info.p.pk); for (const t of g.tests) { - const valid = mldsa.verify(pk, hexToBytes(t.p.message), hexToBytes(t.p.signature)); + const valid = mldsa.internal.verify( + pk, + hexToBytes(t.p.message), + hexToBytes(t.p.signature) + ); deepStrictEqual(valid, t.er.testPassed); } }