diff --git a/programs/bpf/rust/invoke/src/lib.rs b/programs/bpf/rust/invoke/src/lib.rs index fae606e7e9..e4fa21c39a 100644 --- a/programs/bpf/rust/invoke/src/lib.rs +++ b/programs/bpf/rust/invoke/src/lib.rs @@ -144,6 +144,92 @@ fn process_instruction( ); } + info!("Test refcell usage"); + { + let writable = INVOKED_ARGUMENT_INDEX; + let readable = INVOKED_PROGRAM_INDEX; + + let instruction = create_instruction( + *accounts[INVOKED_PROGRAM_INDEX].key, + &[ + (accounts[writable].key, true, true), + (accounts[readable].key, false, false), + ], + vec![TEST_RETURN_ERROR, 1, 2, 3, 4, 5], + ); + + // success with this account configuration as a check + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::Custom(42)) + ); + + { + // writable but lamports borrow_mut'd + let _ref_mut = accounts[writable].try_borrow_mut_lamports()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // writable but data borrow_mut'd + let _ref_mut = accounts[writable].try_borrow_mut_data()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // writable but lamports borrow'd + let _ref_mut = accounts[writable].try_borrow_lamports()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // writable but data borrow'd + let _ref_mut = accounts[writable].try_borrow_data()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // readable but lamports borrow_mut'd + let _ref_mut = accounts[readable].try_borrow_mut_lamports()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // readable but data borrow_mut'd + let _ref_mut = accounts[readable].try_borrow_mut_data()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::AccountBorrowFailed) + ); + } + { + // readable but lamports borrow'd + let _ref_mut = accounts[readable].try_borrow_lamports()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::Custom(42)) + ); + } + { + // readable but data borrow'd + let _ref_mut = accounts[readable].try_borrow_data()?; + assert_eq!( + invoke(&instruction, accounts), + Err(ProgramError::Custom(42)) + ); + } + } + info!("Test create_program_address"); { assert_eq!( diff --git a/programs/bpf/tests/programs.rs b/programs/bpf/tests/programs.rs index 27e92584a6..3f3fe4e9a4 100644 --- a/programs/bpf/tests/programs.rs +++ b/programs/bpf/tests/programs.rs @@ -440,16 +440,25 @@ fn test_program_bpf_invoke() { const TEST_PRIVILEGE_ESCALATION_WRITABLE: u8 = 3; const TEST_PPROGRAM_NOT_EXECUTABLE: u8 = 4; + #[allow(dead_code)] + #[derive(Debug)] + enum Languages { + C, + Rust, + } let mut programs = Vec::new(); #[cfg(feature = "bpf_c")] { - programs.extend_from_slice(&[("invoke", "invoked")]); + programs.push((Languages::C, "invoke", "invoked")); } #[cfg(feature = "bpf_rust")] { - programs.extend_from_slice(&[("solana_bpf_rust_invoke", "solana_bpf_rust_invoked")]); + programs.push(( + Languages::Rust, + "solana_bpf_rust_invoke", + "solana_bpf_rust_invoked", + )); } - for program in programs.iter() { println!("Test program: {:?}", program); @@ -465,9 +474,9 @@ fn test_program_bpf_invoke() { let bank_client = BankClient::new_shared(&bank); let invoke_program_id = - load_bpf_program(&bank_client, &bpf_loader::id(), &mint_keypair, program.0); - let invoked_program_id = load_bpf_program(&bank_client, &bpf_loader::id(), &mint_keypair, program.1); + let invoked_program_id = + load_bpf_program(&bank_client, &bpf_loader::id(), &mint_keypair, program.2); let argument_keypair = Keypair::new(); let account = Account::new(42, 100, &invoke_program_id); @@ -527,9 +536,9 @@ fn test_program_bpf_invoke() { .iter() .map(|ix| message.account_keys[ix.program_id_index as usize].clone()) .collect(); - assert_eq!( - invoked_programs, - vec![ + + let expected_invoked_programs = match program.0 { + Languages::C => vec![ solana_sdk::system_program::id(), solana_sdk::system_program::id(), invoked_program_id.clone(), @@ -542,8 +551,26 @@ fn test_program_bpf_invoke() { invoked_program_id.clone(), invoked_program_id.clone(), invoked_program_id.clone(), - ] - ); + ], + Languages::Rust => vec![ + solana_sdk::system_program::id(), + solana_sdk::system_program::id(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + invoked_program_id.clone(), + ], + }; + assert_eq!(invoked_programs, expected_invoked_programs); // failure cases diff --git a/sdk/src/program.rs b/sdk/src/program.rs index d2a2fbd3e9..6058175544 100644 --- a/sdk/src/program.rs +++ b/sdk/src/program.rs @@ -16,6 +16,22 @@ pub fn invoke_signed( account_infos: &[AccountInfo], signers_seeds: &[&[&[u8]]], ) -> ProgramResult { + // Check that the account RefCells are consistent with the request + for account_meta in instruction.accounts.iter() { + for account_info in account_infos.iter() { + if account_meta.pubkey == *account_info.key { + if account_meta.is_writable { + let _ = account_info.try_borrow_mut_lamports()?; + let _ = account_info.try_borrow_mut_data()?; + } else { + let _ = account_info.try_borrow_lamports()?; + let _ = account_info.try_borrow_data()?; + } + break; + } + } + } + let result = unsafe { sol_invoke_signed_rust( instruction as *const _ as *const u8,