diff --git a/core/src/cost_update_service.rs b/core/src/cost_update_service.rs index b79641733e..b1e03871a5 100644 --- a/core/src/cost_update_service.rs +++ b/core/src/cost_update_service.rs @@ -127,8 +127,10 @@ impl CostUpdateService { CostUpdate::FrozenBank { bank } => { bank.read_cost_tracker().unwrap().report_stats(bank.slot()); } - CostUpdate::ExecuteTiming { execute_timings } => { - dirty |= Self::update_cost_model(&cost_model, &execute_timings); + CostUpdate::ExecuteTiming { + mut execute_timings, + } => { + dirty |= Self::update_cost_model(&cost_model, &mut execute_timings); update_count += 1; } } @@ -151,16 +153,27 @@ impl CostUpdateService { } } - fn update_cost_model(cost_model: &RwLock, execute_timings: &ExecuteTimings) -> bool { + fn update_cost_model( + cost_model: &RwLock, + execute_timings: &mut ExecuteTimings, + ) -> bool { let mut dirty = false; { - let mut cost_model_mutable = cost_model.write().unwrap(); - for (program_id, timing) in &execute_timings.details.per_program_timings { - if timing.count < 1 { + for (program_id, program_timings) in &mut execute_timings.details.per_program_timings { + let current_estimated_program_cost = + cost_model.read().unwrap().find_instruction_cost(program_id); + program_timings.coalesce_error_timings(current_estimated_program_cost); + + if program_timings.count < 1 { continue; } - let units = timing.accumulated_units / timing.count as u64; - match cost_model_mutable.upsert_instruction_cost(program_id, units) { + + let units = program_timings.accumulated_units / program_timings.count as u64; + match cost_model + .write() + .unwrap() + .upsert_instruction_cost(program_id, units) + { Ok(c) => { debug!( "after replayed into bank, instruction {:?} has averaged cost {}", @@ -213,8 +226,8 @@ mod tests { #[test] fn test_update_cost_model_with_empty_execute_timings() { let cost_model = Arc::new(RwLock::new(CostModel::default())); - let empty_execute_timings = ExecuteTimings::default(); - CostUpdateService::update_cost_model(&cost_model, &empty_execute_timings); + let mut empty_execute_timings = ExecuteTimings::default(); + CostUpdateService::update_cost_model(&cost_model, &mut empty_execute_timings); assert_eq!( 0, @@ -247,9 +260,10 @@ mod tests { accumulated_us, accumulated_units, count, + errored_txs_compute_consumed: vec![], }, ); - CostUpdateService::update_cost_model(&cost_model, &execute_timings); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); assert_eq!( 1, cost_model @@ -282,9 +296,10 @@ mod tests { accumulated_us, accumulated_units, count, + errored_txs_compute_consumed: vec![], }, ); - CostUpdateService::update_cost_model(&cost_model, &execute_timings); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); assert_eq!( 1, cost_model @@ -303,4 +318,99 @@ mod tests { ); } } + + #[test] + fn test_update_cost_model_with_error_execute_timings() { + let cost_model = Arc::new(RwLock::new(CostModel::default())); + let mut execute_timings = ExecuteTimings::default(); + let program_key_1 = Pubkey::new_unique(); + + // Test updating cost model with a `ProgramTiming` with no compute units accumulated, i.e. + // `accumulated_units` == 0 + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + // If both the `errored_txs_compute_consumed` is empty and `count == 0`, then + // nothing should be inserted into the cost model + assert!(cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .is_empty()); + } + + // Test updating cost model with only erroring compute costs where the `cost_per_error` is + // greater than the current instruction cost for the program. Should update with the + // new erroring compute costs + let cost_per_error = 1000; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![cost_per_error; 3], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!( + 1, + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .len() + ); + assert_eq!( + Some(&cost_per_error), + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .get(&program_key_1) + ); + } + + // Test updating cost model with only erroring compute costs where the error cost is + // `smaller_cost_per_error`, less than the current instruction cost for the program. + // The cost should not decrease for these new lesser errors + let smaller_cost_per_error = cost_per_error - 10; + { + execute_timings.details.per_program_timings.insert( + program_key_1, + ProgramTiming { + accumulated_us: 1000, + accumulated_units: 0, + count: 0, + errored_txs_compute_consumed: vec![smaller_cost_per_error; 3], + }, + ); + CostUpdateService::update_cost_model(&cost_model, &mut execute_timings); + assert_eq!( + 1, + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .len() + ); + assert_eq!( + Some(&cost_per_error), + cost_model + .read() + .unwrap() + .get_instruction_cost_table() + .get(&program_key_1) + ); + } + } } diff --git a/runtime/src/cost_model.rs b/runtime/src/cost_model.rs index 075497ab76..302b20db0f 100644 --- a/runtime/src/cost_model.rs +++ b/runtime/src/cost_model.rs @@ -123,6 +123,20 @@ impl CostModel { transaction.signatures.len() as u64 * SIGNATURE_COST } + pub fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { + match self.instruction_execution_cost_table.get_cost(program_key) { + Some(cost) => *cost, + None => { + let default_value = self.instruction_execution_cost_table.get_mode(); + debug!( + "Program key {:?} does not have assigned cost, using mode {}", + program_key, default_value + ); + default_value + } + } + } + fn get_write_lock_cost( &self, tx_cost: &mut TransactionCost, @@ -164,20 +178,6 @@ impl CostModel { } cost } - - fn find_instruction_cost(&self, program_key: &Pubkey) -> u64 { - match self.instruction_execution_cost_table.get_cost(program_key) { - Some(cost) => *cost, - None => { - let default_value = self.instruction_execution_cost_table.get_mode(); - debug!( - "Program key {:?} does not have assigned cost, using mode {}", - program_key, default_value - ); - default_value - } - } - } } #[cfg(test)] diff --git a/runtime/src/message_processor.rs b/runtime/src/message_processor.rs index 9d8b124c82..bdcb4e03f0 100644 --- a/runtime/src/message_processor.rs +++ b/runtime/src/message_processor.rs @@ -66,6 +66,18 @@ pub struct ProgramTiming { pub accumulated_us: u64, pub accumulated_units: u64, pub count: u32, + pub errored_txs_compute_consumed: Vec, +} + +impl ProgramTiming { + pub fn coalesce_error_timings(&mut self, current_estimated_program_cost: u64) { + for tx_error_compute_consumed in self.errored_txs_compute_consumed.drain(..) { + let compute_units_update = + std::cmp::max(current_estimated_program_cost, tx_error_compute_consumed); + self.accumulated_units = self.accumulated_units.saturating_add(compute_units_update); + self.count = self.count.saturating_add(1); + } + } } #[derive(Default, Debug)] @@ -1264,31 +1276,43 @@ impl MessageProcessor { let pre_remaining_units = invoke_context.get_compute_meter().borrow().get_remaining(); let mut time = Measure::start("execute_instruction"); - self.process_instruction(program_id, &instruction.data, &mut invoke_context)?; - Self::verify( - message, - instruction, - &invoke_context.pre_accounts, - executable_accounts, - accounts, - &rent_collector.rent, - timings, - invoke_context.get_logger(), - invoke_context.is_feature_active(&updated_verify_policy::id()), - invoke_context.is_feature_active(&demote_program_write_locks::id()), - )?; + let execute_result = + self.process_instruction(program_id, &instruction.data, &mut invoke_context); + let execute_or_verify_result = execute_result.and_then(|_| { + Self::verify( + message, + instruction, + &invoke_context.pre_accounts, + executable_accounts, + accounts, + &rent_collector.rent, + timings, + invoke_context.get_logger(), + invoke_context.is_feature_active(&updated_verify_policy::id()), + invoke_context.is_feature_active(&demote_program_write_locks::id()), + ) + }); time.stop(); let post_remaining_units = invoke_context.get_compute_meter().borrow().get_remaining(); let program_timing = timings.per_program_timings.entry(*program_id).or_default(); program_timing.accumulated_us += time.as_us(); - program_timing.accumulated_units += pre_remaining_units - post_remaining_units; - program_timing.count += 1; + let compute_units_consumed = pre_remaining_units.saturating_sub(post_remaining_units); + if execute_or_verify_result.is_err() { + program_timing + .errored_txs_compute_consumed + .push(compute_units_consumed); + } else { + program_timing.accumulated_units = program_timing + .accumulated_units + .saturating_add(compute_units_consumed); + program_timing.count = program_timing.count.saturating_add(1); + } timings.accumulate(&invoke_context.timings); - Ok(()) + execute_or_verify_result } /// Process a message.