Revert exponential moving average cost model changes (backport #23541) (#23543)

* Revert "fix tests after merge"

This reverts commit ba2d83f580.

(cherry picked from commit 0a17edcc1f)

* Revert "1. Persist to blockstore less frequently;"

This reverts commit 7aa1fb4e24.

(cherry picked from commit c878c9e2cb)

# Conflicts:
#	core/src/cost_update_service.rs
#	core/src/tvu.rs
#	runtime/src/cost_model.rs

* Revert "use EMA in place of Welford"

This reverts commit 6587dbfa47.

(cherry picked from commit 9acbfa5eb1)

* Revert "- estimate a program cost as 2 standard deviation above mean"

This reverts commit a25ac1c988.

(cherry picked from commit 5a0cd05866)

# Conflicts:
#	core/src/cost_update_service.rs
#	runtime/src/cost_model.rs

* fix merge conflicts

Co-authored-by: Carl Lin <carl@solana.com>
Co-authored-by: Tao Zhu <tao@solana.com>
This commit is contained in:
mergify[bot]
2022-03-09 03:55:36 +00:00
committed by GitHub
parent 714cf0eff2
commit 4e5d9885da
5 changed files with 293 additions and 281 deletions

View File

@ -3102,10 +3102,6 @@ mod tests {
..
} = create_slow_genesis_config(lamports);
let bank = Arc::new(Bank::new_no_wallclock_throttle_for_tests(&genesis_config));
// set cost tracker limits to MAX so it will not filter out TXs
bank.write_cost_tracker()
.unwrap()
.set_limits(std::u64::MAX, std::u64::MAX, std::u64::MAX);
// Transfer more than the balance of the mint keypair, should cause a
// InstructionError::InsufficientFunds that is then committed. Needs to be
@ -3162,10 +3158,6 @@ mod tests {
..
} = create_slow_genesis_config(10_000);
let bank = Arc::new(Bank::new_no_wallclock_throttle_for_tests(&genesis_config));
// set cost tracker limits to MAX so it will not filter out TXs
bank.write_cost_tracker()
.unwrap()
.set_limits(std::u64::MAX, std::u64::MAX, std::u64::MAX);
// Make all repetitive transactions that conflict on the `mint_keypair`, so only 1 should be executed
let mut transactions = vec![

View File

@ -10,14 +10,13 @@ use {
solana_runtime::{bank::Bank, cost_model::CostModel},
solana_sdk::timing::timestamp,
std::{
sync::atomic::{AtomicBool, Ordering},
sync::{mpsc::Receiver, Arc, RwLock},
thread::{self, Builder, JoinHandle},
time::Duration,
},
};
// Update blockstore persistence storage when accumulated cost_table updates count exceeds the threshold
const PERSIST_THRESHOLD: u64 = 1_000;
#[derive(Default)]
pub struct CostUpdateServiceTiming {
last_print: u64,
@ -29,25 +28,20 @@ pub struct CostUpdateServiceTiming {
impl CostUpdateServiceTiming {
fn update(
&mut self,
update_cost_model_count: Option<u64>,
update_cost_model_elapsed: Option<u64>,
persist_cost_table_elapsed: Option<u64>,
update_cost_model_count: u64,
update_cost_model_elapsed: u64,
persist_cost_table_elapsed: u64,
) {
if let Some(update_cost_model_count) = update_cost_model_count {
self.update_cost_model_count += update_cost_model_count;
}
if let Some(update_cost_model_elapsed) = update_cost_model_elapsed {
self.update_cost_model_elapsed += update_cost_model_elapsed;
}
if let Some(persist_cost_table_elapsed) = persist_cost_table_elapsed {
self.persist_cost_table_elapsed += persist_cost_table_elapsed;
}
let now = timestamp();
let elapsed_ms = now - self.last_print;
if elapsed_ms > 1000 {
datapoint_info!(
"cost-update-service-stats",
("total_elapsed_us", elapsed_ms * 1000, i64),
(
"update_cost_model_count",
self.update_cost_model_count as i64,
@ -89,6 +83,7 @@ pub struct CostUpdateService {
impl CostUpdateService {
#[allow(clippy::new_ret_no_self)]
pub fn new(
exit: Arc<AtomicBool>,
blockstore: Arc<Blockstore>,
cost_model: Arc<RwLock<CostModel>>,
cost_update_receiver: CostUpdateReceiver,
@ -96,7 +91,7 @@ impl CostUpdateService {
let thread_hdl = Builder::new()
.name("solana-cost-update-service".to_string())
.spawn(move || {
Self::service_loop(blockstore, cost_model, cost_update_receiver);
Self::service_loop(exit, blockstore, cost_model, cost_update_receiver);
})
.unwrap();
@ -108,14 +103,25 @@ impl CostUpdateService {
}
fn service_loop(
exit: Arc<AtomicBool>,
blockstore: Arc<Blockstore>,
cost_model: Arc<RwLock<CostModel>>,
cost_update_receiver: CostUpdateReceiver,
) {
let mut cost_update_service_timing = CostUpdateServiceTiming::default();
let mut update_count = 0_u64;
let mut dirty: bool;
let mut update_count: u64;
let wait_timer = Duration::from_millis(100);
for cost_update in cost_update_receiver.iter() {
loop {
if exit.load(Ordering::Relaxed) {
break;
}
dirty = false;
update_count = 0_u64;
let mut update_cost_model_time = Measure::start("update_cost_model_time");
for cost_update in cost_update_receiver.try_iter() {
match cost_update {
CostUpdate::FrozenBank { bank } => {
bank.read_cost_tracker().unwrap().report_stats(bank.slot());
@ -123,38 +129,35 @@ impl CostUpdateService {
CostUpdate::ExecuteTiming {
mut execute_timings,
} => {
let mut update_cost_model_time = Measure::start("update_cost_model_time");
update_count += Self::update_cost_model(&cost_model, &mut execute_timings);
dirty |= Self::update_cost_model(&cost_model, &mut execute_timings);
update_count += 1;
}
}
}
update_cost_model_time.stop();
cost_update_service_timing.update(
Some(update_count),
Some(update_cost_model_time.as_us()),
None,
);
if update_count > PERSIST_THRESHOLD {
let mut persist_cost_table_time = Measure::start("persist_cost_table_time");
if dirty {
Self::persist_cost_table(&blockstore, &cost_model);
update_count = 0_u64;
}
persist_cost_table_time.stop();
cost_update_service_timing.update(
None,
None,
Some(persist_cost_table_time.as_us()),
update_count,
update_cost_model_time.as_us(),
persist_cost_table_time.as_us(),
);
}
}
}
thread::sleep(wait_timer);
}
}
// Normalize `program_timings` with current estimated cost, update instruction_cost table
// Returns number of updates applied
fn update_cost_model(
cost_model: &RwLock<CostModel>,
execute_timings: &mut ExecuteTimings,
) -> u64 {
let mut update_count = 0_u64;
) -> bool {
let mut dirty = false;
{
for (program_id, program_timings) in &mut execute_timings.details.per_program_timings {
let current_estimated_program_cost =
cost_model.read().unwrap().find_instruction_cost(program_id);
@ -165,42 +168,53 @@ impl CostUpdateService {
}
let units = program_timings.accumulated_units / program_timings.count as u64;
cost_model
match cost_model
.write()
.unwrap()
.upsert_instruction_cost(program_id, units);
update_count += 1;
.upsert_instruction_cost(program_id, units)
{
Ok(c) => {
debug!(
"After replayed into bank, updated cost for instruction {:?}, update_value {}, pre_aggregated_value {}",
program_id, units, current_estimated_program_cost
"after replayed into bank, instruction {:?} has averaged cost {}",
program_id, c
);
dirty = true;
}
Err(err) => {
debug!(
"after replayed into bank, instruction {:?} failed to update cost, err: {}",
program_id, err
);
}
update_count
}
}
}
debug!(
"after replayed into bank, updated cost model instruction cost table, current values: {:?}",
cost_model.read().unwrap().get_instruction_cost_table()
);
dirty
}
// 1. Remove obsolete program entries from persisted table to limit its size
// 2. Update persisted program cost. This involves EMA cost calculation at
// execute_cost_table.get_cost()
fn persist_cost_table(blockstore: &Blockstore, cost_model: &RwLock<CostModel>) {
let cost_model_read = cost_model.read().unwrap();
let cost_table = cost_model_read.get_instruction_cost_table();
let db_records = blockstore.read_program_costs().expect("read programs");
let cost_model = cost_model.read().unwrap();
let active_program_keys = cost_model.get_program_keys();
// delete records from blockstore if they are no longer in cost_table
db_records.iter().for_each(|(pubkey, _)| {
if !active_program_keys.contains(&pubkey) {
if cost_table.get(pubkey).is_none() {
blockstore
.delete_program_cost(pubkey)
.expect("delete old program");
}
});
active_program_keys.iter().for_each(|program_id| {
let cost = cost_model.find_instruction_cost(program_id);
for (key, cost) in cost_table.iter() {
blockstore
.write_program_cost(program_id, &cost)
.write_program_cost(key, cost)
.expect("persist program costs to blockstore");
});
}
}
}
@ -212,9 +226,15 @@ mod tests {
fn test_update_cost_model_with_empty_execute_timings() {
let cost_model = Arc::new(RwLock::new(CostModel::default()));
let mut empty_execute_timings = ExecuteTimings::default();
CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings);
assert_eq!(
CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings),
0
0,
cost_model
.read()
.unwrap()
.get_instruction_cost_table()
.len()
);
}
@ -232,7 +252,7 @@ mod tests {
let accumulated_units: u64 = 100;
let total_errored_units = 0;
let count: u32 = 10;
expected_cost = accumulated_units / count as u64; // = 10
expected_cost = accumulated_units / count as u64;
execute_timings.details.per_program_timings.insert(
program_key_1,
@ -244,15 +264,22 @@ mod tests {
total_errored_units,
},
);
let update_count =
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
assert_eq!(1, update_count);
assert_eq!(
expected_cost,
1,
cost_model
.read()
.unwrap()
.find_instruction_cost(&program_key_1)
.get_instruction_cost_table()
.len()
);
assert_eq!(
Some(&expected_cost),
cost_model
.read()
.unwrap()
.get_instruction_cost_table()
.get(&program_key_1)
);
}
@ -261,8 +288,8 @@ mod tests {
let accumulated_us: u64 = 2000;
let accumulated_units: u64 = 200;
let count: u32 = 10;
// to expect new cost = (mean + 2 * std) of [10, 20]
expected_cost = 13;
// to expect new cost is Average(new_value, existing_value)
expected_cost = ((accumulated_units / count as u64) + expected_cost) / 2;
execute_timings.details.per_program_timings.insert(
program_key_1,
@ -274,15 +301,22 @@ mod tests {
total_errored_units: 0,
},
);
let update_count =
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
assert_eq!(1, update_count);
assert_eq!(
expected_cost,
1,
cost_model
.read()
.unwrap()
.find_instruction_cost(&program_key_1)
.get_instruction_cost_table()
.len()
);
assert_eq!(
Some(&expected_cost),
cost_model
.read()
.unwrap()
.get_instruction_cost_table()
.get(&program_key_1)
);
}
}
@ -306,49 +340,20 @@ mod tests {
total_errored_units: 0,
},
);
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
// If both the `errored_txs_compute_consumed` is empty and `count == 0`, then
// nothing should be inserted into the cost model
assert_eq!(
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings),
0
);
}
// set up current instruction cost to 100
let current_program_cost = 100;
{
execute_timings.details.per_program_timings.insert(
program_key_1,
ProgramTiming {
accumulated_us: 1000,
accumulated_units: current_program_cost,
count: 1,
errored_txs_compute_consumed: vec![],
total_errored_units: 0,
},
);
let update_count =
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
assert_eq!(1, update_count);
assert_eq!(
current_program_cost,
cost_model
assert!(cost_model
.read()
.unwrap()
.find_instruction_cost(&program_key_1)
);
.get_instruction_cost_table()
.is_empty());
}
// Test updating cost model with only erroring compute costs where the `cost_per_error` is
// greater than the current instruction cost for the program. Should update with the
// new erroring compute costs
let cost_per_error = 1000;
// expected_cost = (mean + 2*std) of data points:
// [
// 100, // original program_cost
// 1000, // cost_per_error
// ]
let expected_cost = 289u64;
{
let errored_txs_compute_consumed = vec![cost_per_error; 3];
let total_errored_units = errored_txs_compute_consumed.iter().sum();
@ -362,23 +367,29 @@ mod tests {
total_errored_units,
},
);
let update_count =
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
assert_eq!(1, update_count);
assert_eq!(
expected_cost,
1,
cost_model
.read()
.unwrap()
.find_instruction_cost(&program_key_1)
.get_instruction_cost_table()
.len()
);
assert_eq!(
Some(&cost_per_error),
cost_model
.read()
.unwrap()
.get_instruction_cost_table()
.get(&program_key_1)
);
}
// Test updating cost model with only erroring compute costs where the error cost is
// `smaller_cost_per_error`, less than the current instruction cost for the program.
// The cost should not decrease for these new lesser errors
let smaller_cost_per_error = expected_cost - 10;
let smaller_cost_per_error = cost_per_error - 10;
{
let errored_txs_compute_consumed = vec![smaller_cost_per_error; 3];
let total_errored_units = errored_txs_compute_consumed.iter().sum();
@ -392,23 +403,22 @@ mod tests {
total_errored_units,
},
);
let update_count =
CostUpdateService::update_cost_model(&cost_model, &mut execute_timings);
// expected_cost = (mean = 2*std) of data points:
// [
// 100, // original program cost,
// 1000, // cost_per_error from above test
// 289, // the smaller_cost_per_error will be coalesced to prev cost
// ]
let expected_cost = 293u64;
assert_eq!(1, update_count);
assert_eq!(
expected_cost,
1,
cost_model
.read()
.unwrap()
.find_instruction_cost(&program_key_1)
.get_instruction_cost_table()
.len()
);
assert_eq!(
Some(&cost_per_error),
cost_model
.read()
.unwrap()
.get_instruction_cost_table()
.get(&program_key_1)
);
}
}

View File

@ -311,8 +311,12 @@ impl Tvu {
);
let (cost_update_sender, cost_update_receiver) = channel();
let cost_update_service =
CostUpdateService::new(blockstore.clone(), cost_model.clone(), cost_update_receiver);
let cost_update_service = CostUpdateService::new(
exit.clone(),
blockstore.clone(),
cost_model.clone(),
cost_update_receiver,
);
let (drop_bank_sender, drop_bank_receiver) = channel();

View File

@ -8,6 +8,7 @@ use {
crate::{block_cost_limits::*, execute_cost_table::ExecuteCostTable},
log::*,
solana_sdk::{pubkey::Pubkey, transaction::SanitizedTransaction},
std::collections::HashMap,
};
const MAX_WRITABLE_ACCOUNTS: usize = 256;
@ -73,9 +74,28 @@ impl CostModel {
.map(|(key, cost)| (key, cost))
.chain(BUILT_IN_INSTRUCTION_COSTS.iter())
.for_each(|(program_id, cost)| {
self.instruction_execution_cost_table
.upsert(program_id, *cost);
match self
.instruction_execution_cost_table
.upsert(program_id, *cost)
{
Some(c) => {
debug!(
"initiating cost table, instruction {:?} has cost {}",
program_id, c
);
}
None => {
debug!(
"initiating cost table, failed for instruction {:?}",
program_id
);
}
}
});
debug!(
"restored cost model instruction cost table from blockstore, current values: {:?}",
self.get_instruction_cost_table()
);
}
pub fn calculate_cost(&self, transaction: &SanitizedTransaction) -> TransactionCost {
@ -90,20 +110,30 @@ impl CostModel {
tx_cost
}
// update-or-insert op is always successful. However the result of upsert, eg the aggregated
// value, requires additional calculation, which should only be envoked when needed.
pub fn upsert_instruction_cost(&mut self, program_key: &Pubkey, cost: u64) {
pub fn upsert_instruction_cost(
&mut self,
program_key: &Pubkey,
cost: u64,
) -> Result<u64, &'static str> {
self.instruction_execution_cost_table
.upsert(program_key, cost);
match self.instruction_execution_cost_table.get_cost(program_key) {
Some(cost) => Ok(*cost),
None => Err("failed to upsert to ExecuteCostTable"),
}
}
pub fn get_instruction_cost_table(&self) -> &HashMap<Pubkey, u64> {
self.instruction_execution_cost_table.get_cost_table()
}
pub fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 {
match self.instruction_execution_cost_table.get_cost(program_key) {
Some(cost) => cost,
Some(cost) => *cost,
None => {
let default_value = self.instruction_execution_cost_table.get_default();
let default_value = self.instruction_execution_cost_table.get_mode();
debug!(
"instruction {:?} does not have aggregated cost, using default {}",
"Program key {:?} does not have assigned cost, using mode {}",
program_key, default_value
);
default_value
@ -181,7 +211,6 @@ mod tests {
transaction::Transaction,
},
std::{
collections::HashMap,
str::FromStr,
sync::{Arc, RwLock},
thread::{self, JoinHandle},
@ -205,16 +234,18 @@ mod tests {
let mut testee = CostModel::default();
let known_key = Pubkey::from_str("known11111111111111111111111111111111111111").unwrap();
testee.upsert_instruction_cost(&known_key, 100);
testee.upsert_instruction_cost(&known_key, 100).unwrap();
// find cost for known programs
assert_eq!(100, testee.find_instruction_cost(&known_key));
testee.upsert_instruction_cost(&bpf_loader::id(), 1999);
testee
.upsert_instruction_cost(&bpf_loader::id(), 1999)
.unwrap();
assert_eq!(1999, testee.find_instruction_cost(&bpf_loader::id()));
// unknown program is assigned with default cost
assert_eq!(
testee.instruction_execution_cost_table.get_default(),
testee.instruction_execution_cost_table.get_mode(),
testee.find_instruction_cost(
&Pubkey::from_str("unknown111111111111111111111111111111111111").unwrap()
)
@ -232,7 +263,7 @@ mod tests {
});
test_key_and_cost.iter().for_each(|(key, cost)| {
testee.upsert_instruction_cost(key, *cost);
let _ = testee.upsert_instruction_cost(key, *cost).unwrap();
info!("key {:?} cost {}", key, cost);
});
@ -267,7 +298,9 @@ mod tests {
let expected_cost = 8;
let mut testee = CostModel::default();
testee.upsert_instruction_cost(&system_program::id(), expected_cost);
testee
.upsert_instruction_cost(&system_program::id(), expected_cost)
.unwrap();
assert_eq!(
expected_cost,
testee.get_transaction_cost(&simple_transaction)
@ -295,7 +328,9 @@ mod tests {
let expected_cost = program_cost * 2;
let mut testee = CostModel::default();
testee.upsert_instruction_cost(&system_program::id(), program_cost);
testee
.upsert_instruction_cost(&system_program::id(), program_cost)
.unwrap();
assert_eq!(expected_cost, testee.get_transaction_cost(&tx));
}
@ -327,7 +362,7 @@ mod tests {
let result = testee.get_transaction_cost(&tx);
// expected cost for two random/unknown program is
let expected_cost = testee.instruction_execution_cost_table.get_default() * 2;
let expected_cost = testee.instruction_execution_cost_table.get_mode() * 2;
assert_eq!(expected_cost, result);
}
@ -371,12 +406,12 @@ mod tests {
let mut cost_model = CostModel::default();
// Using default cost for unknown instruction
assert_eq!(
cost_model.instruction_execution_cost_table.get_default(),
cost_model.instruction_execution_cost_table.get_mode(),
cost_model.find_instruction_cost(&key1)
);
// insert instruction cost to table
cost_model.upsert_instruction_cost(&key1, cost1);
assert!(cost_model.upsert_instruction_cost(&key1, cost1).is_ok());
// now it is known insturction with known cost
assert_eq!(cost1, cost_model.find_instruction_cost(&key1));
@ -396,7 +431,9 @@ mod tests {
let expected_execution_cost = 8;
let mut cost_model = CostModel::default();
cost_model.upsert_instruction_cost(&system_program::id(), expected_execution_cost);
cost_model
.upsert_instruction_cost(&system_program::id(), expected_execution_cost)
.unwrap();
let tx_cost = cost_model.calculate_cost(&tx);
assert_eq!(expected_account_cost, tx_cost.write_lock_cost);
assert_eq!(expected_execution_cost, tx_cost.execution_cost);
@ -408,17 +445,16 @@ mod tests {
let key1 = Pubkey::new_unique();
let cost1 = 100;
let cost2 = 200;
// updated_cost = (mean + 2*std) of [100, 200] => 120.899
let updated_cost = 121;
let updated_cost = (cost1 + cost2) / 2;
let mut cost_model = CostModel::default();
// insert instruction cost to table
cost_model.upsert_instruction_cost(&key1, cost1);
assert!(cost_model.upsert_instruction_cost(&key1, cost1).is_ok());
assert_eq!(cost1, cost_model.find_instruction_cost(&key1));
// update instruction cost
cost_model.upsert_instruction_cost(&key1, cost2);
assert!(cost_model.upsert_instruction_cost(&key1, cost2).is_ok());
assert_eq!(updated_cost, cost_model.find_instruction_cost(&key1));
}
@ -460,8 +496,8 @@ mod tests {
if i == 5 {
thread::spawn(move || {
let mut cost_model = cost_model.write().unwrap();
cost_model.upsert_instruction_cost(&prog1, cost1);
cost_model.upsert_instruction_cost(&prog2, cost2);
assert!(cost_model.upsert_instruction_cost(&prog1, cost1).is_ok());
assert!(cost_model.upsert_instruction_cost(&prog2, cost2).is_ok());
})
} else {
thread::spawn(move || {

View File

@ -4,10 +4,7 @@
/// When its capacity limit is reached, it prunes old and less-used programs
/// to make room for new ones.
use log::*;
use {
solana_sdk::pubkey::Pubkey,
std::collections::{hash_map::Entry, HashMap},
};
use {solana_sdk::pubkey::Pubkey, std::collections::HashMap};
// prune is rather expensive op, free up bulk space in each operation
// would be more efficient. PRUNE_RATIO defines the after prune table
@ -18,22 +15,10 @@ const OCCURRENCES_WEIGHT: i64 = 100;
const DEFAULT_CAPACITY: usize = 1024;
// The coefficient represents the degree of weighting decrease in EMA,
// a constant smoothing factor between 0 and 1. A higher alpha
// discounts older observations faster.
// Setting it using `2/(N+1)` where N is 200 samples
const COEFFICIENT: f64 = 0.01;
#[derive(Debug, Default)]
struct AggregatedVarianceStats {
ema: f64,
ema_var: f64,
}
#[derive(Debug)]
#[derive(AbiExample, Debug)]
pub struct ExecuteCostTable {
capacity: usize,
table: HashMap<Pubkey, AggregatedVarianceStats>,
table: HashMap<Pubkey, u64>,
occurrences: HashMap<Pubkey, (usize, u128)>,
}
@ -52,59 +37,55 @@ impl ExecuteCostTable {
}
}
// number of programs in table
pub fn get_cost_table(&self) -> &HashMap<Pubkey, u64> {
&self.table
}
pub fn get_count(&self) -> usize {
self.table.len()
}
// default program cost to max
pub fn get_default(&self) -> u64 {
// default max compute units per program
200_000u64
// instead of assigning unknown program with a configured/hard-coded cost
// use average or mode function to make a educated guess.
pub fn get_average(&self) -> u64 {
if self.table.is_empty() {
0
} else {
self.table.iter().map(|(_, value)| value).sum::<u64>() / self.get_count() as u64
}
}
pub fn get_mode(&self) -> u64 {
if self.occurrences.is_empty() {
0
} else {
let key = self
.occurrences
.iter()
.max_by_key(|&(_, count)| count)
.map(|(key, _)| key)
.expect("cannot find mode from cost table");
*self.table.get(key).unwrap()
}
}
// returns None if program doesn't exist in table. In this case,
// it is advised to call `get_default()` for default program cost.
// Program cost is estimated as 2 standard deviations above mean, eg
// cost = (mean + 2 * std)
pub fn get_cost(&self, key: &Pubkey) -> Option<u64> {
let aggregated = self.table.get(key)?;
let cost_f64 = (aggregated.ema + 2.0 * aggregated.ema_var.sqrt()).ceil();
// check if cost:f64 can be losslessly convert to u64, otherwise return None
let cost_u64 = cost_f64 as u64;
if cost_f64 == cost_u64 as f64 {
Some(cost_u64)
} else {
None
}
// client is advised to call `get_average()` or `get_mode()` to
// assign a 'default' value for new program.
pub fn get_cost(&self, key: &Pubkey) -> Option<&u64> {
self.table.get(key)
}
pub fn upsert(&mut self, key: &Pubkey, value: u64) {
let need_to_add = !self.table.contains_key(key);
pub fn upsert(&mut self, key: &Pubkey, value: u64) -> Option<u64> {
let need_to_add = self.table.get(key).is_none();
let current_size = self.get_count();
if current_size == self.capacity && need_to_add {
self.prune_to(&((current_size as f64 * PRUNE_RATIO) as usize));
}
// exponential moving average algorithm
// https://en.wikipedia.org/wiki/Moving_average#Exponentially_weighted_moving_variance_and_standard_deviation
match self.table.entry(*key) {
Entry::Occupied(mut entry) => {
let aggregated = entry.get_mut();
let theta = value as f64 - aggregated.ema;
aggregated.ema += theta * COEFFICIENT;
aggregated.ema_var =
(1.0 - COEFFICIENT) * (aggregated.ema_var + COEFFICIENT * theta * theta);
}
Entry::Vacant(entry) => {
// the starting values
entry.insert(AggregatedVarianceStats {
ema: value as f64,
ema_var: 0.0,
});
}
}
let program_cost = self.table.entry(*key).or_insert(value);
*program_cost = (*program_cost + value) / 2;
let (count, timestamp) = self
.occurrences
@ -112,6 +93,8 @@ impl ExecuteCostTable {
.or_insert((0, Self::micros_since_epoch()));
*count += 1;
*timestamp = Self::micros_since_epoch();
Some(*program_cost)
}
pub fn get_program_keys(&self) -> Vec<&Pubkey> {
@ -205,9 +188,9 @@ mod tests {
let key2 = Pubkey::new_unique();
let key3 = Pubkey::new_unique();
// simulate a lot of occurrences to key1, so even there're longer than
// simulate a lot of occurences to key1, so even there're longer than
// usual delay between upsert(key1..) and upsert(key2, ..), test
// would still satisfy as key1 has enough occurrences to compensate
// would still satisfy as key1 has enough occurences to compensate
// its age.
for i in 0..1000 {
testee.upsert(&key1, i);
@ -240,21 +223,25 @@ mod tests {
// insert one record
testee.upsert(&key1, cost1);
assert_eq!(1, testee.get_count());
assert_eq!(cost1, testee.get_cost(&key1).unwrap());
assert_eq!(cost1, testee.get_average());
assert_eq!(cost1, testee.get_mode());
assert_eq!(&cost1, testee.get_cost(&key1).unwrap());
// insert 2nd record
testee.upsert(&key2, cost2);
assert_eq!(2, testee.get_count());
assert_eq!(cost1, testee.get_cost(&key1).unwrap());
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
assert_eq!((cost1 + cost2) / 2_u64, testee.get_average());
assert_eq!(cost2, testee.get_mode());
assert_eq!(&cost1, testee.get_cost(&key1).unwrap());
assert_eq!(&cost2, testee.get_cost(&key2).unwrap());
// update 1st record
testee.upsert(&key1, cost2);
assert_eq!(2, testee.get_count());
// expected key1 cost is EMA of [100, 110] with alpha=0.01 => 103
let expected_cost = 103;
assert_eq!(expected_cost, testee.get_cost(&key1).unwrap());
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
assert_eq!(((cost1 + cost2) / 2 + cost2) / 2, testee.get_average());
assert_eq!((cost1 + cost2) / 2, testee.get_mode());
assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key1).unwrap());
assert_eq!(&cost2, testee.get_cost(&key2).unwrap());
}
#[test]
@ -275,50 +262,33 @@ mod tests {
// insert one record
testee.upsert(&key1, cost1);
assert_eq!(1, testee.get_count());
assert_eq!(cost1, testee.get_cost(&key1).unwrap());
assert_eq!(&cost1, testee.get_cost(&key1).unwrap());
// insert 2nd record
testee.upsert(&key2, cost2);
assert_eq!(2, testee.get_count());
assert_eq!(cost1, testee.get_cost(&key1).unwrap());
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
assert_eq!(&cost1, testee.get_cost(&key1).unwrap());
assert_eq!(&cost2, testee.get_cost(&key2).unwrap());
// insert 3rd record, pushes out the oldest (eg 1st) record
testee.upsert(&key3, cost3);
assert_eq!(2, testee.get_count());
assert_eq!((cost2 + cost3) / 2_u64, testee.get_average());
assert_eq!(cost3, testee.get_mode());
assert!(testee.get_cost(&key1).is_none());
assert_eq!(cost2, testee.get_cost(&key2).unwrap());
assert_eq!(cost3, testee.get_cost(&key3).unwrap());
assert_eq!(&cost2, testee.get_cost(&key2).unwrap());
assert_eq!(&cost3, testee.get_cost(&key3).unwrap());
// update 2nd record, so the 3rd becomes the oldest
// add 4th record, pushes out 3rd key
testee.upsert(&key2, cost1);
testee.upsert(&key4, cost4);
assert_eq!(((cost1 + cost2) / 2 + cost4) / 2_u64, testee.get_average());
assert_eq!((cost1 + cost2) / 2, testee.get_mode());
assert_eq!(2, testee.get_count());
assert!(testee.get_cost(&key1).is_none());
// expected key2 cost = (mean + 2*std) of [110, 100] => 112
let expected_cost_2 = 112;
assert_eq!(expected_cost_2, testee.get_cost(&key2).unwrap());
assert_eq!(&((cost1 + cost2) / 2), testee.get_cost(&key2).unwrap());
assert!(testee.get_cost(&key3).is_none());
assert_eq!(cost4, testee.get_cost(&key4).unwrap());
}
#[test]
fn test_get_cost_overflow_u64() {
solana_logger::setup();
let mut testee = ExecuteCostTable::default();
let key1 = Pubkey::new_unique();
let cost1: u64 = f64::MAX as u64;
let cost2: u64 = u64::MAX / 2; // create large variance so the final result will overflow
// insert one record
testee.upsert(&key1, cost1);
assert_eq!(1, testee.get_count());
assert_eq!(cost1, testee.get_cost(&key1).unwrap());
// update cost
testee.upsert(&key1, cost2);
assert!(testee.get_cost(&key1).is_none());
assert_eq!(&cost4, testee.get_cost(&key4).unwrap());
}
}