diff --git a/examples/demo-rollup/celestia_rollup_config.toml b/examples/demo-rollup/celestia_rollup_config.toml index 23f9be3ff..e950ee8c5 100644 --- a/examples/demo-rollup/celestia_rollup_config.toml +++ b/examples/demo-rollup/celestia_rollup_config.toml @@ -21,3 +21,6 @@ start_height = 3 # the host and port to bind the rpc server for bind_host = "127.0.0.1" bind_port = 12345 + +[prover_service] +aggregated_proof_block_jump = 1 \ No newline at end of file diff --git a/examples/demo-rollup/mock_rollup_config.toml b/examples/demo-rollup/mock_rollup_config.toml index 2f25d2f48..116021f0c 100644 --- a/examples/demo-rollup/mock_rollup_config.toml +++ b/examples/demo-rollup/mock_rollup_config.toml @@ -13,4 +13,7 @@ start_height = 0 [runner.rpc_config] # the host and port to bind the rpc server for bind_host = "127.0.0.1" -bind_port = 12345 \ No newline at end of file +bind_port = 12345 + +[prover_service] +aggregated_proof_block_jump = 1 \ No newline at end of file diff --git a/examples/demo-rollup/src/celestia_rollup.rs b/examples/demo-rollup/src/celestia_rollup.rs index a26b269a6..5ff8e99ca 100644 --- a/examples/demo-rollup/src/celestia_rollup.rs +++ b/examples/demo-rollup/src/celestia_rollup.rs @@ -12,7 +12,9 @@ use sov_risc0_adapter::host::Risc0Host; use sov_rollup_interface::zk::ZkvmHost; use sov_state::storage_manager::ProverStorageManager; use sov_state::{DefaultStorageSpec, Storage, ZkStorage}; -use sov_stf_runner::{ParallelProverService, RollupConfig, RollupProverConfig}; +use sov_stf_runner::{ + ParallelProverService, ProverServiceConfig, RollupConfig, RollupProverConfig, +}; use crate::{ROLLUP_BATCH_NAMESPACE, ROLLUP_PROOF_NAMESPACE}; @@ -101,6 +103,7 @@ impl RollupBlueprint for CelestiaDemoRollup { async fn create_prover_service( &self, prover_config: RollupProverConfig, + rollup_config: &RollupConfig, _da_service: &Self::DaService, ) -> Self::ProverService { let vm = Risc0Host::new(risc0::ROLLUP_ELF); @@ -117,6 +120,7 @@ impl RollupBlueprint for CelestiaDemoRollup { da_verifier, prover_config, zk_storage, + rollup_config.prover_service, ) } } diff --git a/examples/demo-rollup/src/mock_rollup.rs b/examples/demo-rollup/src/mock_rollup.rs index 958820fc2..4af036748 100644 --- a/examples/demo-rollup/src/mock_rollup.rs +++ b/examples/demo-rollup/src/mock_rollup.rs @@ -12,7 +12,9 @@ use sov_risc0_adapter::host::Risc0Host; use sov_rollup_interface::zk::ZkvmHost; use sov_state::storage_manager::ProverStorageManager; use sov_state::{DefaultStorageSpec, Storage, ZkStorage}; -use sov_stf_runner::{ParallelProverService, RollupConfig, RollupProverConfig}; +use sov_stf_runner::{ + ParallelProverService, ProverServiceConfig, RollupConfig, RollupProverConfig, +}; /// Rollup with MockDa pub struct MockDemoRollup {} @@ -92,6 +94,7 @@ impl RollupBlueprint for MockDemoRollup { async fn create_prover_service( &self, prover_config: RollupProverConfig, + rollup_config: &RollupConfig, _da_service: &Self::DaService, ) -> Self::ProverService { let vm = Risc0Host::new(risc0::MOCK_DA_ELF); @@ -105,6 +108,7 @@ impl RollupBlueprint for MockDemoRollup { da_verifier, prover_config, zk_storage, + rollup_config.prover_service, ) } } diff --git a/examples/demo-rollup/tests/test_helpers.rs b/examples/demo-rollup/tests/test_helpers.rs index d54d5b786..909b2e878 100644 --- a/examples/demo-rollup/tests/test_helpers.rs +++ b/examples/demo-rollup/tests/test_helpers.rs @@ -4,7 +4,9 @@ use demo_stf::genesis_config::GenesisPaths; use sov_demo_rollup::MockDemoRollup; use sov_mock_da::{MockAddress, MockDaConfig}; use sov_modules_rollup_blueprint::RollupBlueprint; -use sov_stf_runner::{RollupConfig, RollupProverConfig, RpcConfig, RunnerConfig, StorageConfig}; +use sov_stf_runner::{ + ProverServiceConfig, RollupConfig, RollupProverConfig, RpcConfig, RunnerConfig, StorageConfig, +}; use tokio::sync::oneshot; pub async fn start_rollup( @@ -29,6 +31,9 @@ pub async fn start_rollup( da: MockDaConfig { sender_address: MockAddress::from([0; 32]), }, + prover_service: ProverServiceConfig { + aggregated_proof_block_jump: 1, + }, }; let mock_demo_rollup = MockDemoRollup {}; diff --git a/full-node/sov-stf-runner/src/config.rs b/full-node/sov-stf-runner/src/config.rs index 5fcb90b86..ee6575d16 100644 --- a/full-node/sov-stf-runner/src/config.rs +++ b/full-node/sov-stf-runner/src/config.rs @@ -30,6 +30,13 @@ pub struct StorageConfig { pub path: PathBuf, } +///TODO +#[derive(Debug, Clone, PartialEq, Deserialize, Copy)] +pub struct ProverServiceConfig { + ///TODO + pub aggregated_proof_block_jump: u64, +} + /// Rollup Configuration #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct RollupConfig { @@ -39,6 +46,8 @@ pub struct RollupConfig { pub runner: RunnerConfig, /// Data Availability service configuration. pub da: DaServiceConfig, + /// Prover service configuration. + pub prover_service: ProverServiceConfig, } /// Reads toml file as a specific type. @@ -85,6 +94,8 @@ mod tests { [runner.rpc_config] bind_host = "127.0.0.1" bind_port = 12345 + [prover_service] + aggregated_proof_block_jump = 22 "#; let config_file = create_config_from(config); @@ -109,6 +120,9 @@ mod tests { storage: StorageConfig { path: PathBuf::from("/tmp"), }, + prover_service: ProverServiceConfig { + aggregated_proof_block_jump: 22, + }, }; assert_eq!(config, expected); } diff --git a/full-node/sov-stf-runner/src/lib.rs b/full-node/sov-stf-runner/src/lib.rs index 6aaf1a9e9..bc3b2c67c 100644 --- a/full-node/sov-stf-runner/src/lib.rs +++ b/full-node/sov-stf-runner/src/lib.rs @@ -23,7 +23,7 @@ pub use prover_service::*; #[cfg(feature = "native")] mod runner; #[cfg(feature = "native")] -pub use config::{from_toml_path, RollupConfig, RunnerConfig, StorageConfig}; +pub use config::{from_toml_path, ProverServiceConfig, RollupConfig, RunnerConfig, StorageConfig}; #[cfg(feature = "native")] pub use runner::*; use serde::de::DeserializeOwned; diff --git a/full-node/sov-stf-runner/src/prover_service/parallel/mod.rs b/full-node/sov-stf-runner/src/prover_service/parallel/mod.rs index 50508400f..1fb470418 100644 --- a/full-node/sov-stf-runner/src/prover_service/parallel/mod.rs +++ b/full-node/sov-stf-runner/src/prover_service/parallel/mod.rs @@ -1,4 +1,5 @@ mod prover; +mod prover_manager; use std::sync::Arc; use async_trait::async_trait; @@ -13,8 +14,8 @@ use sov_rollup_interface::zk::ZkvmHost; use super::{ProverService, ProverServiceError}; use crate::verifier::StateTransitionVerifier; use crate::{ - ProofGenConfig, ProofProcessingStatus, ProofSubmissionStatus, RollupProverConfig, - StateTransitionData, WitnessSubmissionStatus, + ProofGenConfig, ProofProcessingStatus, ProofSubmissionStatus, ProverServiceConfig, + RollupProverConfig, StateTransitionData, WitnessSubmissionStatus, }; /// Prover service that generates proofs in parallel. @@ -28,6 +29,7 @@ where { vm: Vm, prover_config: Arc>, + prover_service_config: ProverServiceConfig, zk_storage: V::PreState, prover_state: Prover, } @@ -49,6 +51,7 @@ where config: RollupProverConfig, zk_storage: V::PreState, num_threads: usize, + prover_service_config: ProverServiceConfig, ) -> Self { let stf_verifier = StateTransitionVerifier::::new(zk_stf, da_verifier); @@ -65,7 +68,11 @@ where Self { vm, prover_config, - prover_state: Prover::new(num_threads), + prover_state: Prover::new( + num_threads, + prover_service_config.aggregated_proof_block_jump, + ), + prover_service_config, zk_storage, } } @@ -77,11 +84,20 @@ where da_verifier: Da::Verifier, config: RollupProverConfig, zk_storage: V::PreState, + prover_service_config: ProverServiceConfig, ) -> Self { let num_cpus = num_cpus::get(); assert!(num_cpus > 1, "Unable to create parallel prover service"); - Self::new(vm, zk_stf, da_verifier, config, zk_storage, num_cpus - 1) + Self::new( + vm, + zk_stf, + da_verifier, + config, + zk_storage, + num_cpus - 1, + prover_service_config, + ) } } diff --git a/full-node/sov-stf-runner/src/prover_service/parallel/prover.rs b/full-node/sov-stf-runner/src/prover_service/parallel/prover.rs index fabebf8f2..8dcd8ac63 100644 --- a/full-node/sov-stf-runner/src/prover_service/parallel/prover.rs +++ b/full-node/sov-stf-runner/src/prover_service/parallel/prover.rs @@ -1,8 +1,7 @@ -use std::collections::hash_map::Entry; -use std::collections::HashMap; use std::ops::Deref; use std::sync::{Arc, RwLock}; +use super::prover_manager::{ProverManager, ProverStatus}; use serde::de::DeserializeOwned; use serde::Serialize; use sov_rollup_interface::da::{BlockHeaderTrait, DaSpec}; @@ -16,90 +15,10 @@ use crate::{ WitnessSubmissionStatus, }; -enum ProverStatus { - WitnessSubmitted(StateTransitionData), - ProvingInProgress, - Proved(Proof), - Err(anyhow::Error), -} - -struct ProverState { - prover_status: HashMap>, - pending_tasks_count: usize, -} - -impl ProverState { - fn remove(&mut self, hash: &Da::SlotHash) -> Option> { - self.prover_status.remove(hash) - } - - fn set_to_proving( - &mut self, - hash: Da::SlotHash, - ) -> Option> { - self.prover_status - .insert(hash, ProverStatus::ProvingInProgress) - } - - fn set_to_proved( - &mut self, - hash: Da::SlotHash, - proof: Result, - ) -> Option> { - match proof { - Ok(p) => self.prover_status.insert(hash, ProverStatus::Proved(p)), - Err(e) => self.prover_status.insert(hash, ProverStatus::Err(e)), - } - } - - fn get_prover_status( - &self, - hash: Da::SlotHash, - ) -> Option<&ProverStatus> { - self.prover_status.get(&hash) - } - - fn inc_task_count_if_not_busy(&mut self, num_threads: usize) -> bool { - if self.pending_tasks_count >= num_threads { - return false; - } - - self.pending_tasks_count += 1; - true - } - - fn dec_task_count(&mut self) { - assert!(self.pending_tasks_count > 0); - self.pending_tasks_count -= 1; - } - - fn set_to_witness_submitted( - &mut self, - header_hash: Da::SlotHash, - state_transition_data: StateTransitionData, - ) -> WitnessSubmissionStatus { - let entry = self.prover_status.entry(header_hash); - let data = ProverStatus::WitnessSubmitted(state_transition_data); - - match entry { - Entry::Occupied(_) => WitnessSubmissionStatus::WitnessExist, - Entry::Vacant(v) => { - v.insert(data); - WitnessSubmissionStatus::SubmittedForProving - } - } - } -} - -#[derive(Clone)] -struct ProverManager { - prover_state: Arc>>, -} - // A prover that generates proofs in parallel using a thread pool. If the pool is saturated, // the prover will reject new jobs. pub(crate) struct Prover { - prover_manager: Arc>>, + prover_manager: Arc>>, num_threads: usize, pool: rayon::ThreadPool, } @@ -110,7 +29,7 @@ where StateRoot: Serialize + DeserializeOwned + Clone + AsRef<[u8]> + Send + Sync + 'static, Witness: Serialize + DeserializeOwned + Send + Sync + 'static, { - pub(crate) fn new(num_threads: usize) -> Self { + pub(crate) fn new(num_threads: usize, jump: u64) -> Self { Self { num_threads, pool: rayon::ThreadPoolBuilder::new() @@ -118,10 +37,7 @@ where .build() .unwrap(), - prover_manager: Arc::new(RwLock::new(ProverState { - prover_status: Default::default(), - pending_tasks_count: Default::default(), - })), + prover_manager: Arc::new(RwLock::new(ProverManager::new(jump))), } } @@ -130,8 +46,12 @@ where state_transition_data: StateTransitionData, ) -> WitnessSubmissionStatus { let header_hash = state_transition_data.da_block_header.hash(); - let mut prover_state = self.prover_manager.write().expect("Lock was poisoned"); - prover_state.set_to_witness_submitted(header_hash, state_transition_data) + let height = state_transition_data.da_block_header.height(); + self.prover_manager.write().unwrap().submit_witness( + height, + header_hash, + state_transition_data, + ) } pub(crate) fn start_proving( @@ -146,28 +66,29 @@ where V: StateTransitionFunction + Send + Sync + 'static, V::PreState: Send + Sync + 'static, { - let prover_state_clone = self.prover_manager.clone(); - let mut prover_state = self.prover_manager.write().expect("Lock was poisoned"); + let prover_manager_clone = self.prover_manager.clone(); + let mut prover_manager = self.prover_manager.write().expect("Lock was poisoned"); - let prover_status = prover_state + let (prover_status, state_transition_data) = prover_manager .remove(&block_header_hash) .ok_or_else(|| anyhow::anyhow!("Missing witness for block: {:?}", block_header_hash))?; match prover_status { - ProverStatus::WitnessSubmitted(state_transition_data) => { - let start_prover = prover_state.inc_task_count_if_not_busy(self.num_threads); + ProverStatus::WitnessSubmitted => { + let start_prover = prover_manager.inc_task_count_if_not_busy(self.num_threads); // Initiate a new proving job only if the prover is not busy. if start_prover { - prover_state.set_to_proving(block_header_hash.clone()); vm.add_hint(state_transition_data); + prover_manager.set_to_proving(block_header_hash.clone()); + self.pool.spawn(move || { tracing::info_span!("guest_execution").in_scope(|| { let proof = make_proof(vm, config, zk_storage); let mut prover_state = - prover_state_clone.write().expect("Lock was poisoned"); + prover_manager_clone.write().expect("Lock was poisoned"); prover_state.set_to_proved(block_header_hash, proof); prover_state.dec_task_count(); @@ -197,18 +118,19 @@ where &self, block_header_hash: ::SlotHash, ) -> Result { - let mut prover_state = self.prover_manager.write().unwrap(); - let status = prover_state.get_prover_status(block_header_hash.clone()); + let mut prover_manager = self.prover_manager.write().unwrap(); + let status = prover_manager.get_prover_status(&block_header_hash.clone()); match status { Some(ProverStatus::ProvingInProgress) => { Ok(ProofSubmissionStatus::ProofGenerationInProgress) } Some(ProverStatus::Proved(_)) => { - prover_state.remove(&block_header_hash); + //TODO + //prover_manager.remove(&block_header_hash); Ok(ProofSubmissionStatus::Success) } - Some(ProverStatus::WitnessSubmitted(_)) => Err(anyhow::anyhow!( + Some(ProverStatus::WitnessSubmitted) => Err(anyhow::anyhow!( "Witness for {:?} was submitted, but the proof generation is not triggered.", block_header_hash )), diff --git a/full-node/sov-stf-runner/src/prover_service/parallel/prover_manager.rs b/full-node/sov-stf-runner/src/prover_service/parallel/prover_manager.rs new file mode 100644 index 000000000..62c7f04ca --- /dev/null +++ b/full-node/sov-stf-runner/src/prover_service/parallel/prover_manager.rs @@ -0,0 +1,213 @@ +use crate::{ProofSubmissionStatus, StateTransitionData, WitnessSubmissionStatus}; +use sov_rollup_interface::da::DaSpec; +use sov_rollup_interface::zk::Proof; +use std::collections::hash_map::Entry; +use std::collections::HashMap; + +pub(crate) enum ProverStatus { + WitnessSubmitted, + ProvingInProgress, + Proved(Proof), + Err(anyhow::Error), +} + +struct ProverState { + prover_status: HashMap, + witness: HashMap>, + pending_tasks_count: usize, +} + +impl ProverState { + fn remove(&mut self, hash: &Da::SlotHash) -> Option { + self.prover_status.remove(hash) + } + + fn remove_witness( + &mut self, + hash: &Da::SlotHash, + ) -> Option> { + self.witness.remove(hash) + } + + fn set_to_proving(&mut self, hash: Da::SlotHash) -> Option { + self.prover_status + .insert(hash, ProverStatus::ProvingInProgress) + } + + fn set_to_proved( + &mut self, + hash: Da::SlotHash, + proof: Result, + ) -> Option { + match proof { + Ok(p) => self.prover_status.insert(hash, ProverStatus::Proved(p)), + Err(e) => self.prover_status.insert(hash, ProverStatus::Err(e)), + } + } + + fn get_prover_status(&self, hash: &Da::SlotHash) -> Option<&ProverStatus> { + self.prover_status.get(hash) + } + + fn inc_task_count_if_not_busy(&mut self, num_threads: usize) -> bool { + if self.pending_tasks_count >= num_threads { + return false; + } + + self.pending_tasks_count += 1; + true + } + + fn dec_task_count(&mut self) { + assert!(self.pending_tasks_count > 0); + self.pending_tasks_count -= 1; + } +} + +#[derive(Default)] +struct AggregatedProofInfo { + height_to_slot_hash: HashMap, + start_height: u64, + jump: u64, +} + +impl AggregatedProofInfo {} + +pub(crate) struct ProverManager { + prover_state: ProverState, + aggregated_proof_info: AggregatedProofInfo, +} + +impl ProverManager { + pub(crate) fn new(jump: u64) -> Self { + Self { + prover_state: ProverState { + prover_status: Default::default(), + pending_tasks_count: Default::default(), + witness: Default::default(), + }, + aggregated_proof_info: AggregatedProofInfo { + height_to_slot_hash: Default::default(), + start_height: 0, + jump, + }, + } + } + + pub(crate) fn set_to_proving(&mut self, hash: Da::SlotHash) -> Option { + self.prover_state.set_to_proving(hash) + } + + pub(crate) fn set_to_proved( + &mut self, + hash: Da::SlotHash, + proof: Result, + ) -> Option { + self.prover_state.set_to_proved(hash, proof) + } + + pub(crate) fn inc_task_count_if_not_busy(&mut self, num_threads: usize) -> bool { + self.prover_state.inc_task_count_if_not_busy(num_threads) + } + + pub(crate) fn dec_task_count(&mut self) { + self.prover_state.dec_task_count() + } + + pub(crate) fn get_witness( + &mut self, + hash: &Da::SlotHash, + ) -> &StateTransitionData { + self.prover_state.witness.get(hash).unwrap() + } + + pub(crate) fn submit_witness( + &mut self, + height: u64, + header_hash: Da::SlotHash, + state_transition_data: StateTransitionData, + ) -> WitnessSubmissionStatus { + let entry = self.prover_state.prover_status.entry(header_hash.clone()); + let data = ProverStatus::WitnessSubmitted; + + match entry { + Entry::Occupied(_) => WitnessSubmissionStatus::WitnessExist, + Entry::Vacant(v) => { + v.insert(data); + // TODO assert first insertion + self.aggregated_proof_info + .height_to_slot_hash + .insert(height, header_hash.clone()); + + self.prover_state + .witness + .insert(header_hash, state_transition_data); + + WitnessSubmissionStatus::SubmittedForProving + } + } + } + + // TODO change name + pub(crate) fn remove( + &mut self, + hash: &Da::SlotHash, + ) -> Option<(ProverStatus, StateTransitionData)> { + let status = self.prover_state.remove(hash)?; + let witness = self.prover_state.remove_witness(hash)?; + Some((status, witness)) + } + + pub(crate) fn get_prover_status(&mut self, hash: &Da::SlotHash) -> Option<&ProverStatus> { + self.prover_state.get_prover_status(hash) + } + + fn get_aggregated_proof(&mut self) -> Result { + let jump = self.aggregated_proof_info.jump; + let start_height = self.aggregated_proof_info.start_height; + + let mut proofs_data = Vec::default(); + + for height in start_height..start_height + jump { + let hash = self + .aggregated_proof_info + .height_to_slot_hash + .get(&height) + .unwrap(); + + let state = self.prover_state.get_prover_status(hash).unwrap(); + match state { + ProverStatus::WitnessSubmitted => { + return Err(anyhow::anyhow!( + "Witness for {:?} was submitted, but the proof generation is not triggered.", + hash + )) + } + ProverStatus::ProvingInProgress => { + return Ok(ProofSubmissionStatus::ProofGenerationInProgress) + } + ProverStatus::Proved(proof) => proofs_data.push(proof), + ProverStatus::Err(e) => return Err(anyhow::anyhow!(e.to_string())), + } + } + + todo!() + } +} + +struct AggregatedProofWitness { + proof: Proof, + pre_state: StateRoot, + post_state_root: StateRoot, + da_block_hash: SlotHash, + height: u64, +} + +struct AggregatedProofPublicInput { + initial_state: u64, + final_state_root: u64, + initial_height: u64, + final_height: u64, +} + +struct AggrgatedProof {} diff --git a/full-node/sov-stf-runner/tests/prover_tests.rs b/full-node/sov-stf-runner/tests/prover_tests.rs index a125ff467..4a9f43aa9 100644 --- a/full-node/sov-stf-runner/tests/prover_tests.rs +++ b/full-node/sov-stf-runner/tests/prover_tests.rs @@ -6,7 +6,8 @@ use sov_rollup_interface::da::Time; use sov_stf_runner::mock::MockStf; use sov_stf_runner::{ ParallelProverService, ProofProcessingStatus, ProofSubmissionStatus, ProverService, - ProverServiceError, RollupProverConfig, StateTransitionData, WitnessSubmissionStatus, + ProverServiceConfig, ProverServiceError, RollupProverConfig, StateTransitionData, + WitnessSubmissionStatus, }; #[tokio::test] @@ -202,6 +203,9 @@ fn make_new_prover() -> TestProver { prover_config, (), num_threads, + ProverServiceConfig { + aggregated_proof_block_jump: 1, + }, ), vm, num_worker_threads: num_threads, diff --git a/module-system/sov-modules-rollup-blueprint/src/lib.rs b/module-system/sov-modules-rollup-blueprint/src/lib.rs index 687d67604..e6b091eef 100644 --- a/module-system/sov-modules-rollup-blueprint/src/lib.rs +++ b/module-system/sov-modules-rollup-blueprint/src/lib.rs @@ -92,6 +92,7 @@ pub trait RollupBlueprint: Sized + Send + Sync { async fn create_prover_service( &self, prover_config: RollupProverConfig, + rollup_config: &RollupConfig, da_service: &Self::DaService, ) -> Self::ProverService; @@ -118,9 +119,12 @@ pub trait RollupBlueprint: Sized + Send + Sync { ::Storage: NativeStorage, { let da_service = self.create_da_service(&rollup_config).await; - let prover_service = self.create_prover_service(prover_config, &da_service).await; + let prover_service = self + .create_prover_service(prover_config, &rollup_config, &da_service) + .await; let ledger_db = self.create_ledger_db(&rollup_config); + let genesis_config = self.create_genesis_config(genesis_paths, &rollup_config)?; let storage_manager = self.create_storage_manager(&rollup_config)?;