diff --git a/runtime/src/accounts.rs b/runtime/src/accounts.rs index 371140886f..a7680589ab 100644 --- a/runtime/src/accounts.rs +++ b/runtime/src/accounts.rs @@ -152,7 +152,7 @@ impl Accounts { let rent_fix_enabled = feature_set.cumulative_rent_related_fixes_enabled(); for (i, key) in message.account_keys.iter().enumerate() { - let account = if Self::is_non_loader_key(message, key, i) { + let account = if message.is_non_loader_key(key, i) { if payer_index.is_none() { payer_index = Some(i); } @@ -793,10 +793,6 @@ impl Accounts { self.accounts_db.add_root(slot) } - pub fn is_non_loader_key(message: &Message, key: &Pubkey, key_index: usize) -> bool { - !message.program_ids().contains(&key) || message.is_key_passed_to_program(key_index) - } - fn collect_accounts_to_store<'a>( &self, txs: &'a [Transaction], @@ -843,7 +839,7 @@ impl Accounts { .iter() .enumerate() .zip(acc.0.iter_mut()) - .filter(|((i, key), _account)| Self::is_non_loader_key(message, key, *i)) + .filter(|((i, key), _account)| message.is_non_loader_key(key, *i)) { let is_nonce_account = prepare_if_nonce_account( account, diff --git a/runtime/src/bank.rs b/runtime/src/bank.rs index ad2a0a4597..06e8ca2392 100644 --- a/runtime/src/bank.rs +++ b/runtime/src/bank.rs @@ -515,7 +515,7 @@ impl NonceRollbackFull { .account_keys .iter() .enumerate() - .find(|(i, k)| Accounts::is_non_loader_key(message, k, *i)) + .find(|(i, k)| message.is_non_loader_key(k, *i)) .and_then(|(i, k)| accounts.get(i).cloned().map(|a| (*k, a))); if let Some((fee_pubkey, fee_account)) = fee_payer { if fee_pubkey == nonce_address { diff --git a/sdk/program/src/message.rs b/sdk/program/src/message.rs index ad4ba63f12..c58442500f 100644 --- a/sdk/program/src/message.rs +++ b/sdk/program/src/message.rs @@ -298,6 +298,10 @@ impl Message { false } + pub fn is_non_loader_key(&self, key: &Pubkey, key_index: usize) -> bool { + !self.program_ids().contains(&key) || self.is_key_passed_to_program(key_index) + } + pub fn program_position(&self, index: usize) -> Option { let program_ids = self.program_ids(); program_ids @@ -794,4 +798,61 @@ mod tests { ); } } + + #[test] + fn test_program_ids() { + let key0 = Pubkey::new_unique(); + let key1 = Pubkey::new_unique(); + let loader2 = Pubkey::new_unique(); + let instructions = vec![CompiledInstruction::new(2, &(), vec![0, 1])]; + let message = Message::new_with_compiled_instructions( + 1, + 0, + 2, + vec![key0, key1, loader2], + Hash::default(), + instructions, + ); + assert_eq!(message.program_ids(), vec![&loader2]); + } + + #[test] + fn test_is_key_passed_to_program() { + let key0 = Pubkey::new_unique(); + let key1 = Pubkey::new_unique(); + let loader2 = Pubkey::new_unique(); + let instructions = vec![CompiledInstruction::new(2, &(), vec![0, 1])]; + let message = Message::new_with_compiled_instructions( + 1, + 0, + 2, + vec![key0, key1, loader2], + Hash::default(), + instructions, + ); + + assert!(message.is_key_passed_to_program(0)); + assert!(message.is_key_passed_to_program(1)); + assert!(!message.is_key_passed_to_program(2)); + } + + #[test] + fn test_is_non_loader_key() { + let key0 = Pubkey::new_unique(); + let key1 = Pubkey::new_unique(); + let loader2 = Pubkey::new_unique(); + let instructions = vec![CompiledInstruction::new(2, &(), vec![0, 1])]; + let message = Message::new_with_compiled_instructions( + 1, + 0, + 2, + vec![key0, key1, loader2], + Hash::default(), + instructions, + ); + + assert!(message.is_non_loader_key(&key0, 0)); + assert!(message.is_non_loader_key(&key1, 1)); + assert!(!message.is_non_loader_key(&loader2, 2)); + } }