Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sm2: Fix heap allocation #1099

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sm2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features =
primeorder = { version = "=0.14.0-pre.2", optional = true, path = "../primeorder" }
rfc6979 = { version = "=0.5.0-pre.4", optional = true }
serdect = { version = "0.3.0-rc.0", optional = true, default-features = false }
signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core"] }
signature = { version = "=2.3.0-pre.4", optional = true, features = ["rand_core", "digest"] }
sm3 = { version = "=0.5.0-pre.4", optional = true, default-features = false }

[dev-dependencies]
Expand Down
29 changes: 14 additions & 15 deletions sm2/src/pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,19 @@

use core::cmp::min;

use crate::AffinePoint;

#[cfg(feature = "alloc")]
use alloc::vec;

use elliptic_curve::{
bigint::{Encoding, Uint, U256},
pkcs8::der::{
asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer,
},
};

use elliptic_curve::{
pkcs8::der::{asn1::OctetStringRef, EncodeValue},
sec1::ToEncodedPoint,
Result,
sec1::{ModulusSize, ToEncodedPoint},
CurveArithmetic, FieldBytesSize, Result,
};
use sm3::digest::DynDigest;
use primeorder::{AffinePoint, PrimeCurveParams};
use signature::digest::{FixedOutputReset, Output, Update};

#[cfg(feature = "arithmetic")]
mod decrypting;
Expand Down Expand Up @@ -131,22 +126,26 @@ impl<'a> DecodeValue<'a> for Cipher<'a> {
}

/// Performs key derivation using a hash function and elliptic curve point.
fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> {
fn kdf<D, C>(hasher: &mut D, kpb: AffinePoint<C>, c2: &mut [u8]) -> Result<()>
where
D: Update + FixedOutputReset,
C: CurveArithmetic + PrimeCurveParams,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: ToEncodedPoint<C>,
{
let klen = c2.len();
let mut ct: i32 = 0x00000001;
let mut offset = 0;
let digest_size = hasher.output_size();
let mut ha = vec![0u8; digest_size];
let digest_size = D::output_size();
let mut ha = Output::<D>::default();
let encode_point = kpb.to_encoded_point(false);

while offset < klen {
hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?);
hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?);
hasher.update(&ct.to_be_bytes());

hasher
.finalize_into_reset(&mut ha)
.map_err(|_e| elliptic_curve::Error)?;
hasher.finalize_into_reset(&mut ha);

let xor_len = min(digest_size, klen - offset);
xor(c2, &ha, offset, xor_len);
Expand Down
25 changes: 12 additions & 13 deletions sm2/src/pke/decrypting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ use elliptic_curve::{
};
use primeorder::PrimeField;

use sm3::{digest::DynDigest, Digest, Sm3};
use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update};
use sm3::Sm3;

use super::{encrypting::EncryptingKey, kdf, vec, Cipher, Mode};
use super::{encrypting::EncryptingKey, kdf, Cipher, Mode};
/// Represents a decryption key used for decrypting messages using elliptic curve cryptography.
#[derive(Clone)]
pub struct DecryptingKey {
Expand Down Expand Up @@ -91,7 +92,7 @@ impl DecryptingKey {
/// Decrypts a ciphertext in-place using the specified digest algorithm.
pub fn decrypt_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + OutputSizeUser + Update + FixedOutputReset,
{
let mut digest = D::new();
decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext)
Expand All @@ -105,7 +106,7 @@ impl DecryptingKey {
/// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm.
pub fn decrypt_der_digest<D>(&self, ciphertext: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + OutputSizeUser + Update + FixedOutputReset,
{
let cipher = Cipher::from_der(ciphertext).map_err(elliptic_curve::pkcs8::Error::from)?;
let prefix: &[u8] = &[0x04];
Expand Down Expand Up @@ -153,12 +154,10 @@ impl PartialEq for DecryptingKey {
}
}

fn decrypt(
secret_scalar: &Scalar,
mode: Mode,
hasher: &mut dyn DynDigest,
cipher: &[u8],
) -> Result<Vec<u8>> {
fn decrypt<D>(secret_scalar: &Scalar, mode: Mode, hasher: &mut D, cipher: &[u8]) -> Result<Vec<u8>>
where
D: Update + OutputSizeUser + FixedOutputReset,
{
let q = U256::from_be_hex(FieldElement::MODULUS);
let c1_len = (q.bits() + 7) / 8 * 2 + 1;

Expand All @@ -177,7 +176,7 @@ fn decrypt(

// B3: compute [𝑑𝐵]𝐶1 = (𝑥2, 𝑦2)
c1_point = (c1_point * secret_scalar).to_affine();
let digest_size = hasher.output_size();
let digest_size = D::output_size();
let (c2, c3) = match mode {
Mode::C1C3C2 => {
let (c3, c2) = c.split_at(digest_size);
Expand All @@ -192,12 +191,12 @@ fn decrypt(
kdf(hasher, c1_point, &mut c2)?;

// compute 𝑢 = 𝐻𝑎𝑠ℎ(𝑥2 ∥ 𝑀′∥ 𝑦2).
let mut u = vec![0u8; digest_size];
let mut u = Output::<D>::default();
let encode_point = c1_point.to_encoded_point(false);
hasher.update(encode_point.x().ok_or(Error)?);
hasher.update(&c2);
hasher.update(encode_point.y().ok_or(Error)?);
hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?;
hasher.finalize_into_reset(&mut u);
let checked = u
.iter()
.zip(c3)
Expand Down
40 changes: 17 additions & 23 deletions sm2/src/pke/encrypting.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use core::fmt::Debug;

use crate::{
arithmetic::field::FieldElement,
pke::{kdf, vec},
AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2,
arithmetic::field::FieldElement, pke::kdf, AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2,
};

#[cfg(feature = "alloc")]
Expand All @@ -18,12 +16,10 @@ use elliptic_curve::{
};

use primeorder::PrimeField;
use sm3::{
digest::{Digest, DynDigest},
Sm3,
};
use sm3::Sm3;

use super::{Cipher, Mode};
use signature::digest::{Digest, FixedOutputReset, Output, OutputSizeUser, Update};
/// Represents an encryption key used for encrypting messages using elliptic curve cryptography.
#[derive(Clone, Debug)]
pub struct EncryptingKey {
Expand Down Expand Up @@ -91,7 +87,7 @@ impl EncryptingKey {
/// Encrypts a message using a specified digest algorithm.
pub fn encrypt_digest<D>(&self, msg: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Digest + Update + FixedOutputReset,
{
let mut digest = D::new();
encrypt(&self.public_key, self.mode, &mut digest, msg)
Expand All @@ -100,11 +96,11 @@ impl EncryptingKey {
/// Encrypts a message using a specified digest algorithm and returns the result in ASN.1 format.
pub fn encrypt_der_digest<D>(&self, msg: &[u8]) -> Result<Vec<u8>>
where
D: 'static + Digest + DynDigest + Send + Sync,
D: Update + OutputSizeUser + Digest + FixedOutputReset,
{
let mut digest = D::new();
let cipher = encrypt(&self.public_key, self.mode, &mut digest, msg)?;
let digest_size = digest.output_size();
let digest_size = <D as OutputSizeUser>::output_size();
let (_, cipher) = cipher.split_at(1);
let (x, cipher) = cipher.split_at(32);
let (y, cipher) = cipher.split_at(32);
Expand Down Expand Up @@ -133,14 +129,13 @@ impl From<PublicKey> for EncryptingKey {
}

/// Encrypts a message using the specified public key, mode, and digest algorithm.
fn encrypt(
public_key: &PublicKey,
mode: Mode,
digest: &mut dyn DynDigest,
msg: &[u8],
) -> Result<Vec<u8>> {
fn encrypt<D>(public_key: &PublicKey, mode: Mode, digest: &mut D, msg: &[u8]) -> Result<Vec<u8>>
where
D: Update + FixedOutputReset,
{
const N_BYTES: u32 = (Sm2::ORDER.bits() + 7) / 8;
let mut c1 = vec![0; (N_BYTES * 2 + 1) as usize];
#[allow(unused_assignments)]
let mut c1 = Default::default();
let mut c2 = msg.to_owned();
let mut hpb: AffinePoint;
loop {
Expand All @@ -167,24 +162,23 @@ fn encrypt(
// // If 𝑡 is an all-zero bit string, go to A1.
// if all of t are 0, xor(c2) == c2
if c2.iter().zip(msg).any(|(pre, cur)| pre != cur) {
let uncompress_kg = kg.to_encoded_point(false);
c1.copy_from_slice(uncompress_kg.as_bytes());
c1 = kg.to_encoded_point(false);
break;
}
}
let encode_point = hpb.to_encoded_point(false);

// A7: compute 𝐶3 = 𝐻𝑎𝑠ℎ(𝑥2||𝑀||𝑦2)
let mut c3 = vec![0; digest.output_size()];
let mut c3 = Output::<D>::default();
digest.update(encode_point.x().ok_or(Error)?);
digest.update(msg);
digest.update(encode_point.y().ok_or(Error)?);
digest.finalize_into_reset(&mut c3).map_err(|_e| Error)?;
digest.finalize_into_reset(&mut c3);

// A8: output the ciphertext 𝐶 = 𝐶1||𝐶2||𝐶3.
Ok(match mode {
Mode::C1C2C3 => [c1.as_slice(), &c2, &c3].concat(),
Mode::C1C3C2 => [c1.as_slice(), &c3, &c2].concat(),
Mode::C1C2C3 => [c1.as_bytes(), &c2, &c3].concat(),
Mode::C1C3C2 => [c1.as_bytes(), &c3, &c2].concat(),
})
}

Expand Down