diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index 662348b949..900cce52cb 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -9,8 +9,9 @@ use { solana_measure::measure::Measure, solana_program_runtime::timings::ExecuteTimings, solana_runtime::{bank::Bank, cost_model::CostModel}, - solana_sdk::timing::timestamp, + solana_sdk::{pubkey::Pubkey, timing::timestamp}, std::{ + collections::HashMap, sync::{ atomic::{AtomicBool, Ordering}, Arc, RwLock, @@ -112,8 +113,8 @@ impl CostUpdateService { cost_update_receiver: CostUpdateReceiver, ) { let mut cost_update_service_timing = CostUpdateServiceTiming::default(); - let mut dirty: bool; let mut update_count: u64; + let mut updated_program_costs = HashMap::::new(); let wait_timer = Duration::from_millis(100); loop { @@ -121,7 +122,6 @@ impl CostUpdateService { break; } - dirty = false; update_count = 0_u64; let mut update_cost_model_time = Measure::start("update_cost_model_time"); for cost_update in cost_update_receiver.try_iter() { @@ -132,7 +132,8 @@ impl CostUpdateService { CostUpdate::ExecuteTiming { mut execute_timings, } => { - dirty |= Self::update_cost_model(&cost_model, &mut execute_timings); + updated_program_costs = + Self::update_cost_model(&cost_model, &mut execute_timings); update_count += 1; } } @@ -140,9 +141,7 @@ impl CostUpdateService { update_cost_model_time.stop(); let mut persist_cost_table_time = Measure::start("persist_cost_table_time"); - if dirty { - Self::persist_cost_table(&blockstore, &cost_model); - } + Self::persist_cost_table(&blockstore, &updated_program_costs); persist_cost_table_time.stop(); cost_update_service_timing.update( @@ -158,62 +157,58 @@ impl CostUpdateService { fn update_cost_model( cost_model: &RwLock, execute_timings: &mut ExecuteTimings, - ) -> bool { - let mut dirty = false; - { - for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { - let current_estimated_program_cost = - cost_model.read().unwrap().find_instruction_cost(program_id); - program_timings.coalesce_error_timings(current_estimated_program_cost); + ) -> HashMap { + let mut updated_program_costs = HashMap::::new(); + for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { + let current_estimated_program_cost = + cost_model.read().unwrap().find_instruction_cost(program_id); + program_timings.coalesce_error_timings(current_estimated_program_cost); - if program_timings.count < 1 { - continue; + if program_timings.count < 1 { + continue; + } + + let units = program_timings.accumulated_units / program_timings.count as u64; + match cost_model + .write() + .unwrap() + .upsert_instruction_cost(program_id, units) + { + Ok(cost) => { + debug!( + "after replayed into bank, instruction {:?} has averaged cost {}", + program_id, cost + ); + updated_program_costs.insert(*program_id, cost); } - - let units = program_timings.accumulated_units / program_timings.count as u64; - match cost_model - .write() - .unwrap() - .upsert_instruction_cost(program_id, units) - { - Ok(c) => { - debug!( - "after replayed into bank, instruction {:?} has averaged cost {}", - program_id, c - ); - dirty = true; - } - Err(err) => { - debug!( + Err(err) => { + debug!( "after replayed into bank, instruction {:?} failed to update cost, err: {}", program_id, err ); - } } } } - debug!( - "after replayed into bank, updated cost model instruction cost table, current values: {:?}", - cost_model.read().unwrap().get_instruction_cost_table() - ); - dirty + updated_program_costs } - fn persist_cost_table(blockstore: &Blockstore, cost_model: &RwLock) { - let cost_model_read = cost_model.read().unwrap(); - let cost_table = cost_model_read.get_instruction_cost_table(); + fn persist_cost_table(blockstore: &Blockstore, updated_program_costs: &HashMap) { + if updated_program_costs.is_empty() { + return; + } + let db_records = blockstore.read_program_costs().expect("read programs"); // delete records from blockstore if they are no longer in cost_table db_records.iter().for_each(|(pubkey, _)| { - if cost_table.get(pubkey).is_none() { + if !updated_program_costs.contains_key(pubkey) { blockstore .delete_program_cost(pubkey) .expect("delete old program"); } }); - for (key, cost) in cost_table.iter() { + for (key, cost) in updated_program_costs.iter() { blockstore .write_program_cost(key, cost) .expect("persist program costs to blockstore"); @@ -229,15 +224,9 @@ mod tests { fn test_update_cost_model_with_empty_execute_timings() { let cost_model = Arc::new(RwLock::new(CostModel::default())); let mut empty_execute_timings = ExecuteTimings::default(); - CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings); - - assert_eq!( - 0, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() + assert!( + CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings) + .is_empty() ); } @@ -255,7 +244,7 @@ mod tests { let accumulated_units: u64 = 100; let total_errored_units = 0; let count: u32 = 10; - expected_cost = accumulated_units / count as u64; + expected_cost = accumulated_units / count as u64; // = 10 execute_timings.details.per_program_timings.insert( program_key_1, @@ -267,22 +256,12 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); - assert_eq!( - 1, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() - ); + let updated_program_costs = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + updated_program_costs.get(&program_key_1) ); } @@ -291,8 +270,9 @@ mod tests { let accumulated_us: u64 = 2000; let accumulated_units: u64 = 200; let count: u32 = 10; - // to expect new cost is Average(new_value, existing_value) - expected_cost = ((accumulated_units / count as u64) + expected_cost) / 2; + // to expect new cost = (mean + 2 * std) of [10, 20] = 25, where + // mean = (10+20)/2 = 15; std=5 + expected_cost = 25; execute_timings.details.per_program_timings.insert( program_key_1, @@ -304,22 +284,12 @@ mod tests { total_errored_units: 0, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); - assert_eq!( - 1, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() - ); + let updated_program_costs = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + updated_program_costs.get(&program_key_1) ); } } @@ -343,14 +313,33 @@ mod tests { total_errored_units: 0, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); // If both the `errored_txs_compute_consumed` is empty and `count == 0`, then // nothing should be inserted into the cost model - assert!(cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .is_empty()); + assert!( + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings).is_empty() + ); + } + + // set up current instruction cost to 100 + let current_program_cost = 100; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: current_program_cost, + count: 1, + errored_txs_compute_consumed: vec![], + total_errored_units: 0, + }, + ); + let updated_program_costs = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!(1, updated_program_costs.len()); + assert_eq!( + Some(¤t_program_cost), + updated_program_costs.get(&program_key_1) + ); } // Test updating cost model with only erroring compute costs where the `cost_per_error` is @@ -370,22 +359,19 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let updated_program_costs = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + + // expected_cost = (mean + 2*std) of data points: + // [ + // 100, // original program_cost + // 1000, // cost_per_error + // ] + let expected_cost = 1450u64; + assert_eq!(1, updated_program_costs.len()); assert_eq!( - 1, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&cost_per_error), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + Some(&expected_cost), + updated_program_costs.get(&program_key_1) ); } @@ -406,22 +392,20 @@ mod tests { total_errored_units, }, ); - CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + let updated_program_costs = + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + + // expected_cost = (mean = 2*std) of data points: + // [ + // 100, // original program cost, + // 1000, // cost_per_error from above test + // 1450, // the smaller_cost_per_error will be coalesced to prev cost + // ] + let expected_cost = 1973u64; + assert_eq!(1, updated_program_costs.len()); assert_eq!( - 1, - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .len() - ); - assert_eq!( - Some(&cost_per_error), - cost_model - .read() - .unwrap() - .get_instruction_cost_table() - .get(&program_key_1) + Some(&expected_cost), + updated_program_costs.get(&program_key_1) ); } } diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index cd04a22732..b500480294 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -11,7 +11,6 @@ use { instruction::CompiledInstruction, program_utils::limited_deserialize, pubkey::Pubkey, system_instruction::SystemInstruction, system_program, transaction::SanitizedTransaction, }, - std::collections::HashMap, }; const MAX_WRITABLE_ACCOUNTS: usize = 256; @@ -79,28 +78,9 @@ impl CostModel { .map(|(key, cost)| (key, cost)) .chain(BUILT_IN_INSTRUCTION_COSTS.iter()) .for_each(|(program_id, cost)| { - match self - .instruction_execution_cost_table - .upsert(program_id, *cost) - { - Some(c) => { - debug!( - "initiating cost table, instruction {:?} has cost {}", - program_id, c - ); - } - None => { - debug!( - "initiating cost table, failed for instruction {:?}", - program_id - ); - } - } + self.instruction_execution_cost_table + .upsert(program_id, *cost); }); - debug!( - "restored cost model instruction cost table from blockstore, current values: {:?}", - self.get_instruction_cost_table() - ); } pub fn calculate_cost(&self, transaction: &SanitizedTransaction) -> TransactionCost { @@ -124,22 +104,18 @@ impl CostModel { self.instruction_execution_cost_table .upsert(program_key, cost); match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => Ok(*cost), + Some(cost) => Ok(cost), None => Err("failed to upsert to ExecuteCostTable"), } } - pub fn get_instruction_cost_table(&self) -> &HashMap { - self.instruction_execution_cost_table.get_cost_table() - } - pub fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => *cost, + Some(cost) => cost, None => { - let default_value = self.instruction_execution_cost_table.get_mode(); + let default_value = self.instruction_execution_cost_table.get_default(); debug!( - "Program key {:?} does not have assigned cost, using mode {}", + "Program key {:?} does not have assigned cost, using default value {}", program_key, default_value ); default_value @@ -304,7 +280,7 @@ mod tests { // unknown program is assigned with default cost assert_eq!( - testee.instruction_execution_cost_table.get_mode(), + testee.instruction_execution_cost_table.get_default(), testee.find_instruction_cost( &Pubkey::from_str("unknown111111111111111111111111111111111111").unwrap() ) @@ -439,7 +415,7 @@ mod tests { let result = testee.get_transaction_cost(&tx); // expected cost for two random/unknown program is - let expected_cost = testee.instruction_execution_cost_table.get_mode() * 2; + let expected_cost = testee.instruction_execution_cost_table.get_default() * 2; assert_eq!(expected_cost, result); } @@ -483,7 +459,7 @@ mod tests { let mut cost_model = CostModel::default(); // Using default cost for unknown instruction assert_eq!( - cost_model.instruction_execution_cost_table.get_mode(), + cost_model.instruction_execution_cost_table.get_default(), cost_model.find_instruction_cost(&key1) ); @@ -522,7 +498,8 @@ mod tests { let key1 = Pubkey::new_unique(); let cost1 = 100; let cost2 = 200; - let updated_cost = (cost1 + cost2) / 2; + // updated_cost = (mean + 2*std) = 150 + 2 * 50 = 250 + let updated_cost = 250; let mut cost_model = CostModel::default(); diff --git a/runtime/src/execute_cost_table.rs b/runtime/src/execute_cost_table.rs index c1ef3449cb..3b01f0fedb 100644 --- a/runtime/src/execute_cost_table.rs +++ b/runtime/src/execute_cost_table.rs @@ -15,10 +15,17 @@ const OCCURRENCES_WEIGHT: i64 = 100; const DEFAULT_CAPACITY: usize = 1024; -#[derive(AbiExample, Debug)] +#[derive(Debug, Default)] +struct AggregatedVarianceStats { + count: u64, + mean: f64, + squared_mean_distance: f64, +} + +#[derive(Debug)] pub struct ExecuteCostTable { capacity: usize, - table: HashMap, + table: HashMap, occurrences: HashMap, } @@ -37,55 +44,50 @@ impl ExecuteCostTable { } } - pub fn get_cost_table(&self) -> &HashMap { - &self.table - } - + // number of programs in table pub fn get_count(&self) -> usize { self.table.len() } - // instead of assigning unknown program with a configured/hard-coded cost - // use average or mode function to make a educated guess. - pub fn get_average(&self) -> u64 { - if self.table.is_empty() { - 0 - } else { - self.table.iter().map(|(_, value)| value).sum::() / self.get_count() as u64 - } - } - - pub fn get_mode(&self) -> u64 { - if self.occurrences.is_empty() { - 0 - } else { - let key = self - .occurrences - .iter() - .max_by_key(|&(_, count)| count) - .map(|(key, _)| key) - .expect("cannot find mode from cost table"); - - *self.table.get(key).unwrap() - } + // default prorgam cost to max + pub fn get_default(&self) -> u64 { + // default max comoute units per program + 200_000u64 } // returns None if program doesn't exist in table. In this case, - // client is advised to call `get_average()` or `get_mode()` to - // assign a 'default' value for new program. - pub fn get_cost(&self, key: &Pubkey) -> Option<&u64> { - self.table.get(key) + // it is advised to call `get_default()` for default program costdefault/ + // using Welford's Algorithm to calculate mean and std: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + // Program cost is estimated as 2 standard deviations above mean, eg + // cost = (mean + 2 * std) + pub fn get_cost(&self, key: &Pubkey) -> Option { + let aggregated = self.table.get(key)?; + if aggregated.count < 1 { + None + } else { + let variance = aggregated.squared_mean_distance / aggregated.count as f64; + Some((aggregated.mean + 2.0 * variance.sqrt()).ceil() as u64) + } } - pub fn upsert(&mut self, key: &Pubkey, value: u64) -> Option { - let need_to_add = self.table.get(key).is_none(); + pub fn upsert(&mut self, key: &Pubkey, value: u64) { + let need_to_add = !self.table.contains_key(key); let current_size = self.get_count(); if current_size == self.capacity && need_to_add { self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize)); } - let program_cost = self.table.entry(*key).or_insert(value); - *program_cost = (*program_cost + value) / 2; + // Welford's algorithm + let aggregated = self + .table + .entry(*key) + .or_insert_with(AggregatedVarianceStats::default); + aggregated.count += 1; + let delta = value as f64 - aggregated.mean; + aggregated.mean += delta / aggregated.count as f64; + let delta_2 = value as f64 - aggregated.mean; + aggregated.squared_mean_distance += delta * delta_2; let (count, timestamp) = self .occurrences @@ -93,8 +95,6 @@ impl ExecuteCostTable { .or_insert((0, Self::micros_since_epoch())); *count += 1; *timestamp = Self::micros_since_epoch(); - - Some(*program_cost) } // prune the old programs so the table contains `new_size` of records, @@ -219,25 +219,21 @@ mod tests { // insert one record testee.upsert(&key1, cost1); assert_eq!(1, testee.get_count()); - assert_eq!(cost1, testee.get_average()); - assert_eq!(cost1, testee.get_mode()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); // insert 2nd record testee.upsert(&key2, cost2); assert_eq!(2, testee.get_count()); - assert_eq!((cost1 + cost2) / 2_u64, testee.get_average()); - assert_eq!(cost2, testee.get_mode()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); // update 1st record testee.upsert(&key1, cost2); assert_eq!(2, testee.get_count()); - assert_eq!(((cost1 + cost2) / 2 + cost2) / 2, testee.get_average()); - assert_eq!((cost1 + cost2) / 2, testee.get_mode()); - assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + // expected key1 cost = (mean + 2*std) = (105 + 2*5) = 115 + let expected_cost = 115; + assert_eq!(expected_cost, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); } #[test] @@ -258,33 +254,31 @@ mod tests { // insert one record testee.upsert(&key1, cost1); assert_eq!(1, testee.get_count()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); // insert 2nd record testee.upsert(&key2, cost2); assert_eq!(2, testee.get_count()); - assert_eq!(&cost1, testee.get_cost(&key1).unwrap()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost1, testee.get_cost(&key1).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); // insert 3rd record, pushes out the oldest (eg 1st) record testee.upsert(&key3, cost3); assert_eq!(2, testee.get_count()); - assert_eq!((cost2 + cost3) / 2_u64, testee.get_average()); - assert_eq!(cost3, testee.get_mode()); assert!(testee.get_cost(&key1).is_none()); - assert_eq!(&cost2, testee.get_cost(&key2).unwrap()); - assert_eq!(&cost3, testee.get_cost(&key3).unwrap()); + assert_eq!(cost2, testee.get_cost(&key2).unwrap()); + assert_eq!(cost3, testee.get_cost(&key3).unwrap()); // update 2nd record, so the 3rd becomes the oldest // add 4th record, pushes out 3rd key testee.upsert(&key2, cost1); testee.upsert(&key4, cost4); - assert_eq!(((cost1 + cost2) / 2 + cost4) / 2_u64, testee.get_average()); - assert_eq!((cost1 + cost2) / 2, testee.get_mode()); assert_eq!(2, testee.get_count()); assert!(testee.get_cost(&key1).is_none()); - assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key2).unwrap()); + // expected key2 cost = (mean + 2*std) = (105 + 2*5) = 115 + let expected_cost_2 = 115; + assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap()); assert!(testee.get_cost(&key3).is_none()); - assert_eq!(&cost4, testee.get_cost(&key4).unwrap()); + assert_eq!(cost4, testee.get_cost(&key4).unwrap()); } }