diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index a7f1571aca..900cce52cb 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -270,8 +270,9 @@ mod tests { let accumulated_us: u64 = 2000; let accumulated_units: u64 = 200; let count: u32 = 10; - // to expect new cost = (mean + 2 * std) - expected_cost = 24; + // to expect new cost = (mean + 2 * std) of [10, 20] = 25, where + // mean = (10+20)/2 = 15; std=5 + expected_cost = 25; execute_timings.details.per_program_timings.insert( program_key_1, @@ -366,7 +367,7 @@ mod tests { // 100, // original program_cost // 1000, // cost_per_error // ] - let expected_cost = 1342u64; + let expected_cost = 1450u64; assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), @@ -400,7 +401,7 @@ mod tests { // 1000, // cost_per_error from above test // 1450, // the smaller_cost_per_error will be coalesced to prev cost // ] - let expected_cost = 1915u64; + let expected_cost = 1973u64; assert_eq!(1, updated_program_costs.len()); assert_eq!( Some(&expected_cost), diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index 534dae0470..b500480294 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -498,8 +498,8 @@ mod tests { let key1 = Pubkey::new_unique(); let cost1 = 100; let cost2 = 200; - // updated_cost = (mean + 2*std) - let updated_cost = 238; + // updated_cost = (mean + 2*std) = 150 + 2 * 50 = 250 + let updated_cost = 250; let mut cost_model = CostModel::default(); diff --git a/runtime/src/execute_cost_table.rs b/runtime/src/execute_cost_table.rs index c779164609..3b01f0fedb 100644 --- a/runtime/src/execute_cost_table.rs +++ b/runtime/src/execute_cost_table.rs @@ -15,15 +15,11 @@ const OCCURRENCES_WEIGHT: i64 = 100; const DEFAULT_CAPACITY: usize = 1024; -// The coefficient represents the degree of weighting decrease in EMA, -// a constant smoothing factor between 0 and 1. A higher alpha -// discounts older observations faster. -const COEFFICIENT: f64 = 0.4; - #[derive(Debug, Default)] struct AggregatedVarianceStats { - ema: f64, - ema_var: f64, + count: u64, + mean: f64, + squared_mean_distance: f64, } #[derive(Debug)] @@ -61,11 +57,18 @@ impl ExecuteCostTable { // returns None if program doesn't exist in table. In this case, // it is advised to call `get_default()` for default program costdefault/ + // using Welford's Algorithm to calculate mean and std: + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm // Program cost is estimated as 2 standard deviations above mean, eg // cost = (mean + 2 * std) pub fn get_cost(&self, key: &Pubkey) -> Option { let aggregated = self.table.get(key)?; - Some((aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil() as u64) + if aggregated.count < 1 { + None + } else { + let variance = aggregated.squared_mean_distance / aggregated.count as f64; + Some((aggregated.mean + 2.0 * variance.sqrt()).ceil() as u64) + } } pub fn upsert(&mut self, key: &Pubkey, value: u64) { @@ -75,24 +78,16 @@ impl ExecuteCostTable { self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize)); } - // exponential moving average algorithm - // https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation - if self.table.contains_key(key) { - let aggregated = self.table.get_mut(key).unwrap(); - let theta = value as f64 - aggregated.ema; - aggregated.ema += theta * COEFFICIENT; - aggregated.ema_var = - (1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta) - } else { - // the starting values - self.table.insert( - *key, - AggregatedVarianceStats { - ema: value as f64, - ema_var: 0.0, - }, - ); - } + // Welford's algorithm + let aggregated = self + .table + .entry(*key) + .or_insert_with(AggregatedVarianceStats::default); + aggregated.count += 1; + let delta = value as f64 - aggregated.mean; + aggregated.mean += delta / aggregated.count as f64; + let delta_2 = value as f64 - aggregated.mean; + aggregated.squared_mean_distance += delta * delta_2; let (count, timestamp) = self .occurrences @@ -236,7 +231,7 @@ mod tests { testee.upsert(&key1, cost2); assert_eq!(2, testee.get_count()); // expected key1 cost = (mean + 2*std) = (105 + 2*5) = 115 - let expected_cost = 114; + let expected_cost = 115; assert_eq!(expected_cost, testee.get_cost(&key1).unwrap()); assert_eq!(cost2, testee.get_cost(&key2).unwrap()); } @@ -281,7 +276,7 @@ mod tests { assert_eq!(2, testee.get_count()); assert!(testee.get_cost(&key1).is_none()); // expected key2 cost = (mean + 2*std) = (105 + 2*5) = 115 - let expected_cost_2 = 116; + let expected_cost_2 = 115; assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap()); assert!(testee.get_cost(&key3).is_none()); assert_eq!(cost4, testee.get_cost(&key4).unwrap());