diff --git a/lib/cairo.ex b/lib/cairo.ex index c6fb908..ba86fe5 100644 --- a/lib/cairo.ex +++ b/lib/cairo.ex @@ -88,7 +88,7 @@ defmodule Cairo do to: Cairo.CairoProver, as: :program_hash - @spec felt_to_string(list(byte())) :: binary() + @spec felt_to_string(list(byte())) :: binary() | {:error, term()} defdelegate felt_to_string(felt), to: Cairo.CairoProver, as: :cairo_felt_to_string diff --git a/lib/cairo/cairo.ex b/lib/cairo/cairo.ex index 002bc55..0810bb8 100644 --- a/lib/cairo/cairo.ex +++ b/lib/cairo/cairo.ex @@ -51,6 +51,7 @@ defmodule Cairo.CairoProver do @spec program_hash(list(byte())) :: nif_result(list(byte())) def program_hash(_public_inputs), do: error() + @spec cairo_felt_to_string(list(byte())) :: nif_result(binary()) def cairo_felt_to_string(_felt), do: error() def cairo_generate_compliance_input_json( diff --git a/native/cairo_prover/Cargo.lock b/native/cairo_prover/Cargo.lock index 8cd1ef7..d8c4825 100644 --- a/native/cairo_prover/Cargo.lock +++ b/native/cairo_prover/Cargo.lock @@ -189,7 +189,7 @@ dependencies = [ [[package]] name = "cairo-platinum-prover" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "bincode 2.0.0-rc.2", "cairo-vm", @@ -254,6 +254,7 @@ dependencies = [ "starknet-crypto 0.7.1", "starknet-curve 0.5.0", "starknet-types-core", + "thiserror", ] [[package]] @@ -485,7 +486,7 @@ dependencies = [ [[package]] name = "lambdaworks-crypto" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "lambdaworks-math 0.9.0", "serde", @@ -506,7 +507,7 @@ dependencies = [ [[package]] name = "lambdaworks-math" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "rayon", "serde", @@ -858,7 +859,7 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "stark-platinum-prover" version = "0.9.0" -source = "git+https://github.com/lambdaclass/lambdaworks#c4fa1f21b98a56825c76b2c38108e3a7f79b3995" +source = "git+https://github.com/heliaxdev/lambdaworks?branch=cairo_rm#ea328f5ca24448c0e2d7816a76b86c81eadb2d9f" dependencies = [ "bincode 2.0.0-rc.2", "itertools 0.11.0", diff --git a/native/cairo_prover/Cargo.toml b/native/cairo_prover/Cargo.toml index 422a35e..79e1acd 100644 --- a/native/cairo_prover/Cargo.toml +++ b/native/cairo_prover/Cargo.toml @@ -11,9 +11,9 @@ crate-type = ["cdylib"] [dependencies] rustler = "0.31.0" -cairo-platinum-prover = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} -stark-platinum-prover = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} -lambdaworks-math = { git = "https://github.com/lambdaclass/lambdaworks", version = "0.9.0"} +cairo-platinum-prover = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} +stark-platinum-prover = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} +lambdaworks-math = { git = "https://github.com/heliaxdev/lambdaworks", branch = "cairo_rm"} bincode = "2.0.0-rc.3" serde_json = { version = "1.0", features = ["preserve_order"] } hashbrown = { version = "0.14.0", features = ["serde"] } @@ -26,3 +26,4 @@ num-integer = { version = "0.1.45", default-features = false } rand = "0.8.5" lazy_static = "1.4" serde = { version = "1.0.160", features = ["derive"] } +thiserror = "1.0" diff --git a/native/cairo_prover/src/compliance_input.rs b/native/cairo_prover/src/compliance_input.rs index 27ef73b..66690b5 100644 --- a/native/cairo_prover/src/compliance_input.rs +++ b/native/cairo_prover/src/compliance_input.rs @@ -1,4 +1,4 @@ -use crate::utils::felt_to_string; +use crate::{error::CairoError, utils::felt_to_string}; use serde::{Deserialize, Serialize}; #[derive(Debug, Default, Clone, Serialize, Deserialize)] @@ -31,32 +31,31 @@ struct PathNode { impl ComplianceInputJson { pub fn to_json_string( - input_resource: &Vec, - output_resource: &Vec, - path: &Vec>, + input_resource: Vec, + output_resource: Vec, + path: Vec>, pos: u64, - input_nf_key: &Vec, - eph_root: &Vec, - rcv: &Vec, - ) -> String { - let input = ResourceJson::from_bytes(input_resource); - let output = ResourceJson::from_bytes(output_resource); + input_nf_key: Vec, + eph_root: Vec, + rcv: Vec, + ) -> Result { + let input = ResourceJson::from_bytes(input_resource)?; + let output = ResourceJson::from_bytes(output_resource)?; - let rcv = felt_to_string(rcv); - let eph_root = felt_to_string(eph_root); - let input_nf_key = felt_to_string(input_nf_key); + let rcv = felt_to_string(rcv)?; + let eph_root = felt_to_string(eph_root)?; + let input_nf_key = felt_to_string(input_nf_key)?; let mut next_pos = pos; - let merkle_path = path - .iter() - .map(|v| { - let snd = if next_pos % 2 == 0 { false } else { true }; - next_pos >>= 1; - PathNode { - fst: felt_to_string(v), - snd, - } - }) - .collect(); + let mut merkle_path = Vec::new(); + for node in path.into_iter() { + let snd = next_pos % 2 != 0; + next_pos >>= 1; + let node = PathNode { + fst: felt_to_string(node)?, + snd, + }; + merkle_path.push(node); + } let compliance_input = Self { input, @@ -66,22 +65,22 @@ impl ComplianceInputJson { rcv, eph_root, }; - serde_json::to_string(&compliance_input).unwrap() + Ok(serde_json::to_string(&compliance_input)?) } } impl ResourceJson { - pub fn from_bytes(bytes: &Vec) -> Self { - Self { - logic: felt_to_string(&bytes[0..32].to_vec()), - label: felt_to_string(&bytes[32..64].to_vec()), - quantity: felt_to_string(&bytes[64..96].to_vec()), - data: felt_to_string(&bytes[96..128].to_vec()), - nonce: felt_to_string(&bytes[128..160].to_vec()), - npk: felt_to_string(&bytes[160..192].to_vec()), - rseed: felt_to_string(&bytes[192..224].to_vec()), - eph: if bytes[224] == 0 { false } else { true }, - } + pub fn from_bytes(bytes: Vec) -> Result { + Ok(Self { + logic: felt_to_string(bytes[0..32].to_vec())?, + label: felt_to_string(bytes[32..64].to_vec())?, + quantity: felt_to_string(bytes[64..96].to_vec())?, + data: felt_to_string(bytes[96..128].to_vec())?, + nonce: felt_to_string(bytes[128..160].to_vec())?, + npk: felt_to_string(bytes[160..192].to_vec())?, + rseed: felt_to_string(bytes[192..224].to_vec())?, + eph: bytes[224] != 0, + }) } } @@ -97,14 +96,15 @@ fn test_compliance_input_json() { let path = (0..32).map(|_| random_felt()).collect(); let json = ComplianceInputJson::to_json_string( - &random_resouce.to_vec(), - &random_resouce.to_vec(), - &path, + random_resouce.to_vec(), + random_resouce.to_vec(), + path, 0, - &random_felt(), - &random_felt(), - &random_felt(), - ); + random_felt(), + random_felt(), + random_felt(), + ) + .unwrap(); println!("compliance_input_json: {}", json); } diff --git a/native/cairo_prover/src/encryption.rs b/native/cairo_prover/src/encryption.rs index 735368d..03d6cec 100644 --- a/native/cairo_prover/src/encryption.rs +++ b/native/cairo_prover/src/encryption.rs @@ -1,3 +1,4 @@ +use crate::{error::CairoError, utils::bytes_to_felt_vec}; use starknet_crypto::{poseidon_hash, poseidon_hash_many}; use starknet_curve::curve_params::GENERATOR; use starknet_types_core::{ @@ -29,9 +30,15 @@ impl Ciphertext { &self.0 } - pub fn encrypt(messages: &[Felt], pk: &AffinePoint, sk: &Felt, encrypt_nonce: &Felt) -> Self { + pub fn encrypt( + messages: &[Felt], + pk: &AffinePoint, + sk: &Felt, + encrypt_nonce: &Felt, + ) -> Result { // Generate the secret key - let (secret_key_x, secret_key_y) = SecretKey::from_dh_exchange(pk, sk).get_coordinates(); + let secret_key = SecretKey::from_dh_exchange(pk, sk)?; + let (secret_key_x, secret_key_y) = secret_key.get_coordinates(); // Pad the messages let plaintext = Plaintext::padding(messages); @@ -56,33 +63,34 @@ impl Ciphertext { cipher.push(poseidon_state); // Add sender's public key - let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()).unwrap(); - let sender_pk = (&generator * *sk).to_affine().unwrap(); + let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()) + .map_err(|_| CairoError::InvalidAffinePoint)?; + let sender_pk = (&generator * *sk) + .to_affine() + .map_err(|_| CairoError::InvalidAffinePoint)?; cipher.push(sender_pk.x()); cipher.push(sender_pk.y()); // Add encrypt_nonce cipher.push(*encrypt_nonce); - cipher.into() - } + let ret: [Felt; CIPHERTEXT_NUM] = cipher + .try_into() + .map_err(|_| CairoError::InvalidCiphertextLength)?; - pub fn decrypt(&self, sk: &Felt) -> Option> { - let cipher_text = self.inner(); - let cipher_len = cipher_text.len(); - if cipher_len != CIPHERTEXT_NUM { - return None; - } + Ok(Self(ret)) + } - let mac = cipher_text[CIPHERTEXT_MAC]; - let pk_x = cipher_text[CIPHERTEXT_PK_X]; - let pk_y = cipher_text[CIPHERTEXT_PK_Y]; - let encrypt_nonce = cipher_text[CIPHERTEXT_NONCE]; + pub fn decrypt(&self, sk: &Felt) -> Result, CairoError> { + let mac = self.inner()[CIPHERTEXT_MAC]; + let pk_x = self.inner()[CIPHERTEXT_PK_X]; + let pk_y = self.inner()[CIPHERTEXT_PK_Y]; + let encrypt_nonce = self.inner()[CIPHERTEXT_NONCE]; if let Ok(pk) = AffinePoint::new(pk_x, pk_y) { // Generate the secret key - let (secret_key_x, secret_key_y) = - SecretKey::from_dh_exchange(&pk, sk).get_coordinates(); + let sk = SecretKey::from_dh_exchange(&pk, sk)?; + let (secret_key_x, secret_key_y) = sk.get_coordinates(); // Init poseidon sponge state let mut poseidon_state = poseidon_hash_many(&vec![ @@ -94,30 +102,28 @@ impl Ciphertext { // Decrypt let mut msg = vec![]; - for cipher_element in &cipher_text[0..PLAINTEXT_NUM] { + for cipher_element in &self.inner()[0..PLAINTEXT_NUM] { let msg_element = *cipher_element - poseidon_state; msg.push(msg_element); poseidon_state = poseidon_hash(*cipher_element, secret_key_x); } if mac != poseidon_state { - return None; + return Err(CairoError::DecryptionFailure); } - Some(msg) + Ok(msg) } else { - return None; + Err(CairoError::InvalidPublicKey) } } -} -impl From> for Ciphertext { - fn from(input_vec: Vec) -> Self { - Ciphertext( - input_vec - .try_into() - .expect("public input with incorrect length"), - ) + pub fn from_bytes(input_vec: Vec>) -> Result { + let cipher_felt = bytes_to_felt_vec(input_vec)?; + let cipher: [Felt; CIPHERTEXT_NUM] = cipher_felt + .try_into() + .map_err(|_| CairoError::InvalidCiphertextLength)?; + Ok(Self(cipher)) } } @@ -149,12 +155,13 @@ impl From> for Plaintext { } impl SecretKey { - pub fn from_dh_exchange(pk: &AffinePoint, sk: &Felt) -> Self { - Self( - (&ProjectivePoint::try_from(pk.clone()).unwrap() * *sk) - .to_affine() - .unwrap(), - ) + pub fn from_dh_exchange(pk: &AffinePoint, sk: &Felt) -> Result { + let pk_projective = + ProjectivePoint::try_from(pk.clone()).map_err(|_| CairoError::InvalidAffinePoint)?; + let key = (&pk_projective * *sk) + .to_affine() + .map_err(|_| CairoError::InvalidDHKey)?; + Ok(Self(key)) } pub fn get_coordinates(&self) -> (Felt, Felt) { @@ -173,7 +180,7 @@ fn test_encryption() { let encrypt_nonce = Felt::ONE; // Encryption - let cipher = Ciphertext::encrypt(&messages, &pk, &sender_sk, &encrypt_nonce); + let cipher = Ciphertext::encrypt(&messages, &pk, &sender_sk, &encrypt_nonce).unwrap(); // Decryption let decryption = cipher.decrypt(&Felt::ONE).unwrap(); diff --git a/native/cairo_prover/src/error.rs b/native/cairo_prover/src/error.rs new file mode 100644 index 0000000..bd3fa1a --- /dev/null +++ b/native/cairo_prover/src/error.rs @@ -0,0 +1,59 @@ +use bincode::error::{DecodeError, EncodeError}; +use rustler::{Encoder, Env, Term}; +use serde_json::error::Error as JsonError; +use starknet_crypto::SignError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CairoError { + #[error("Inputs should not be empty")] + EmptyInputs, + #[error("Bytes should be a multiple of 24 for trace or 40 for memory")] + CairoImportError, + #[error("Parse public input error: {0}")] + ParsePublicInputError(String), + #[error("Proving error")] + ProvingError, + #[error(transparent)] + EncodeError(#[from] EncodeError), + #[error(transparent)] + DecodeError(#[from] DecodeError), + #[error("Segment not found in memory(public input)")] + SegmentNotFound, + #[error("Address({0}) not found in memory(public input)")] + AddressNotFound(u64), + #[error(transparent)] + SignError(#[from] SignError), + #[error("Invalid inputs")] + InvalidInputs, + #[error("Invalid finite field: 32 bytes needed")] + InvalidFiniteField, + #[error("Invalid Point")] + InvalidAffinePoint, + #[error("Invalid signature: 64 bytes needed")] + InvalidSignatureFormat, + #[error("Signature verification failed")] + SigVerifyError, + #[error(transparent)] + JsonError(#[from] JsonError), + #[error("Invalid public key")] + InvalidPublicKey, + #[error("Invalid DH key")] + InvalidDHKey, + #[error("Invalid mac in decryption")] + DecryptionFailure, + #[error("The length of ciphertext is not correct")] + InvalidCiphertextLength, +} + +impl Encoder for CairoError { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + self.to_string().encode(env) + } +} + +impl From for rustler::Error { + fn from(e: CairoError) -> Self { + rustler::Error::Term(Box::new(e)) + } +} diff --git a/native/cairo_prover/src/errors.rs b/native/cairo_prover/src/errors.rs deleted file mode 100644 index f9f55db..0000000 --- a/native/cairo_prover/src/errors.rs +++ /dev/null @@ -1,164 +0,0 @@ -use rustler::{Encoder, Env, Term}; - -#[derive(Debug)] -pub(crate) enum CairoProveError { - RegisterStatesError(String), - CairoMemoryError(String), - ProofGenerationError(String), - PublicInputError(String), - EncodingError(String), -} - -impl std::fmt::Display for CairoProveError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoProveError::RegisterStatesError(msg) => { - write!(f, "Register states error: {}", msg) - } - CairoProveError::CairoMemoryError(msg) => write!(f, "Cairo memory error: {}", msg), - CairoProveError::ProofGenerationError(msg) => { - write!(f, "Proof generation failed: {}", msg) - } - CairoProveError::PublicInputError(msg) => write!(f, "Public input error: {}", msg), - CairoProveError::EncodingError(msg) => write!(f, "Encoding error: {}", msg), - } - } -} - -impl Encoder for CairoProveError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoVerifyError { - ProofDecodingError(String), - PublicInputDecodingError(String), -} - -impl std::fmt::Display for CairoVerifyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoVerifyError::ProofDecodingError(msg) => write!(f, "Proof decoding error: {}", msg), - CairoVerifyError::PublicInputDecodingError(msg) => { - write!(f, "Public input decoding error: {}", msg) - } - } - } -} - -impl Encoder for CairoVerifyError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoGetOutputError { - DecodingError(String), - SegmentNotFound, - AddressNotFound(u64), -} - -impl std::fmt::Display for CairoGetOutputError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoGetOutputError::DecodingError(msg) => write!(f, "Decoding error: {}", msg), - CairoGetOutputError::SegmentNotFound => { - write!(f, "Output segment not found in memory segments") - } - CairoGetOutputError::AddressNotFound(addr) => { - write!(f, "Address {} not found in public memory", addr) - } - } - } -} - -impl Encoder for CairoGetOutputError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub(crate) enum CairoSignError { - SignatureGenerationError(String), -} - -impl std::fmt::Display for CairoSignError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoSignError::SignatureGenerationError(msg) => { - write!(f, "Binding Signature generation error: {}", msg) - } - } - } -} - -impl Encoder for CairoSignError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum CairoBindingSigVerifyError { - InputError, - VerificationError, -} - -impl std::fmt::Display for CairoBindingSigVerifyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoBindingSigVerifyError::InputError => write!(f, "Invalid input data"), - CairoBindingSigVerifyError::VerificationError => { - write!(f, "Signature verification failed") - } - } - } -} - -impl Encoder for CairoBindingSigVerifyError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum CairoBindingSigError { - KeyGenerationError, -} - -impl std::fmt::Display for CairoBindingSigError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoBindingSigError::KeyGenerationError => write!(f, "Error generating key"), - } - } -} - -impl Encoder for CairoBindingSigError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} - -#[derive(Debug)] -pub enum TypeError { - DecodingError(String), -} - -impl std::fmt::Display for TypeError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - TypeError::DecodingError(msg) => write!(f, "Type error: {}", msg), - } - } -} - -impl Encoder for TypeError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} diff --git a/native/cairo_prover/src/lib.rs b/native/cairo_prover/src/lib.rs index 2641f37..b08e80a 100644 --- a/native/cairo_prover/src/lib.rs +++ b/native/cairo_prover/src/lib.rs @@ -2,16 +2,13 @@ mod compliance_input; mod encryption; -mod errors; +mod error; mod utils; use crate::{ compliance_input::ComplianceInputJson, encryption::Ciphertext, - errors::{ - CairoBindingSigError, CairoBindingSigVerifyError, CairoGetOutputError, CairoProveError, - CairoSignError, CairoVerifyError, - }, + error::CairoError, utils::{bytes_to_affine, bytes_to_felt, bytes_to_felt_vec, felt_to_string, random_felt}, }; use cairo_platinum_prover::{ @@ -27,14 +24,11 @@ use num_bigint::BigInt; use num_integer::Integer; use num_traits::Zero; use rand::{thread_rng, RngCore}; -use rustler::{Error, NifResult}; +use rustler::NifResult; use stark_platinum_prover::proof::options::{ProofOptions, SecurityLevel}; use starknet_crypto::{poseidon_hash, poseidon_hash_many, poseidon_hash_single, sign, verify}; use starknet_curve::curve_params::{EC_ORDER, GENERATOR}; -use starknet_types_core::{ - curve::{AffinePoint, ProjectivePoint}, - felt::Felt, -}; +use starknet_types_core::{curve::ProjectivePoint, felt::Felt}; use std::ops::Add; #[rustler::nif(schedule = "DirtyCpu")] @@ -43,24 +37,18 @@ fn cairo_prove( memory: Vec, public_input: Vec, ) -> NifResult<(Vec, Vec)> { + if trace.is_empty() || memory.is_empty() || public_input.is_empty() { + return Err(CairoError::EmptyInputs.into()); + } // Generating the prover args - let register_states = RegisterStates::from_bytes_le(&trace).map_err(|e| { - Error::Term(Box::new(CairoProveError::RegisterStatesError(format!( - "{:?}", - e - )))) - })?; - - let memory = CairoMemory::from_bytes_le(&memory).map_err(|e| { - Error::Term(Box::new(CairoProveError::CairoMemoryError(format!( - "{:?}", - e - )))) - })?; + let register_states = + RegisterStates::from_bytes_le(&trace).map_err(|_| CairoError::CairoImportError)?; + + let memory = CairoMemory::from_bytes_le(&memory).map_err(|_| CairoError::CairoImportError)?; // Handle public inputs let (rc_min, rc_max, public_memory, memory_segments) = parse_public_input(&public_input) - .map_err(|e| Error::Term(Box::new(CairoProveError::PublicInputError(e.to_string()))))?; + .map_err(|e| CairoError::ParsePublicInputError(e.to_string()))?; let num_steps = register_states.steps(); let mut pub_inputs = PublicInputs { @@ -81,24 +69,19 @@ fn cairo_prove( // Generating proof let proof_options = ProofOptions::new_secure(SecurityLevel::Conjecturable100Bits, 3); - let proof = generate_cairo_proof(&main_trace, &pub_inputs, &proof_options).map_err(|e| { - Error::Term(Box::new(CairoProveError::ProofGenerationError(format!( - "{:?}", - e - )))) - })?; + let proof = generate_cairo_proof(&main_trace, &pub_inputs, &proof_options) + .map_err(|_| CairoError::ProvingError)?; // Encode proof and pub_inputs let proof_bytes = bincode::serde::encode_to_vec(proof, bincode::config::standard()) - .map_err(|e| Error::Term(Box::new(CairoProveError::EncodingError(format!("{:?}", e)))))?; + .map_err(CairoError::from)?; let pub_input_bytes = bincode::serde::encode_to_vec(&pub_inputs, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoProveError::EncodingError(format!("{:?}", e)))) - })?; + .map_err(CairoError::from)?; Ok((proof_bytes, pub_input_bytes)) } +#[allow(clippy::type_complexity)] fn parse_public_input( public_input: &[u8], ) -> Result< @@ -147,9 +130,7 @@ fn parse_public_input( let value = Felt252::from_bytes_le( public_input .get(start_index + 8..start_index + 40) - .ok_or("Input too short for public memory value")? - .try_into() - .map_err(|_| "Failed to convert public memory value bytes")?, + .ok_or("Input too short for public memory value")?, ) .map_err(|_| "Failed to create Felt252 from bytes")?; public_memory.insert(addr, value); @@ -201,20 +182,12 @@ fn cairo_verify(proof: Vec, public_input: Vec) -> NifResult { // Decode proof let proof = bincode::serde::decode_from_slice(&proof, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoVerifyError::ProofDecodingError( - e.to_string(), - ))) - })? + .map_err(CairoError::from)? .0; // Decode public inputs let pub_inputs = bincode::serde::decode_from_slice(&public_input, bincode::config::standard()) - .map_err(|e| { - Error::Term(Box::new(CairoVerifyError::PublicInputDecodingError( - e.to_string(), - ))) - })? + .map_err(CairoError::from)? .0; Ok(verify_cairo_proof(&proof, &pub_inputs, &proof_options)) @@ -224,15 +197,14 @@ fn cairo_verify(proof: Vec, public_input: Vec) -> NifResult { fn cairo_get_output(public_input: Vec) -> NifResult>> { // Decode public inputs let (pub_inputs, _): (PublicInputs, usize) = - bincode::serde::decode_from_slice(&public_input, bincode::config::standard()).map_err( - |e| Error::Term(Box::new(CairoGetOutputError::DecodingError(e.to_string()))), - )?; + bincode::serde::decode_from_slice(&public_input, bincode::config::standard()) + .map_err(CairoError::from)?; // Get output segments let output_segments = pub_inputs .memory_segments .get(&SegmentName::Output) - .ok_or_else(|| Error::Term(Box::new(CairoGetOutputError::SegmentNotFound)))?; + .ok_or_else(|| CairoError::SegmentNotFound)?; let begin_addr: u64 = output_segments.begin_addr as u64; let stop_addr: u64 = output_segments.stop_ptr as u64; @@ -245,9 +217,7 @@ fn cairo_get_output(public_input: Vec) -> NifResult>> { if let Some(value) = pub_inputs.public_memory.get(&addr_field_element) { output_values.push(value.clone().to_bytes_be().to_vec()); } else { - return Err(Error::Term(Box::new(CairoGetOutputError::AddressNotFound( - addr, - )))); + return Err(CairoError::AddressNotFound(addr).into()); } } @@ -261,12 +231,15 @@ fn cairo_binding_sig_sign( private_key_segments: Vec, messages: Vec>, ) -> NifResult> { + if private_key_segments.is_empty() || private_key_segments.len() % 32 != 0 { + return Err(CairoError::InvalidInputs.into()); + } // Compute private key let private_key = { let result = private_key_segments .chunks(32) .fold(BigInt::zero(), |acc, key_segment| { - let key = BigInt::from_bytes_be(num_bigint::Sign::Plus, &key_segment); + let key = BigInt::from_bytes_be(num_bigint::Sign::Plus, key_segment); acc.add(key) }) .mod_floor(&EC_ORDER.to_bigint()); @@ -288,11 +261,7 @@ fn cairo_binding_sig_sign( rng.fill_bytes(&mut felt); Felt::from_bytes_be(&felt) }; - let signature = sign(&private_key, &sig_hash, &k).map_err(|e| { - Error::Term(Box::new(CairoSignError::SignatureGenerationError( - e.to_string(), - ))) - })?; + let signature = sign(&private_key, &sig_hash, &k).map_err(CairoError::from)?; // Serialize signature let mut ret = Vec::new(); @@ -311,46 +280,30 @@ fn cairo_binding_sig_verify( signature: Vec, ) -> NifResult { // Generate the public key - let pub_key = pub_key_segments - .into_iter() - .try_fold(ProjectivePoint::identity(), |acc, bytes| { - let key_x = Felt::from_bytes_be( - &bytes[0..32] - .try_into() - .map_err(|_| CairoBindingSigVerifyError::InputError)?, - ); - let key_y = Felt::from_bytes_be( - &bytes[32..64] - .try_into() - .map_err(|_| CairoBindingSigVerifyError::InputError)?, - ); - let key_segment_affine = AffinePoint::new(key_x, key_y) - .map_err(|_| CairoBindingSigVerifyError::InputError)?; - Ok(acc.add(key_segment_affine)) - }) - .map_err(|e: CairoBindingSigVerifyError| Error::Term(Box::new(e)))? + let mut pub_key = ProjectivePoint::identity(); + for pk_seg_bytes in pub_key_segments.into_iter() { + let pk_seg = bytes_to_affine(pk_seg_bytes)?; + pub_key += pk_seg; + } + let pub_key_x = pub_key .to_affine() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))? + .map_err(|_| CairoError::InvalidAffinePoint)? .x(); // Message digest let msg = message_digest(messages)?; // Decode the signature - let r = Felt::from_bytes_be( - signature[0..32] - .try_into() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))?, - ); - let s = Felt::from_bytes_be( - signature[32..64] - .try_into() - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::InputError)))?, - ); + if signature.len() != 64 { + return Err(CairoError::InvalidSignatureFormat.into()); + } + + let (r_bytes, s_bytes) = signature.split_at(32); + let r = bytes_to_felt(r_bytes.to_vec())?; + let s = bytes_to_felt(s_bytes.to_vec())?; // Verify the signature - verify(&pub_key, &msg, &r, &s) - .map_err(|_| Error::Term(Box::new(CairoBindingSigVerifyError::VerificationError))) + verify(&pub_key_x, &msg, &r, &s).map_err(|_| CairoError::SigVerifyError.into()) } // random_felt can help create private key in signature @@ -361,14 +314,14 @@ fn cairo_random_felt() -> NifResult> { #[rustler::nif] fn get_public_key(priv_key: Vec) -> NifResult> { - let priv_key_felt = Felt::from_bytes_be_slice(&priv_key); + let priv_key_felt = bytes_to_felt(priv_key)?; let generator = ProjectivePoint::from_affine(GENERATOR.x(), GENERATOR.y()) - .map_err(|_| Error::Term(Box::new(CairoBindingSigError::KeyGenerationError)))?; + .map_err(|_| CairoError::InvalidAffinePoint)?; let pub_key = (&generator * priv_key_felt) .to_affine() - .map_err(|_| Error::Term(Box::new(CairoBindingSigError::KeyGenerationError)))?; + .map_err(|_| CairoError::InvalidAffinePoint)?; let mut ret = pub_key.x().to_bytes_be().to_vec(); let mut y = pub_key.y().to_bytes_be().to_vec(); @@ -405,14 +358,12 @@ fn poseidon_many(inputs: Vec>) -> NifResult> { #[rustler::nif] fn program_hash(public_inputs: Vec) -> NifResult> { let (pub_inputs, _): (PublicInputs, usize) = - bincode::serde::decode_from_slice(&public_inputs, bincode::config::standard()).unwrap(); - let program_segments = match pub_inputs.memory_segments.get(&SegmentName::Program) { - Some(segment) => segment, - None => { - eprintln!("Error: 'Program' segment not found in memory_segments"); - return Ok(vec![]); - } - }; + bincode::serde::decode_from_slice(&public_inputs, bincode::config::standard()) + .map_err(CairoError::from)?; + let program_segments = pub_inputs + .memory_segments + .get(&SegmentName::Program) + .ok_or_else(|| CairoError::SegmentNotFound)?; let begin_addr: u64 = program_segments.begin_addr as u64; let stop_addr: u64 = program_segments.stop_ptr as u64; @@ -421,16 +372,11 @@ fn program_hash(public_inputs: Vec) -> NifResult> { for addr in begin_addr..stop_addr { // Convert addr to FieldElement (assuming this is the correct way to create a FieldElement from an address) let addr_field_element = Felt252::from(addr); - - if let Some(value) = pub_inputs.public_memory.get(&addr_field_element) { - program.push(Felt::from_raw(value.to_raw().limbs)); - } else { - eprintln!( - "Error: Address {:?} not found in public memory", - addr_field_element - ); - return Ok(vec![]); - } + let value = pub_inputs + .public_memory + .get(&addr_field_element) + .ok_or_else(|| CairoError::AddressNotFound(addr))?; + program.push(Felt::from_raw(value.to_raw().limbs)); } let program_hash = poseidon_hash_many(&program); @@ -439,8 +385,8 @@ fn program_hash(public_inputs: Vec) -> NifResult> { } #[rustler::nif] -fn cairo_felt_to_string(felt: Vec) -> String { - felt_to_string(&felt) +fn cairo_felt_to_string(felt: Vec) -> NifResult { + Ok(felt_to_string(felt)?) } #[rustler::nif] @@ -452,16 +398,16 @@ fn cairo_generate_compliance_input_json( input_nf_key: Vec, eph_root: Vec, rcv: Vec, -) -> String { - ComplianceInputJson::to_json_string( - &input_resource, - &output_resource, - &path, +) -> NifResult { + Ok(ComplianceInputJson::to_json_string( + input_resource, + output_resource, + path, pos, - &input_nf_key, - &eph_root, - &rcv, - ) + input_nf_key, + eph_root, + rcv, + )?) } #[rustler::nif] @@ -484,7 +430,7 @@ fn encrypt( let nonce_felt = bytes_to_felt(nonce)?; // Encrypt - let cipher = Ciphertext::encrypt(&msgs_felt, &pk_affine, &sk_felt, &nonce_felt); + let cipher = Ciphertext::encrypt(&msgs_felt, &pk_affine, &sk_felt, &nonce_felt)?; let cipher_bytes = cipher .inner() .iter() @@ -497,13 +443,13 @@ fn encrypt( #[rustler::nif] fn decrypt(cihper: Vec>, sk: Vec) -> NifResult>> { // Decode messages - let cipher_felt = bytes_to_felt_vec(cihper)?; + let cipher = Ciphertext::from_bytes(cihper)?; // Decode sk let sk_felt = bytes_to_felt(sk)?; // Encrypt - let plaintext = Ciphertext::from(cipher_felt).decrypt(&sk_felt).unwrap(); + let plaintext = cipher.decrypt(&sk_felt)?; let plaintext_bytes = plaintext.iter().map(|x| x.to_bytes_be().to_vec()).collect(); Ok(plaintext_bytes) diff --git a/native/cairo_prover/src/utils.rs b/native/cairo_prover/src/utils.rs index d4e5cf9..2b9b23e 100644 --- a/native/cairo_prover/src/utils.rs +++ b/native/cairo_prover/src/utils.rs @@ -1,17 +1,13 @@ -use crate::errors::TypeError; +use crate::error::CairoError; use rand::{thread_rng, RngCore}; -use rustler::{Error, NifResult}; use starknet_types_core::curve::AffinePoint; use starknet_types_core::felt::Felt; -pub fn felt_to_string(felt: &Vec) -> String { - assert_eq!(felt.len(), 32, "The felt size is not 32 bytes"); - Felt::from_bytes_be( - felt.as_slice() - .try_into() - .expect("Slice with incorrect length"), - ) - .to_hex_string() +pub fn felt_to_string(bytes: Vec) -> Result { + let felt: [u8; 32] = bytes + .try_into() + .map_err(|_| CairoError::InvalidFiniteField)?; + Ok(Felt::from_bytes_be(&felt).to_hex_string()) } pub fn random_felt() -> Vec { @@ -22,45 +18,35 @@ pub fn random_felt() -> Vec { felt.to_bytes_be().to_vec() } -pub fn bytes_to_felt_vec(bytes: Vec>) -> NifResult> { +pub fn bytes_to_felt_vec(bytes_vec: Vec>) -> Result, CairoError> { + if bytes_vec.is_empty() { + return Err(CairoError::InvalidInputs); + } let mut vec_fe = Vec::new(); - for i in bytes { - let i_bytes: [u8; 32] = i.as_slice().try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError( - "invalid felt".to_string(), - ))) - })?; - vec_fe.push(Felt::from_bytes_be(&i_bytes)) + for fe_bytes in bytes_vec { + let fe = bytes_to_felt(fe_bytes)?; + vec_fe.push(fe) } Ok(vec_fe) } -pub fn bytes_to_felt(bytes: Vec) -> NifResult { - let felt: [u8; 32] = bytes.try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError( - "invalid felt".to_string(), - ))) - })?; +pub fn bytes_to_felt(bytes: Vec) -> Result { + let felt: [u8; 32] = bytes + .try_into() + .map_err(|_| CairoError::InvalidFiniteField)?; Ok(Felt::from_bytes_be(&felt)) } -pub fn bytes_to_affine(bytes: Vec) -> NifResult { +pub fn bytes_to_affine(bytes: Vec) -> Result { if bytes.len() != 64 { - return Err(Error::Term(Box::new(TypeError::DecodingError( - "invalid pk".to_string(), - )))); + return Err(CairoError::InvalidAffinePoint); } - let key_x = - Felt::from_bytes_be(&bytes[0..32].try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string()))) - })?); - let key_y = - Felt::from_bytes_be(&bytes[32..64].try_into().map_err(|_| { - Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string()))) - })?); - AffinePoint::new(key_x, key_y) - .map_err(|_| Error::Term(Box::new(TypeError::DecodingError("invalid pk".to_string())))) + let (x, y) = bytes.split_at(32); + let key_x = bytes_to_felt(x.to_vec())?; + let key_y = bytes_to_felt(y.to_vec())?; + + AffinePoint::new(key_x, key_y).map_err(|_| CairoError::InvalidAffinePoint) } diff --git a/native/cairo_vm/Cargo.lock b/native/cairo_vm/Cargo.lock index 3b95479..e8d6bb5 100644 --- a/native/cairo_vm/Cargo.lock +++ b/native/cairo_vm/Cargo.lock @@ -287,6 +287,7 @@ dependencies = [ "juvix-cairo-vm", "rustler", "serde_json", + "thiserror", ] [[package]] diff --git a/native/cairo_vm/Cargo.toml b/native/cairo_vm/Cargo.toml index 98e5090..8fcac8c 100644 --- a/native/cairo_vm/Cargo.toml +++ b/native/cairo_vm/Cargo.toml @@ -14,3 +14,4 @@ rustler = "0.31.0" bincode = "2.0.0-rc.3" juvix-cairo-vm = { git = "https://github.com/anoma/juvix-cairo-vm"} serde_json = "1.0.120" +thiserror = "1.0" diff --git a/native/cairo_vm/src/error.rs b/native/cairo_vm/src/error.rs new file mode 100644 index 0000000..18f4a71 --- /dev/null +++ b/native/cairo_vm/src/error.rs @@ -0,0 +1,24 @@ +use rustler::{Encoder, Env, Term}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum CairoVMError { + #[error("Invalid program content")] + InvalidProgramContent, + #[error("Invalid input JSON")] + InvalidInputJSON, + #[error("Runtime error: {0}")] + RuntimeError(String), +} + +impl Encoder for CairoVMError { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + self.to_string().encode(env) + } +} + +impl From for rustler::Error { + fn from(e: CairoVMError) -> Self { + rustler::Error::Term(Box::new(e)) + } +} diff --git a/native/cairo_vm/src/errors.rs b/native/cairo_vm/src/errors.rs deleted file mode 100644 index 0c55249..0000000 --- a/native/cairo_vm/src/errors.rs +++ /dev/null @@ -1,24 +0,0 @@ -use rustler::{Encoder, Env, Term}; - -#[derive(Debug)] -pub(crate) enum CairoVMError { - InvalidProgramContent, - InvalidInputJSON, - RuntimeError(String), -} - -impl std::fmt::Display for CairoVMError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - CairoVMError::InvalidProgramContent => write!(f, "Invalid program content"), - CairoVMError::InvalidInputJSON => write!(f, "Invalid input JSON"), - CairoVMError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg), - } - } -} - -impl Encoder for CairoVMError { - fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { - self.to_string().encode(env) - } -} diff --git a/native/cairo_vm/src/lib.rs b/native/cairo_vm/src/lib.rs index 89eb768..43c5ae6 100644 --- a/native/cairo_vm/src/lib.rs +++ b/native/cairo_vm/src/lib.rs @@ -1,11 +1,12 @@ -mod errors; +mod error; -use crate::errors::CairoVMError; +use crate::error::CairoVMError; use juvix_cairo_vm::{anoma_cairo_vm_runner, program_input::ProgramInput}; -use rustler::{Error, NifResult}; +use rustler::NifResult; use serde_json::Value; use std::collections::HashMap; +#[allow(clippy::type_complexity)] #[rustler::nif(schedule = "DirtyCpu")] fn cairo_vm_runner( program_content: String, @@ -13,18 +14,17 @@ fn cairo_vm_runner( ) -> NifResult<(String, Vec, Vec, Vec)> { // Validate program content serde_json::from_str::(&program_content) - .map_err(|_| Error::Term(Box::new(CairoVMError::InvalidProgramContent)))?; + .map_err(|_| CairoVMError::InvalidProgramContent)?; // Load program input let program_input = if inputs.is_empty() { ProgramInput::new(HashMap::new()) } else { - ProgramInput::from_json(&inputs) - .map_err(|_| Error::Term(Box::new(CairoVMError::InvalidInputJSON)))? + ProgramInput::from_json(&inputs).map_err(|_| CairoVMError::InvalidInputJSON)? }; - anoma_cairo_vm_runner(&program_content.as_bytes(), program_input) - .map_err(|e| Error::Term(Box::new(CairoVMError::RuntimeError(e.to_string())))) + anoma_cairo_vm_runner(program_content.as_bytes(), program_input) + .map_err(|e| CairoVMError::RuntimeError(e.to_string()).into()) } rustler::init!("Elixir.Cairo.CairoVM", [cairo_vm_runner]); diff --git a/test/cairo_binding_signature.exs b/test/cairo_binding_signature.exs index bdff799..0ef7f16 100644 --- a/test/cairo_binding_signature.exs +++ b/test/cairo_binding_signature.exs @@ -16,5 +16,56 @@ defmodule BindingSignatureTest do # Sign and verify signature = (priv_key_1 ++ priv_key_2) |> Cairo.sign(msg) assert true = Cairo.sig_verify(pub_keys, msg, signature) + + # Wrong pub_key + wrong_pub_key = Cairo.get_public_key(priv_key_1) + refute Cairo.sig_verify([wrong_pub_key], msg, signature) + + # Wrong msg + refute Cairo.sig_verify(pub_keys, [List.duplicate(1, 32)], signature) + + # Wrong signature + refute Cairo.sig_verify(pub_keys, msg, List.duplicate(1, 64)) + end + + test "cairo_binding_signature_invalid_input_test" do + priv_key_1 = Cairo.random_felt() + priv_key_2 = Cairo.random_felt() + + pub_keys = + [priv_key_1, priv_key_2] + |> Enum.map(fn x -> Cairo.get_public_key(x) end) + + msg = [Cairo.random_felt(), Cairo.random_felt()] + + assert {:error, "Invalid inputs"} = Cairo.sign([], msg) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sign(priv_key_1, [[]]) + + assert {:error, "Invalid inputs"} = Cairo.sign([1, 2], msg) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sign(priv_key_1, [[1, 2]]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.get_public_key([]) + + signature = (priv_key_1 ++ priv_key_2) |> Cairo.sign(msg) + assert {:error, "Invalid Point"} = Cairo.sig_verify([[]], msg, signature) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sig_verify(pub_keys, [[]], signature) + + assert {:error, "Invalid signature: 64 bytes needed"} = + Cairo.sig_verify(pub_keys, msg, []) + + assert {:error, "Invalid Point"} = Cairo.sig_verify([[1]], msg, signature) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.sig_verify(pub_keys, [[1]], signature) + + assert {:error, "Invalid signature: 64 bytes needed"} = + Cairo.sig_verify(pub_keys, msg, [1]) end end diff --git a/test/cairo_encryption.exs b/test/cairo_encryption.exs index ffcd9e1..582db47 100644 --- a/test/cairo_encryption.exs +++ b/test/cairo_encryption.exs @@ -36,8 +36,45 @@ defmodule NifTest do assert Cairo.get_output(public_input) == expected_cipher # decryption - plaintext = Cairo.decrypt(expected_cipher, felt_bytes_1) + plaintext = Cairo.decrypt(expected_cipher, sk) assert plaintext == expected_plaintext + + # decryption: wrong sk + assert {:error, "Invalid DH key"} = Cairo.decrypt(expected_cipher, felt_bytes_0) + end + + test "cairo_encryption_invalid_input_test" do + felt_bytes = List.duplicate(1, 32) + plaintext = List.duplicate(felt_bytes, 10) + pk = Cairo.get_public_key(felt_bytes) + invalid_pk = List.duplicate(1, 64) + sk = felt_bytes + nonce = felt_bytes + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt([[]], pk, sk, nonce) + + assert {:error, "Invalid Point"} = Cairo.encrypt(plaintext, [], sk, nonce) + + assert {:error, "Invalid Point"} = + Cairo.encrypt(plaintext, invalid_pk, sk, nonce) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt(plaintext, pk, [], nonce) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.encrypt(plaintext, pk, sk, []) + + cipher = List.duplicate(felt_bytes, 14) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.decrypt([[]], sk) + + assert {:error, "The length of ciphertext is not correct"} = + Cairo.decrypt([felt_bytes], sk) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.decrypt(cipher, []) end end diff --git a/test/cairo_negative_test.exs b/test/cairo_negative_test.exs index 829b389..5a97083 100644 --- a/test/cairo_negative_test.exs +++ b/test/cairo_negative_test.exs @@ -37,67 +37,51 @@ defmodule NegativeTest do assert String.starts_with?(error_message, "Runtime error:") end - test "cairo_prove with invalid trace (RegisterStatesError)" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, _trace, memory, vm_public_input} = - Cairo.cairo_vm_runner( - program, - input - ) - - invalid_trace = [0, 1, 2, 3] - - assert {:error, error_message} = - Cairo.prove(invalid_trace, memory, vm_public_input) - - assert String.starts_with?(error_message, "Register states error:") - end - - test "cairo_prove with invalid memory (CairoMemoryError)" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, _memory, vm_public_input} = - Cairo.cairo_vm_runner( - program, - input - ) - - invalid_memory = [0, 1, 2, 3] - - assert {:error, error_message} = - Cairo.prove(trace, invalid_memory, vm_public_input) - - assert String.starts_with?(error_message, "Cairo memory error:") - end - - test "cairo_verify with invalid proof" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, memory, vm_public_input} = - Cairo.cairo_vm_runner(program, input) - - {_proof, public_input} = Cairo.prove(trace, memory, vm_public_input) - invalid_proof = [0, 1, 2, 3] - - assert {:error, error_message} = Cairo.verify(invalid_proof, public_input) - assert String.starts_with?(error_message, "Proof decoding error:") + test "cairo_get_output" do + assert {:error, _} = Cairo.get_output([]) + assert {:error, _} = Cairo.get_output([1, 2, 3, 4]) end - test "cairo_verify with invalid public input" do - {:ok, program} = File.read("./native/cairo_vm/cairo.json") - {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") - - {_output, trace, memory, vm_public_input} = - Cairo.cairo_vm_runner(program, input) - - {proof, _public_input} = Cairo.prove(trace, memory, vm_public_input) - invalid_public_input = [] - - assert {:error, error_message} = Cairo.verify(proof, invalid_public_input) - assert String.starts_with?(error_message, "Public input decoding error:") + test "cairo_felt_to_string" do + assert "0x0" = Cairo.felt_to_string(List.duplicate(0, 32)) + + assert "0x7752582c54a42fe0fa35c40f07293bb7d8efe90e21d8d2c06a7db52d7d9b7a1" = + Cairo.felt_to_string([ + 7, + 117, + 37, + 130, + 197, + 74, + 66, + 254, + 15, + 163, + 92, + 64, + 240, + 114, + 147, + 187, + 125, + 142, + 254, + 144, + 226, + 29, + 141, + 44, + 6, + 167, + 219, + 82, + 215, + 217, + 183, + 161 + ]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.felt_to_string([1, 2, 3, 4]) end end diff --git a/test/cairo_test.exs b/test/cairo_test.exs index b15a4b1..5d44ae1 100644 --- a/test/cairo_test.exs +++ b/test/cairo_test.exs @@ -4,7 +4,7 @@ defmodule NifTest do doctest Cairo.CairoProver doctest Cairo.CairoVM - test "cairo_api_test" do + test "cairo_prove_test" do {:ok, program} = File.read("./native/cairo_vm/cairo.json") {:ok, input} = File.read("./native/cairo_vm/cairo_input.json") @@ -25,5 +25,16 @@ defmodule NifTest do Cairo.get_program_hash(public_input) |> Cairo.felt_to_string() # IO.inspect(program_hash) + + assert {:error, _} = Cairo.prove([], memory, vm_public_input) + assert {:error, _} = Cairo.prove(trace, [], vm_public_input) + assert {:error, _} = Cairo.prove(trace, memory, []) + assert {:error, _} = Cairo.prove([1], memory, vm_public_input) + assert {:error, _} = Cairo.prove(trace, [1], vm_public_input) + assert {:error, _} = Cairo.prove(trace, memory, [1]) + assert {:error, _} = Cairo.verify([], public_input) + assert {:error, _} = Cairo.verify(proof, []) + assert {:error, _} = Cairo.verify([1], public_input) + assert {:error, _} = Cairo.verify(proof, [1]) end end diff --git a/test/poseidon_test.exs b/test/poseidon_test.exs index bdb0db0..9195678 100644 --- a/test/poseidon_test.exs +++ b/test/poseidon_test.exs @@ -71,4 +71,20 @@ defmodule PoseidonTest do assert hash_bytes == output end + + test "poseidon_hash_invalid_input" do + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon_single([]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon([], List.duplicate(1, 32)) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon(List.duplicate(1, 32), []) + + assert {:error, "Invalid inputs"} = Cairo.poseidon_many([]) + + assert {:error, "Invalid finite field: 32 bytes needed"} = + Cairo.poseidon_many([[1]]) + end end