diff --git a/core/src/banking_stage.rs b/core/src/banking_stage.rs index b55febb58c..461cc294a1 100644 --- a/core/src/banking_stage.rs +++ b/core/src/banking_stage.rs @@ -2784,6 +2784,10 @@ mod tests { .. } = create_slow_genesis_config(lamports); let bank = Arc::new(Bank::new_no_wallclock_throttle_for_tests(&genesis_config)); + // set cost tracker limits to MAX so it will not filter out TXs + bank.write_cost_tracker() + .unwrap() + .set_limits(std::u64::MAX, std::u64::MAX, std::u64::MAX); // Transfer more than the balance of the mint keypair, should cause a // InstructionError::InsufficientFunds that is then committed. Needs to be @@ -2840,6 +2844,10 @@ mod tests { .. } = create_slow_genesis_config(10_000); let bank = Arc::new(Bank::new_no_wallclock_throttle_for_tests(&genesis_config)); + // set cost tracker limits to MAX so it will not filter out TXs + bank.write_cost_tracker() + .unwrap() + .set_limits(std::u64::MAX, std::u64::MAX, std::u64::MAX); // Make all repetitive transactions that conflict on the `mint_keypair`, so only 1 should be executed let mut transactions = vec![ diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index 55fb0e738d..24df05121e 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -10,16 +10,14 @@ use { solana_runtime::{bank::Bank, cost_model::CostModel}, solana_sdk::timing::timestamp, std::{ - sync::{ - atomic::{AtomicBool, Ordering}, - mpsc::Receiver, - Arc, RwLock, - }, + sync::{mpsc::Receiver, Arc, RwLock}, thread::{self, Builder, JoinHandle}, - time::Duration, }, }; +// Update blockstore persistence storage when accumulated cost_table updates count exceeds the threshold +const PERSIST_THRESHOLD: u64 = 1_000; + #[derive(Default)] pub struct CostUpdateServiceTiming { last_print: u64, @@ -31,20 +29,25 @@ pub struct CostUpdateServiceTiming { impl CostUpdateServiceTiming { fn update( &mut self, - update_cost_model_count: u64, - update_cost_model_elapsed: u64, - persist_cost_table_elapsed: u64, + update_cost_model_count: Option, + update_cost_model_elapsed: Option, + persist_cost_table_elapsed: Option, ) { - self.update_cost_model_count += update_cost_model_count; - self.update_cost_model_elapsed += update_cost_model_elapsed; - self.persist_cost_table_elapsed += persist_cost_table_elapsed; + if let Some(update_cost_model_count) = update_cost_model_count { + self.update_cost_model_count += update_cost_model_count; + } + if let Some(update_cost_model_elapsed) = update_cost_model_elapsed { + self.update_cost_model_elapsed += update_cost_model_elapsed; + } + if let Some(persist_cost_table_elapsed) = persist_cost_table_elapsed { + self.persist_cost_table_elapsed += persist_cost_table_elapsed; + } let now = timestamp(); let elapsed_ms = now - self.last_print; if elapsed_ms > 1000 { datapoint_info!( "cost-update-service-stats", - ("total_elapsed_us", elapsed_ms * 1000, i64), ( "update_cost_model_count", self.update_cost_model_count as i64, @@ -86,7 +89,6 @@ pub struct CostUpdateService { impl CostUpdateService { #[allow(clippy::new_ret_no_self)] pub fn new( - exit: Arc, blockstore: Arc, cost_model: Arc>, cost_update_receiver: CostUpdateReceiver, @@ -94,7 +96,7 @@ impl CostUpdateService { let thread_hdl = Builder::new() .name("solana-cost-update-service".to_string()) .spawn(move || { - Self::service_loop(exit, blockstore, cost_model, cost_update_receiver); + Self::service_loop(blockstore, cost_model, cost_update_receiver); }) .unwrap(); @@ -106,118 +108,99 @@ impl CostUpdateService { } fn service_loop( - exit: Arc, blockstore: Arc, cost_model: Arc>, cost_update_receiver: CostUpdateReceiver, ) { let mut cost_update_service_timing = CostUpdateServiceTiming::default(); - let mut dirty: bool; - let mut update_count: u64; - let wait_timer = Duration::from_millis(100); + let mut update_count = 0_u64; - loop { - if exit.load(Ordering::Relaxed) { - break; - } + for cost_update in cost_update_receiver.iter() { + match cost_update { + CostUpdate::FrozenBank { bank } => { + bank.read_cost_tracker().unwrap().report_stats(bank.slot()); + } + CostUpdate::ExecuteTiming { + mut execute_timings, + } => { + let mut update_cost_model_time = Measure::start("update_cost_model_time"); + update_count += Self::update_cost_model(&cost_model, &mut execute_timings); + update_cost_model_time.stop(); + cost_update_service_timing.update( + Some(update_count), + Some(update_cost_model_time.as_us()), + None, + ); - dirty = false; - update_count = 0_u64; - let mut update_cost_model_time = Measure::start("update_cost_model_time"); - for cost_update in cost_update_receiver.try_iter() { - match cost_update { - CostUpdate::FrozenBank { bank } => { - bank.read_cost_tracker().unwrap().report_stats(bank.slot()); - } - CostUpdate::ExecuteTiming { - mut execute_timings, - } => { - dirty |= Self::update_cost_model(&cost_model, &mut execute_timings); - update_count += 1; + if update_count > PERSIST_THRESHOLD { + let mut persist_cost_table_time = Measure::start("persist_cost_table_time"); + Self::persist_cost_table(&blockstore, &cost_model); + update_count = 0_u64; + persist_cost_table_time.stop(); + cost_update_service_timing.update( + None, + None, + Some(persist_cost_table_time.as_us()), + ); } } } - update_cost_model_time.stop(); - - let mut persist_cost_table_time = Measure::start("persist_cost_table_time"); - if dirty { - Self::persist_cost_table(&blockstore, &cost_model); - } - persist_cost_table_time.stop(); - - cost_update_service_timing.update( - update_count, - update_cost_model_time.as_us(), - persist_cost_table_time.as_us(), - ); - - thread::sleep(wait_timer); } } + // Normalize `program_timings` with current estimated cost, update instruction_cost table + // Returns number of updates applied fn update_cost_model( cost_model: &RwLock, execute_timings: &mut ExecuteTimings, - ) -> bool { - let mut dirty = false; - { - for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { - let current_estimated_program_cost = - cost_model.read().unwrap().find_instruction_cost(program_id); - program_timings.coalesce_error_timings(current_estimated_program_cost); + ) -> u64 { + let mut update_count = 0_u64; + for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { + let current_estimated_program_cost = + cost_model.read().unwrap().find_instruction_cost(program_id); + program_timings.coalesce_error_timings(current_estimated_program_cost); - if program_timings.count < 1 { - continue; - } - - let units = program_timings.accumulated_units / program_timings.count as u64; - match cost_model - .write() - .unwrap() - .upsert_instruction_cost(program_id, units) - { - Ok(c) => { - debug!( - "after replayed into bank, instruction {:?} has averaged cost {}", - program_id, c - ); - dirty = true; - } - Err(err) => { - debug!( - "after replayed into bank, instruction {:?} failed to update cost, err: {}", - program_id, err - ); - } - } + if program_timings.count < 1 { + continue; } + + let units = program_timings.accumulated_units / program_timings.count as u64; + cost_model + .write() + .unwrap() + .upsert_instruction_cost(program_id, units); + update_count += 1; + debug!( + "After replayed into bank, updated cost for instruction {:?}, update_value {}, pre_aggregated_value {}", + program_id, units, current_estimated_program_cost + ); } - debug!( - "after replayed into bank, updated cost model instruction cost table, current values: {:?}", - cost_model.read().unwrap().get_instruction_cost_table() - ); - dirty + update_count } + // 1. Remove obsolete program entries from persisted table to limit its size + // 2. Update persisted program cost. This involves EMA cost calculation at + // execute_cost_table.get_cost() fn persist_cost_table(blockstore: &Blockstore, cost_model: &RwLock) { - let cost_model_read = cost_model.read().unwrap(); - let cost_table = cost_model_read.get_instruction_cost_table(); let db_records = blockstore.read_program_costs().expect("read programs"); + let cost_model = cost_model.read().unwrap(); + let active_program_keys = cost_model.get_program_keys(); // delete records from blockstore if they are no longer in cost_table db_records.iter().for_each(|(pubkey, _)| { - if cost_table.get(pubkey).is_none() { + if !active_program_keys.contains(&pubkey) { blockstore .delete_program_cost(pubkey) .expect("delete old program"); } }); - for (key, cost) in cost_table.iter() { + active_program_keys.iter().for_each(|program_id| { + let cost = cost_model.find_instruction_cost(program_id); blockstore - .write_program_cost(key, cost) + .write_program_cost(program_id, &cost) .expect("persist program costs to blockstore"); - } + }); } } @@ -229,15 +212,9 @@ mod tests { fn test_update_cost_model_with_empty_execute_timings() { let cost_model = Arc::new(RwLock::new(CostModel::default())); let mut empty_execute_timings = ExecuteTimings::default(); - CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings); - assert_eq!( - 0, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() + CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings), + 0 ); } @@ -255,7 +232,7 @@ mod tests { let accumulated_units: u64 = 100; let total_errored_units = 0; let count: u32 = 10; - expected_cost = accumulated_units / count as u64; + expected_cost = accumulated_units / count as u64; // = 10 execute_timings.details.per_program_timings.insert( program_key_1, @@ -267,22 +244,15 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let update_count = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, update_count); assert_eq!( - 1, + expected_cost, cost_model .read() .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&expected_cost), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + .find_instruction_cost(&program_key_1) ); } @@ -291,8 +261,8 @@ mod tests { let accumulated_us: u64 = 2000; let accumulated_units: u64 = 200; let count: u32 = 10; - // to expect new cost is Average(new_value, existing_value) - expected_cost = ((accumulated_units / count as u64) + expected_cost) / 2; + // to expect new cost = (mean + 2 * std) of [10, 20] + expected_cost = 13; execute_timings.details.per_program_timings.insert( program_key_1, @@ -304,22 +274,15 @@ mod tests { total_errored_units: 0, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let update_count = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, update_count); assert_eq!( - 1, + expected_cost, cost_model .read() .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&expected_cost), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + .find_instruction_cost(&program_key_1) ); } } @@ -343,20 +306,49 @@ mod tests { total_errored_units: 0, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); // If both the `errored_txs_compute_consumed` is empty and `count == 0`, then // nothing should be inserted into the cost model - assert!(cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .is_empty()); + assert_eq!( + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings), + 0 + ); + } + + // set up current instruction cost to 100 + let current_program_cost = 100; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: current_program_cost, + count: 1, + errored_txs_compute_consumed: vec![], + total_errored_units: 0, + }, + ); + let update_count = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, update_count); + assert_eq!( + current_program_cost, + cost_model + .read() + .unwrap() + .find_instruction_cost(&program_key_1) + ); } // Test updating cost model with only erroring compute costs where the `cost_per_error` is // greater than the current instruction cost for the program. Should update with the // new erroring compute costs let cost_per_error = 1000; + // expected_cost = (mean + 2*std) of data points: + // [ + // 100, // original program_cost + // 1000, // cost_per_error + // ] + let expected_cost = 289u64; { let errored_txs_compute_consumed = vec![cost_per_error; 3]; let total_errored_units = errored_txs_compute_consumed.iter().sum(); @@ -370,29 +362,23 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let update_count = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + + assert_eq!(1, update_count); assert_eq!( - 1, + expected_cost, cost_model .read() .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&cost_per_error), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + .find_instruction_cost(&program_key_1) ); } // Test updating cost model with only erroring compute costs where the error cost is // `smaller_cost_per_error`, less than the current instruction cost for the program. // The cost should not decrease for these new lesser errors - let smaller_cost_per_error = cost_per_error - 10; + let smaller_cost_per_error = expected_cost - 10; { let errored_txs_compute_consumed = vec![smaller_cost_per_error; 3]; let total_errored_units = errored_txs_compute_consumed.iter().sum(); @@ -406,22 +392,23 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let update_count = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + + // expected_cost = (mean = 2*std) of data points: + // [ + // 100, // original program cost, + // 1000, // cost_per_error from above test + // 289, // the smaller_cost_per_error will be coalesced to prev cost + // ] + let expected_cost = 293u64; + assert_eq!(1, update_count); assert_eq!( - 1, + expected_cost, cost_model .read() .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&cost_per_error), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + .find_instruction_cost(&program_key_1) ); } } diff --git a/core/src/tvu.rs b/core/src/tvu.rs index e161231e01..4203496ed5 100644 --- a/core/src/tvu.rs +++ b/core/src/tvu.rs @@ -309,12 +309,8 @@ impl Tvu { ); let (cost_update_sender, cost_update_receiver) = channel(); - let cost_update_service = CostUpdateService::new( - exit.clone(), - blockstore.clone(), - cost_model.clone(), - cost_update_receiver, - ); + let cost_update_service = + CostUpdateService::new(blockstore.clone(), cost_model.clone(), cost_update_receiver); let (drop_bank_sender, drop_bank_receiver) = channel(); diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index 6891e1e868..7625123256 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -8,7 +8,6 @@ use { crate::{block_cost_limits::*, execute_cost_table::ExecuteCostTable}, log::*, solana_sdk::{pubkey::Pubkey, transaction::SanitizedTransaction}, - std::collections::HashMap, }; const MAX_WRITABLE_ACCOUNTS: usize = 256; @@ -74,28 +73,9 @@ impl CostModel { .map(|(key, cost)| (key, cost)) .chain(BUILT_IN_INSTRUCTION_COSTS.iter()) .for_each(|(program_id, cost)| { - match self - .instruction_execution_cost_table - .upsert(program_id, *cost) - { - Some(c) => { - debug!( - "initiating cost table, instruction {:?} has cost {}", - program_id, c - ); - } - None => { - debug!( - "initiating cost table, failed for instruction {:?}", - program_id - ); - } - } + self.instruction_execution_cost_table + .upsert(program_id, *cost); }); - debug!( - "restored cost model instruction cost table from blockstore, current values: {:?}", - self.get_instruction_cost_table() - ); } pub fn calculate_cost(&self, transaction: &SanitizedTransaction) -> TransactionCost { @@ -110,30 +90,20 @@ impl CostModel { tx_cost } - pub fn upsert_instruction_cost( - &mut self, - program_key: &Pubkey, - cost: u64, - ) -> Result { + // update-or-insert op is always successful. However the result of upsert, eg the aggregated + // value, requires additional calculation, which should only be envoked when needed. + pub fn upsert_instruction_cost(&mut self, program_key: &Pubkey, cost: u64) { self.instruction_execution_cost_table .upsert(program_key, cost); - match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => Ok(*cost), - None => Err("failed to upsert to ExecuteCostTable"), - } - } - - pub fn get_instruction_cost_table(&self) -> &HashMap { - self.instruction_execution_cost_table.get_cost_table() } pub fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => *cost, + Some(cost) => cost, None => { - let default_value = self.instruction_execution_cost_table.get_mode(); + let default_value = self.instruction_execution_cost_table.get_default(); debug!( - "Program key {:?} does not have assigned cost, using mode {}", + "instruction {:?} does not have aggregated cost, using default {}", program_key, default_value ); default_value @@ -141,6 +111,10 @@ impl CostModel { } } + pub fn get_program_keys(&self) -> Vec<&Pubkey> { + self.instruction_execution_cost_table.get_program_keys() + } + fn get_signature_cost(&self, transaction: &SanitizedTransaction) -> u64 { transaction.signatures().len() as u64 * SIGNATURE_COST } @@ -207,6 +181,7 @@ mod tests { transaction::Transaction, }, std::{ + collections::HashMap, str::FromStr, sync::{Arc, RwLock}, thread::{self, JoinHandle}, @@ -230,24 +205,51 @@ mod tests { let mut testee = CostModel::default(); let known_key = Pubkey::from_str("known11111111111111111111111111111111111111").unwrap(); - testee.upsert_instruction_cost(&known_key, 100).unwrap(); + testee.upsert_instruction_cost(&known_key, 100); // find cost for known programs assert_eq!(100, testee.find_instruction_cost(&known_key)); - testee - .upsert_instruction_cost(&bpf_loader::id(), 1999) - .unwrap(); + testee.upsert_instruction_cost(&bpf_loader::id(), 1999); assert_eq!(1999, testee.find_instruction_cost(&bpf_loader::id())); // unknown program is assigned with default cost assert_eq!( - testee.instruction_execution_cost_table.get_mode(), + testee.instruction_execution_cost_table.get_default(), testee.find_instruction_cost( &Pubkey::from_str("unknown111111111111111111111111111111111111").unwrap() ) ); } + #[test] + fn test_iterating_instruction_cost_by_program_keys() { + solana_logger::setup(); + let mut testee = CostModel::default(); + + let mut test_key_and_cost = HashMap::::new(); + (0u64..10u64).for_each(|n| { + test_key_and_cost.insert(Pubkey::new_unique(), n); + }); + + test_key_and_cost.iter().for_each(|(key, cost)| { + testee.upsert_instruction_cost(key, *cost); + info!("key {:?} cost {}", key, cost); + }); + + let keys = testee.get_program_keys(); + // verify each key has pre-set value + keys.iter().for_each(|key| { + let expected_cost = test_key_and_cost.get(key).unwrap(); + info!( + "check key {:?} expect {} find {}", + key, + expected_cost, + testee.find_instruction_cost(key) + ); + assert_eq!(*expected_cost, testee.find_instruction_cost(key)); + }); + } + #[test] fn test_cost_model_simple_transaction() { let (mint_keypair, start_hash) = test_setup(); @@ -265,9 +267,7 @@ mod tests { let expected_cost = 8; let mut testee = CostModel::default(); - testee - .upsert_instruction_cost(&system_program::id(), expected_cost) - .unwrap(); + testee.upsert_instruction_cost(&system_program::id(), expected_cost); assert_eq!( expected_cost, testee.get_transaction_cost(&simple_transaction) @@ -295,9 +295,7 @@ mod tests { let expected_cost = program_cost * 2; let mut testee = CostModel::default(); - testee - .upsert_instruction_cost(&system_program::id(), program_cost) - .unwrap(); + testee.upsert_instruction_cost(&system_program::id(), program_cost); assert_eq!(expected_cost, testee.get_transaction_cost(&tx)); } @@ -329,7 +327,7 @@ mod tests { let result = testee.get_transaction_cost(&tx); // expected cost for two random/unknown program is - let expected_cost = testee.instruction_execution_cost_table.get_mode() * 2; + let expected_cost = testee.instruction_execution_cost_table.get_default() * 2; assert_eq!(expected_cost, result); } @@ -373,12 +371,12 @@ mod tests { let mut cost_model = CostModel::default(); // Using default cost for unknown instruction assert_eq!( - cost_model.instruction_execution_cost_table.get_mode(), + cost_model.instruction_execution_cost_table.get_default(), cost_model.find_instruction_cost(&key1) ); // insert instruction cost to table - assert!(cost_model.upsert_instruction_cost(&key1, cost1).is_ok()); + cost_model.upsert_instruction_cost(&key1, cost1); // now it is known insturction with known cost assert_eq!(cost1, cost_model.find_instruction_cost(&key1)); @@ -398,9 +396,7 @@ mod tests { let expected_execution_cost = 8; let mut cost_model = CostModel::default(); - cost_model - .upsert_instruction_cost(&system_program::id(), expected_execution_cost) - .unwrap(); + cost_model.upsert_instruction_cost(&system_program::id(), expected_execution_cost); let tx_cost = cost_model.calculate_cost(&tx); assert_eq!(expected_account_cost, tx_cost.write_lock_cost); assert_eq!(expected_execution_cost, tx_cost.execution_cost); @@ -412,16 +408,17 @@ mod tests { let key1 = Pubkey::new_unique(); let cost1 = 100; let cost2 = 200; - let updated_cost = (cost1 + cost2) / 2; + // updated_cost = (mean + 2*std) of [100, 200] => 120.899 + let updated_cost = 121; let mut cost_model = CostModel::default(); // insert instruction cost to table - assert!(cost_model.upsert_instruction_cost(&key1, cost1).is_ok()); + cost_model.upsert_instruction_cost(&key1, cost1); assert_eq!(cost1, cost_model.find_instruction_cost(&key1)); // update instruction cost - assert!(cost_model.upsert_instruction_cost(&key1, cost2).is_ok()); + cost_model.upsert_instruction_cost(&key1, cost2); assert_eq!(updated_cost, cost_model.find_instruction_cost(&key1)); } @@ -463,8 +460,8 @@ mod tests { if i == 5 { thread::spawn(move || { let mut cost_model = cost_model.write().unwrap(); - assert!(cost_model.upsert_instruction_cost(&prog1, cost1).is_ok()); - assert!(cost_model.upsert_instruction_cost(&prog2, cost2).is_ok()); + cost_model.upsert_instruction_cost(&prog1, cost1); + cost_model.upsert_instruction_cost(&prog2, cost2); }) } else { thread::spawn(move || { diff --git a/runtime/src/execute_cost_table.rs b/runtime/src/execute_cost_table.rs index c1ef3449cb..d45bce0ffb 100644 --- a/runtime/src/execute_cost_table.rs +++ b/runtime/src/execute_cost_table.rs @@ -4,7 +4,10 @@ /// When its capacity limit is reached, it prunes old and less-used programs /// to make room for new ones. use log::*; -use {solana_sdk::pubkey::Pubkey, std::collections::HashMap}; +use { + solana_sdk::pubkey::Pubkey, + std::collections::{hash_map::Entry, HashMap}, +}; // prune is rather expensive op, free up bulk space in each operation // would be more efficient. PRUNE_RATIO defines the after prune table @@ -15,10 +18,22 @@ const OCCURRENCES_WEIGHT: i64 = 100; const DEFAULT_CAPACITY: usize = 1024; -#[derive(AbiExample, Debug)] +// The coefficient represents the degree of weighting decrease in EMA, +// a constant smoothing factor between 0 and 1. A higher alpha +// discounts older observations faster. +// Setting it using `2/(N+1)` where N is 200 samples +const COEFFICIENT: f64 = 0.01; + +#[derive(Debug, Default)] +struct AggregatedVarianceStats { + ema: f64, + ema_var: f64, +} + +#[derive(Debug)] pub struct ExecuteCostTable { capacity: usize, - table: HashMap, + table: HashMap, occurrences: HashMap, } @@ -37,55 +52,59 @@ impl ExecuteCostTable { } } - pub fn get_cost_table(&self) -> &HashMap { - &self.table - } - + // number of programs in table pub fn get_count(&self) -> usize { self.table.len() } - // instead of assigning unknown program with a configured/hard-coded cost - // use average or mode function to make a educated guess. - pub fn get_average(&self) -> u64 { - if self.table.is_empty() { - 0 - } else { - self.table.iter().map(|(_, value)| value).sum::() / self.get_count() as u64 - } - } - - pub fn get_mode(&self) -> u64 { - if self.occurrences.is_empty() { - 0 - } else { - let key = self - .occurrences - .iter() - .max_by_key(|&(_, count)| count) - .map(|(key, _)| key) - .expect("cannot find mode from cost table"); - - *self.table.get(key).unwrap() - } + // default program cost to max + pub fn get_default(&self) -> u64 { + // default max compute units per program + 200_000u64 } // returns None if program doesn't exist in table. In this case, - // client is advised to call `get_average()` or `get_mode()` to - // assign a 'default' value for new program. - pub fn get_cost(&self, key: &Pubkey) -> Option<&u64> { - self.table.get(key) + // it is advised to call `get_default()` for default program cost. + // Program cost is estimated as 2 standard deviations above mean, eg + // cost = (mean + 2 * std) + pub fn get_cost(&self, key: &Pubkey) -> Option { + let aggregated = self.table.get(key)?; + let cost_f64 = (aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil(); + + // check if cost:f64 can be losslessly convert to u64, otherwise return None + let cost_u64 = cost_f64 as u64; + if cost_f64 == cost_u64 as f64 { + Some(cost_u64) + } else { + None + } } - pub fn upsert(&mut self, key: &Pubkey, value: u64) -> Option { - let need_to_add = self.table.get(key).is_none(); + pub fn upsert(&mut self, key: &Pubkey, value: u64) { + let need_to_add = !self.table.contains_key(key); let current_size = self.get_count(); if current_size == self.capacity && need_to_add { self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize)); } - let program_cost = self.table.entry(*key).or_insert(value); - *program_cost = (*program_cost + value) / 2; + // exponential moving average algorithm + // https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation + match self.table.entry(*key) { + Entry::Occupied(mut entry) => { + let aggregated = entry.get_mut(); + let theta = value as f64 - aggregated.ema; + aggregated.ema += theta * COEFFICIENT; + aggregated.ema_var = + (1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta); + } + Entry::Vacant(entry) => { + // the starting values + entry.insert(AggregatedVarianceStats { + ema: value as f64, + ema_var: 0.0, + }); + } + } let (count, timestamp) = self .occurrences @@ -93,8 +112,10 @@ impl ExecuteCostTable { .or_insert((0, Self::micros_since_epoch())); *count += 1; *timestamp = Self::micros_since_epoch(); + } - Some(*program_cost) + pub fn get_program_keys(&self) -> Vec<&Pubkey> { + self.table.keys().collect() } // prune the old programs so the table contains `new_size` of records, @@ -184,9 +205,9 @@ mod tests { let key2 = Pubkey::new_unique(); let key3 = Pubkey::new_unique(); - // simulate a lot of occurences to key1, so even there're longer than + // simulate a lot of occurrences to key1, so even there're longer than // usual delay between upsert(key1..) and upsert(key2, ..), test - // would still satisfy as key1 has enough occurences to compensate + // would still satisfy as key1 has enough occurrences to compensate // its age. for i in 0..1000 { testee.upsert(&key1, i); @@ -219,25 +240,21 @@ mod tests { // insert one record testee.upsert(&key1, cost1); assert_eq!(1, testee.get_count()); - assert_eq!(cost1, testee.get_average()); - assert_eq!(cost1, testee.get_mode()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); // insert 2nd record testee.upsert(&key2, cost2); assert_eq!(2, testee.get_count()); - assert_eq!((cost1 + cost2) / 2_u64, testee.get_average()); - assert_eq!(cost2, testee.get_mode()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); // update 1st record testee.upsert(&key1, cost2); assert_eq!(2, testee.get_count()); - assert_eq!(((cost1 + cost2) / 2 + cost2) / 2, testee.get_average()); - assert_eq!((cost1 + cost2) / 2, testee.get_mode()); - assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + // expected key1 cost is EMA of [100, 110] with alpha=0.01 => 103 + let expected_cost = 103; + assert_eq!(expected_cost, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); } #[test] @@ -258,33 +275,50 @@ mod tests { // insert one record testee.upsert(&key1, cost1); assert_eq!(1, testee.get_count()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); // insert 2nd record testee.upsert(&key2, cost2); assert_eq!(2, testee.get_count()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); // insert 3rd record, pushes out the oldest (eg 1st) record testee.upsert(&key3, cost3); assert_eq!(2, testee.get_count()); - assert_eq!((cost2 + cost3) / 2_u64, testee.get_average()); - assert_eq!(cost3, testee.get_mode()); assert!(testee.get_cost(&key1).is_none()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); - assert_eq!(&cost3, testee.get_cost(&key3).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost3, testee.get_cost(&key3).unwrap()); // update 2nd record, so the 3rd becomes the oldest // add 4th record, pushes out 3rd key testee.upsert(&key2, cost1); testee.upsert(&key4, cost4); - assert_eq!(((cost1 + cost2) / 2 + cost4) / 2_u64, testee.get_average()); - assert_eq!((cost1 + cost2) / 2, testee.get_mode()); assert_eq!(2, testee.get_count()); assert!(testee.get_cost(&key1).is_none()); - assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key2).unwrap()); + // expected key2 cost = (mean + 2*std) of [110, 100] => 112 + let expected_cost_2 = 112; + assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap()); assert!(testee.get_cost(&key3).is_none()); - assert_eq!(&cost4, testee.get_cost(&key4).unwrap()); + assert_eq!(cost4, testee.get_cost(&key4).unwrap()); + } + + #[test] + fn test_get_cost_overflow_u64() { + solana_logger::setup(); + let mut testee = ExecuteCostTable::default(); + + let key1 = Pubkey::new_unique(); + let cost1: u64 = f64::MAX as u64; + let cost2: u64 = u64::MAX / 2; // create large variance so the final result will overflow + + // insert one record + testee.upsert(&key1, cost1); + assert_eq!(1, testee.get_count()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); + + // update cost + testee.upsert(&key1, cost2); + assert!(testee.get_cost(&key1).is_none()); } }