diff --git a/programs/storage_api/src/storage_contract.rs b/programs/storage_api/src/storage_contract.rs index da0813d088..6ff80e3e47 100644 --- a/programs/storage_api/src/storage_contract.rs +++ b/programs/storage_api/src/storage_contract.rs @@ -374,18 +374,10 @@ impl<'a> StorageAccount<'a> { StorageError::InvalidOwner as u32, ))? } - - let pending = *pending_lamports; - if rewards_pool.account.lamports < pending { - println!("reward pool has {}", rewards_pool.account.lamports); - Err(InstructionError::CustomError( - StorageError::RewardPoolDepleted as u32, - ))? - } - rewards_pool.account.lamports -= pending; - owner.account.lamports += pending; + redeem(*pending_lamports, rewards_pool, owner)?; //clear pending_lamports *pending_lamports = 0; + self.account.set_state(storage_contract) } else if let StorageContract::ReplicatorStorage { owner: account_owner, @@ -409,11 +401,10 @@ impl<'a> StorageAccount<'a> { }) .collect::>(); reward_validations.clear(); - let total_proofs = checked_proofs.len() as u64; let num_validations = count_valid_proofs(&checked_proofs); - let reward = num_validations * REPLICATOR_REWARD * (num_validations / total_proofs); - rewards_pool.account.lamports -= reward; - owner.account.lamports += reward; + let reward = num_validations * REPLICATOR_REWARD; + redeem(reward, rewards_pool, owner)?; + self.account.set_state(storage_contract) } else { Err(InstructionError::InvalidArgument)? @@ -421,6 +412,21 @@ impl<'a> StorageAccount<'a> { } } +fn redeem( + rewards: u64, + rewards_pool: &mut KeyedAccount, + owner: &mut StorageAccount, +) -> Result<(), InstructionError> { + if rewards_pool.account.lamports < rewards { + Err(InstructionError::CustomError( + StorageError::RewardPoolDepleted as u32, + ))? + } + rewards_pool.account.lamports -= rewards; + owner.account.lamports += rewards; + Ok(()) +} + pub fn create_rewards_pool() -> Account { Account::new_data(std::u64::MAX, &StorageContract::RewardsPool, &crate::id()).unwrap() } @@ -470,7 +476,7 @@ fn count_valid_proofs(proofs: &[ProofStatus]) -> u64 { #[cfg(test)] mod tests { use super::*; - use crate::id; + use crate::{id, rewards_pools}; use std::collections::BTreeMap; #[test] @@ -570,4 +576,35 @@ mod tests { ) .unwrap(); } + + #[test] + fn test_redeem() { + let reward = 100; + let mut owner_account = Account { + lamports: 1, + ..Account::default() + }; + let mut rewards_pool = create_rewards_pool(); + let pool_id = rewards_pools::id(); + let mut keyed_pool_account = KeyedAccount::new(&pool_id, false, &mut rewards_pool); + let mut owner = StorageAccount { + id: Pubkey::default(), + account: &mut owner_account, + }; + + // check that redeeming from depleted pools fails + keyed_pool_account.account.lamports = 0; + assert_eq!( + redeem(reward, &mut keyed_pool_account, &mut owner), + Err(InstructionError::CustomError( + StorageError::RewardPoolDepleted as u32, + )) + ); + assert_eq!(owner.account.lamports, 1); + + keyed_pool_account.account.lamports = 200; + assert_eq!(redeem(reward, &mut keyed_pool_account, &mut owner), Ok(())); + // check that the owner's balance increases + assert_eq!(owner.account.lamports, 101); + } }