diff --git a/src/accountant.rs b/src/accountant.rs index 683474efa7..49c5d2c883 100644 --- a/src/accountant.rs +++ b/src/accountant.rs @@ -15,6 +15,7 @@ use signature::{KeyPair, PublicKey, Signature}; use std::collections::hash_map::Entry::Occupied; use std::collections::{HashMap, HashSet, VecDeque}; use std::result; +use std::sync::atomic::{AtomicIsize, Ordering}; use std::sync::RwLock; use transaction::Transaction; @@ -30,18 +31,18 @@ pub enum AccountingError { pub type Result = result::Result; /// Commit funds to the 'to' party. -fn apply_payment(balances: &RwLock>>, payment: &Payment) { +fn apply_payment(balances: &RwLock>, payment: &Payment) { if balances.read().unwrap().contains_key(&payment.to) { let bals = balances.read().unwrap(); - *bals[&payment.to].write().unwrap() += payment.tokens; + bals[&payment.to].fetch_add(payment.tokens as isize, Ordering::Relaxed); } else { let mut bals = balances.write().unwrap(); - bals.insert(payment.to, RwLock::new(payment.tokens)); + bals.insert(payment.to, AtomicIsize::new(payment.tokens as isize)); } } pub struct Accountant { - balances: RwLock>>, + balances: RwLock>, pending: RwLock>, last_ids: RwLock>)>>, time_sources: RwLock>, @@ -127,27 +128,37 @@ impl Accountant { /// funds and isn't a duplicate. pub fn process_verified_transaction_debits(&self, tr: &Transaction) -> Result<()> { let bals = self.balances.read().unwrap(); - - // Hold a write lock before the condition check, so that a debit can't occur - // between checking the balance and the withdraw. let option = bals.get(&tr.from); + if option.is_none() { return Err(AccountingError::AccountNotFound); } - let mut bal = option.unwrap().write().unwrap(); if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) { return Err(AccountingError::InvalidTransferSignature); } - if *bal < tr.data.tokens { - self.forget_signature_with_last_id(&tr.sig, &tr.data.last_id); - return Err(AccountingError::InsufficientFunds); + loop { + let bal = option.unwrap(); + let current = bal.load(Ordering::Relaxed) as i64; + + if current < tr.data.tokens { + self.forget_signature_with_last_id(&tr.sig, &tr.data.last_id); + return Err(AccountingError::InsufficientFunds); + } + + let result = bal.compare_exchange( + current as isize, + (current - tr.data.tokens) as isize, + Ordering::Relaxed, + Ordering::Relaxed, + ); + + match result { + Ok(_) => return Ok(()), + Err(_) => continue, + }; } - - *bal -= tr.data.tokens; - - Ok(()) } pub fn process_verified_transaction_credits(&self, tr: &Transaction) { @@ -300,7 +311,7 @@ impl Accountant { pub fn get_balance(&self, pubkey: &PublicKey) -> Option { let bals = self.balances.read().unwrap(); - bals.get(pubkey).map(|x| *x.read().unwrap()) + bals.get(pubkey).map(|x| x.load(Ordering::Relaxed) as i64) } } diff --git a/src/accountant_skel.rs b/src/accountant_skel.rs index 9712c6c031..65a0ddd097 100644 --- a/src/accountant_skel.rs +++ b/src/accountant_skel.rs @@ -487,14 +487,14 @@ mod tests { use std::time::Duration; use transaction::Transaction; - use subscribers::{Node, Subscribers}; - use streamer; - use std::sync::mpsc::channel; - use std::collections::VecDeque; - use hash::{hash, Hash}; - use event::Event; - use entry; use chrono::prelude::*; + use entry; + use event::Event; + use hash::{hash, Hash}; + use std::collections::VecDeque; + use std::sync::mpsc::channel; + use streamer; + use subscribers::{Node, Subscribers}; #[test] fn test_layout() { diff --git a/src/packet.rs b/src/packet.rs index d97b261e9f..c4b09eb56e 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -4,9 +4,9 @@ use result::{Error, Result}; use std::collections::VecDeque; use std::fmt; use std::io; +use std::mem::size_of; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::sync::{Arc, Mutex, RwLock}; -use std::mem::size_of; pub type SharedPackets = Arc>; pub type SharedBlob = Arc>; diff --git a/src/result.rs b/src/result.rs index 01872dfbe1..532a64c3b2 100644 --- a/src/result.rs +++ b/src/result.rs @@ -1,10 +1,10 @@ //! The `result` module exposes a Result type that propagates one of many different Error types. +use accountant; use bincode; use serde_json; use std; use std::any::Any; -use accountant; #[derive(Debug)] pub enum Error { diff --git a/src/streamer.rs b/src/streamer.rs index 43e6f2ac35..7f0e7fbdba 100644 --- a/src/streamer.rs +++ b/src/streamer.rs @@ -382,8 +382,9 @@ mod test { use std::sync::mpsc::channel; use std::sync::{Arc, RwLock}; use std::time::Duration; - use streamer::{blob_receiver, receiver, responder, retransmitter, window, BlobReceiver, - PacketReceiver}; + use streamer::{ + blob_receiver, receiver, responder, retransmitter, window, BlobReceiver, PacketReceiver, + }; use subscribers::{Node, Subscribers}; fn get_msgs(r: PacketReceiver, num: &mut usize) {