Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for ML-KEM keys #2009

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS})
add_executable(vectors_kem vectors_kem.c)
target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS})

if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
# workaround for Windows .dll
if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition)
else()
target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE")
endif()
endif()

# Enable Valgrind-based timing side-channel analysis for test_kem and test_sig
if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD)
message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.")
Expand Down
114 changes: 108 additions & 6 deletions tests/vectors_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,34 @@
#include <sys/stat.h>

#include <oqs/oqs.h>

#include <oqs/sha3.h>
#include "system_info.c"

#ifdef OQS_ENABLE_KEM_ML_KEM
/* macros for sanity checks for encaps and decaps key */
#define ML_KEM_BLOCKSIZE 384
#define ML_KEM_K_MAX 4
#define ML_KEM_N 256
#define ML_KEM_1024_PK_SIZE 1568
#define ML_KEM_Q 3329
#define SHA256_OP_LEN 32
/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */
#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q))
#endif //OQS_ENABLE_KEM_ML_KEM

struct {
const uint8_t *pos;
} prng_state = {
.pos = 0
};

/* MLKEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
}

static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) {
size_t i;
fprintf(fp, "%s", S);
Expand Down Expand Up @@ -58,13 +77,75 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) {
}
}

/* ML_KEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
#ifdef OQS_ENABLE_KEM_ML_KEM
static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) {
/* sanity checks */
if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name);
return false;
}
/* buffer to hold public key hash */
uint8_t pkdig[SHA256_OP_LEN] = {0};
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
/* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */
OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32);
/* compare it with public key hash stored at 768k+32 offset */
if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) {
return false;
}
return true;
}

static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) {
/* sanity checks */
if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name);
return false;
}
unsigned int i, j;
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
/* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions
encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */
uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0};
/* buffer to hold encoded value */
uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0};
uint16_t *buff_dec;
/* perform byte decoding as per Algo 6 of FIPS 203 */
for (i = 0; i < K; i++) {
buff_dec = &buffd[i * ML_KEM_N];
const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF;
buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]);
buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF;
buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]);
}
}
/* perform byte encoding as per Algo 5 of FIPS 203 */
for (i = 0; i < K; i++) {
uint16_t t0, t1;
buff_dec = &buffd[i * ML_KEM_N];
uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
t0 = buff_dec[2 * j];
t1 = buff_dec[2 * j + 1];
buff_enc[3 * j + 0] = (uint8_t)(t0 >> 0);
buff_enc[3 * j + 1] = (uint8_t)((t0 >> 8) | (t1 << 4));
buff_enc[3 * j + 2] = (uint8_t)(t1 >> 4);
}
}
/* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */
if (0 != memcmp(buffe, pk, pkLen - 32)) {
return false;
}
return true;
}
#endif //OQS_ENABLE_KEM_ML_KEM

static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) {
(void) personalization_string;
prng_state.pos = entropy_input;
Expand Down Expand Up @@ -134,6 +215,13 @@ static OQS_STATUS kem_kg_vector(const char *method_name,
fprintBstr(fh, "ek: ", public_key, kem->length_public_key);
fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key);

#ifdef OQS_ENABLE_KEM_ML_KEM
if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) {
ret = OQS_SUCCESS;
} else {
Expand Down Expand Up @@ -208,6 +296,13 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) {
fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down Expand Up @@ -273,6 +368,13 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckSK(encdec_sk, method_name)) {
fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down
Loading