Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Critical Security Vulnerability Caused by Circom Operator Misuse #7

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions circuits/ctr.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,33 @@ pragma circom 2.1.9;

include "cipher.circom";
include "transformations.circom";
include "circomlib/circuits/comparators.circom";
include "circomlib/circuits/bitify.circom";

template EncryptCTR(l,nk){
signal input plainText[l];
signal input iv[16];
signal input key[nk * 4];
signal output cipher[l];

component checkPlainText[l];
for (var i = 0; i < l; i++) {
checkPlainText[i] = Num2Bits(8);
checkPlainText[i].in <== plainText[i];
}

component checkIv[16];
for (var i = 0; i < 16; i++) {
checkIv[i] = Num2Bits(8);
checkIv[i].in <== iv[i];
}

component checkKey[nk * 4];
for (var i = 0; i < nk * 4; i++) {
checkKey[i] = Num2Bits(8);
checkKey[i].in <== key[i];
}

var n = l\16;
if(l%16 > 0){
n = n + 1;
Expand Down Expand Up @@ -114,31 +134,52 @@ template AddCipher(){
}
}

template ByteInc() {
signal input in;
signal input control;
signal output out;
signal output carry;

signal added;
added <== in + control;

signal addedDiff;
addedDiff <== added - 256;
carry <== IsZero()(addedDiff);

out <== added - carry * 256;
}

// converts iv to counter blocks
// iv is 16 bytes
template GenerateCounterBlocks(n){
assert(n < 0xffffffff);
signal input iv[16];
signal blockNonce[n][16];
signal output counterBlocks[n][4][4];

var ivr[16] = iv;

component toBlocks[n];

component ivByteInc[n-1][16];


toBlocks[0] = ToBlocks(16);
toBlocks[0].stream <== iv;
counterBlocks[0] <== toBlocks[0].blocks[0];

for (var i = 1; i < n; i++) {
for (var j = 15; j >= 0; j--) {
ivByteInc[i-1][j] = ByteInc();
ivByteInc[i-1][j].in <== toBlocks[i-1].stream[j];
if (j==15) {
ivByteInc[i-1][j].control <== 1;
} else {
ivByteInc[i-1][j].control <== ivByteInc[i-1][j+1].carry;
}
blockNonce[i][j] <== ivByteInc[i-1][j].out;
}

for (var i = 0; i < n; i++) {
toBlocks[i] = ToBlocks(16);
toBlocks[i].stream <-- ivr;
toBlocks[i].stream <== blockNonce[i];
counterBlocks[i] <== toBlocks[i].blocks[0];
ivr[15] = (ivr[15] + 1)%256;
if (ivr[15] == 0){
ivr[14] = (ivr[14] + 1)%256;
if (ivr[14] == 0){
ivr[13] = (ivr[13] + 1)%256;
if (ivr[13] == 0){
ivr[12] = (ivr[12] + 1)%256;
}
}
}

}
}
169 changes: 169 additions & 0 deletions circuits/finite_field.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
pragma circom 2.1.9;

include "circomlib/circuits/bitify.circom";

// Finite field addition, the signal variable plus a compile-time constant
template FieldAddConst(c) {
signal input in[8];
// control bit, if 0, then do not perform addition
signal input control;
signal output out[8];

for (var i=0; i<8; i++) {
if(c & (1<<i) != 0) {
// XOR operation
out[i] <== in[i] + control - 2 * in[i] * control;
} else {
out[i] <== in[i];
}
}
}

// Finite field multiplication by 2 operation for AES. This involves left-shifting 'input' by 1 (input << 1),
// and then XORing with 0x1B if the most significate bit is 1. This is because the irreducible polynomial
// for AES's finite field (GF(2^8)) is x^8 + x^4 + x^3 + x + 1.
template FieldMul2() {
signal input in;
signal output out;

signal inBits[8];
inBits <== Num2Bits(8)(in);

component reduce = FieldAddConst(0x1b);
reduce.in[0] <== 0;
for (var i = 1; i < 8; i++) {
reduce.in[i] <== inBits[i-1];
}
reduce.control <== inBits[7];
out <== Bits2Num(8)(reduce.out);
}

// Finite field multiplication by 3 operation for AES. This involves (input << 1) ⊕ input and then XORing
// with 0x1B if the most significate bit is 1.
template FieldMul3() {
signal input in;
signal output out;

signal inBits[8] <== Num2Bits(8)(in);

component reduce = FieldAddConst(0x1b);
reduce.in[0] <== inBits[0];
for (var i = 1; i < 8; i++) {
reduce.in[i] <== inBits[i-1] + inBits[i] - 2 * inBits[i-1] * inBits[i];
}
reduce.control <== inBits[7];
out <== Bits2Num(8)(reduce.out);
}

// Determine the parity (odd or even) of an integer that can be accommodated within 'nBits' bits.
template IsOdd(nBits) {
signal input in;
signal output out;
if (nBits == 1) {
out <== in;
} else {
signal bits[nBits] <== Num2Bits(nBits)(in);
out <== bits[0];
}
}

// Finite field multiplication.
template FieldMul() {
signal input a;
signal input b;
signal inBits[2][8];
signal output out;

inBits[0] <== Num2Bits(8)(a);
inBits[1] <== Num2Bits(8)(b);

// List of finite field elements obtained by successively doubling, starting from 1.
var power[15] = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a];

signal mulMatrix[8][8];
var outLinesLc[8];
for (var i = 0; i < 8; i++) {
outLinesLc[i] = 0;
}
// Apply elementary multiplication
for (var i = 0; i < 8; i++) {
for (var j = 0; j < 8; j++) {
mulMatrix[i][j] <== inBits[0][i] * inBits[1][j];
for (var t = 0; t < 8; t++) {
if (power[i+j] & (1 << t) != 0) {
outLinesLc[t] += mulMatrix[i][j];
}
}
}
}
signal outBitsUnreduced[8];
signal outBits[8];
for (var i = 0; i < 8; i++) {
outBitsUnreduced[i] <== outLinesLc[i];
// Each element in 'outLinesLc' is incremented by a known constant number of
// elements from 'mulMatrix', less than 31.
outBits[i] <== IsOdd(6)(outBitsUnreduced[i]);
}

out <== Bits2Num(8)(outBits);
}

// Finite Field Inversion. Specially, if the input is 0, the output is also 0.
template FieldInv() {
signal input in;
signal output out;

var inv[256] = [0x00, 0x01, 0x8d, 0xf6, 0xcb, 0x52, 0x7b, 0xd1, 0xe8, 0x4f, 0x29, 0xc0, 0xb0, 0xe1, 0xe5, 0xc7,
0x74, 0xb4, 0xaa, 0x4b, 0x99, 0x2b, 0x60, 0x5f, 0x58, 0x3f, 0xfd, 0xcc, 0xff, 0x40, 0xee, 0xb2,
0x3a, 0x6e, 0x5a, 0xf1, 0x55, 0x4d, 0xa8, 0xc9, 0xc1, 0x0a, 0x98, 0x15, 0x30, 0x44, 0xa2, 0xc2,
0x2c, 0x45, 0x92, 0x6c, 0xf3, 0x39, 0x66, 0x42, 0xf2, 0x35, 0x20, 0x6f, 0x77, 0xbb, 0x59, 0x19,
0x1d, 0xfe, 0x37, 0x67, 0x2d, 0x31, 0xf5, 0x69, 0xa7, 0x64, 0xab, 0x13, 0x54, 0x25, 0xe9, 0x09,
0xed, 0x5c, 0x05, 0xca, 0x4c, 0x24, 0x87, 0xbf, 0x18, 0x3e, 0x22, 0xf0, 0x51, 0xec, 0x61, 0x17,
0x16, 0x5e, 0xaf, 0xd3, 0x49, 0xa6, 0x36, 0x43, 0xf4, 0x47, 0x91, 0xdf, 0x33, 0x93, 0x21, 0x3b,
0x79, 0xb7, 0x97, 0x85, 0x10, 0xb5, 0xba, 0x3c, 0xb6, 0x70, 0xd0, 0x06, 0xa1, 0xfa, 0x81, 0x82,
0x83, 0x7e, 0x7f, 0x80, 0x96, 0x73, 0xbe, 0x56, 0x9b, 0x9e, 0x95, 0xd9, 0xf7, 0x02, 0xb9, 0xa4,
0xde, 0x6a, 0x32, 0x6d, 0xd8, 0x8a, 0x84, 0x72, 0x2a, 0x14, 0x9f, 0x88, 0xf9, 0xdc, 0x89, 0x9a,
0xfb, 0x7c, 0x2e, 0xc3, 0x8f, 0xb8, 0x65, 0x48, 0x26, 0xc8, 0x12, 0x4a, 0xce, 0xe7, 0xd2, 0x62,
0x0c, 0xe0, 0x1f, 0xef, 0x11, 0x75, 0x78, 0x71, 0xa5, 0x8e, 0x76, 0x3d, 0xbd, 0xbc, 0x86, 0x57,
0x0b, 0x28, 0x2f, 0xa3, 0xda, 0xd4, 0xe4, 0x0f, 0xa9, 0x27, 0x53, 0x04, 0x1b, 0xfc, 0xac, 0xe6,
0x7a, 0x07, 0xae, 0x63, 0xc5, 0xdb, 0xe2, 0xea, 0x94, 0x8b, 0xc4, 0xd5, 0x9d, 0xf8, 0x90, 0x6b,
0xb1, 0x0d, 0xd6, 0xeb, 0xc6, 0x0e, 0xcf, 0xad, 0x08, 0x4e, 0xd7, 0xe3, 0x5d, 0x50, 0x1e, 0xb3,
0x5b, 0x23, 0x38, 0x34, 0x68, 0x46, 0x03, 0x8c, 0xdd, 0x9c, 0x7d, 0xa0, 0xcd, 0x1a, 0x41, 0x1c];

// Obtain an unchecked result from a lookup table
out <-- inv[in];
// Compute the product of the input and output, expected to be 1
signal checkRes <== FieldMul()(in, out);
// For the special case when the input is 0, both input and output should be 0
signal isZeroIn <== IsZero()(in);
signal isZeroOut <== IsZero()(out);
signal checkZero <== isZeroIn * isZeroOut;
// Ensure that either the product is 1 or both input and output are 0, satisfying at least one condition
(1 - checkRes) * (1 - checkZero) === 0;
}

// AffineTransform required by the S-box computation.
template AffineTransform() {
signal input inBits[8];
signal output outBits[8];

var matrix[8][8] = [[1, 0, 0, 0, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 0, 0, 0],
[0, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 1, 1]];
var offset[8] = [1, 1, 0, 0, 0, 1, 1, 0];
for (var i = 0; i < 8; i++) {
var lc = 0;
for (var j = 0; j < 8; j++) {
if (matrix[i][j] == 1) {
lc += inBits[j];
}
}
lc += offset[i];
outBits[i] <== IsOdd(3)(lc);
}
}
10 changes: 3 additions & 7 deletions circuits/key_expansion.circom
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,14 @@ template KeyExpansion(nk,nr) {

for (var round = 1; round <= effectiveRounds; round++) {
var outputWordLen = round == effectiveRounds ? 4 : nk;
nextRound[round - 1] = NextRound(nk, outputWordLen);
nextRound[round - 1] = NextRound(nk, outputWordLen, round);

for (var i = 0; i < nk; i++) {
for (var j = 0; j < 4; j++) {
nextRound[round - 1].key[i][j] <== keyExpanded[(round * nk) + i - nk][j];
}
}

nextRound[round - 1].round <== round;

for (var i = 0; i < outputWordLen; i++) {
for (var j = 0; j < 4; j++) {
keyExpanded[(round * nk) + i][j] <== nextRound[round - 1].nextKey[i][j];
Expand All @@ -79,9 +77,8 @@ template KeyExpansion(nk,nr) {

// @param nk: number of keys which can be 4, 6, 8
// @param o: number of output words which can be 4 or nk
template NextRound(nk, o){
template NextRound(nk, o, round){
signal input key[nk][4];
signal input round;
signal output nextKey[o][4];

component rotateWord = Rotate(1, 4);
Expand All @@ -93,8 +90,7 @@ template NextRound(nk, o){
substituteWord[0] = SubstituteWord();
substituteWord[0].bytes <== rotateWord.rotated;

component rcon = RCon();
rcon.round <== round;
component rcon = RCon(round);

component xorWord[o + 1];
xorWord[0] = XorWord();
Expand Down
18 changes: 9 additions & 9 deletions circuits/mix_columns.circom
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ template S0(){
}

num2bits[0] = Num2Bits(8);
num2bits[0].in <-- TBox(0, in[0]);
num2bits[0].in <== TBox(0)(in[0]);

num2bits[1] = Num2Bits(8);
num2bits[1].in <-- TBox(1, in[1]);
num2bits[1].in <== TBox(1)(in[1]);

xor[0] = XorBits();
xor[0].a <== num2bits[0].out;
Expand Down Expand Up @@ -92,10 +92,10 @@ template S1(){
num2bits[0].in <== in[0];

num2bits[1] = Num2Bits(8);
num2bits[1].in <-- TBox(0, in[1]);
num2bits[1].in <== TBox(0)(in[1]);

num2bits[2] = Num2Bits(8);
num2bits[2].in <-- TBox(1, in[2]);
num2bits[2].in <== TBox(1)(in[2]);

num2bits[3] = Num2Bits(8);
num2bits[3].in <== in[3];
Expand Down Expand Up @@ -134,10 +134,10 @@ template S2() {
}

num2bits[2] = Num2Bits(8);
num2bits[2].in <-- TBox(0, in[2]);
num2bits[2].in <== TBox(0)(in[2]);

num2bits[3] = Num2Bits(8);
num2bits[3].in <-- TBox(1, in[3]);
num2bits[3].in <== TBox(1)(in[3]);

xor[0] = XorBits();
xor[0].a <== num2bits[0].out;
Expand Down Expand Up @@ -173,10 +173,10 @@ template S3() {
}

num2bits[0] = Num2Bits(8);
num2bits[0].in <-- TBox(1, in[0]);
num2bits[0].in <== TBox(1)(in[0]);

num2bits[3] = Num2Bits(8);
num2bits[3].in <-- TBox(0, in[3]);
num2bits[3].in <== TBox(0)(in[3]);

xor[0] = XorBits();
xor[0].a <== num2bits[0].out;
Expand All @@ -187,7 +187,7 @@ template S3() {
xor[1].b <== num2bits[2].out;

xor[2] = XorBits();
xor[2].a <-- num2bits[3].out;
xor[2].a <== num2bits[3].out;
xor[2].b <== xor[1].out;

component b2n = Bits2Num(8);
Expand Down
Loading
Loading