Merge pull request #165 from rlkelly/126__atomic_balances

126  atomic balances
This commit is contained in:
Greg Fitzgerald
2018-05-02 10:43:31 -06:00
committed by GitHub
5 changed files with 39 additions and 27 deletions

View File

@ -15,6 +15,7 @@ use signature::{KeyPair, PublicKey, Signature};
use std::collections::hash_map::Entry::Occupied; use std::collections::hash_map::Entry::Occupied;
use std::collections::{HashMap, HashSet, VecDeque}; use std::collections::{HashMap, HashSet, VecDeque};
use std::result; use std::result;
use std::sync::atomic::{AtomicIsize, Ordering};
use std::sync::RwLock; use std::sync::RwLock;
use transaction::Transaction; use transaction::Transaction;
@ -30,18 +31,18 @@ pub enum AccountingError {
pub type Result<T> = result::Result<T, AccountingError>; pub type Result<T> = result::Result<T, AccountingError>;
/// Commit funds to the 'to' party. /// Commit funds to the 'to' party.
fn apply_payment(balances: &RwLock<HashMap<PublicKey, RwLock<i64>>>, payment: &Payment) { fn apply_payment(balances: &RwLock<HashMap<PublicKey, AtomicIsize>>, payment: &Payment) {
if balances.read().unwrap().contains_key(&payment.to) { if balances.read().unwrap().contains_key(&payment.to) {
let bals = balances.read().unwrap(); let bals = balances.read().unwrap();
*bals[&payment.to].write().unwrap() += payment.tokens; bals[&payment.to].fetch_add(payment.tokens as isize, Ordering::Relaxed);
} else { } else {
let mut bals = balances.write().unwrap(); let mut bals = balances.write().unwrap();
bals.insert(payment.to, RwLock::new(payment.tokens)); bals.insert(payment.to, AtomicIsize::new(payment.tokens as isize));
} }
} }
pub struct Accountant { pub struct Accountant {
balances: RwLock<HashMap<PublicKey, RwLock<i64>>>, balances: RwLock<HashMap<PublicKey, AtomicIsize>>,
pending: RwLock<HashMap<Signature, Plan>>, pending: RwLock<HashMap<Signature, Plan>>,
last_ids: RwLock<VecDeque<(Hash, RwLock<HashSet<Signature>>)>>, last_ids: RwLock<VecDeque<(Hash, RwLock<HashSet<Signature>>)>>,
time_sources: RwLock<HashSet<PublicKey>>, time_sources: RwLock<HashSet<PublicKey>>,
@ -127,27 +128,37 @@ impl Accountant {
/// funds and isn't a duplicate. /// funds and isn't a duplicate.
pub fn process_verified_transaction_debits(&self, tr: &Transaction) -> Result<()> { pub fn process_verified_transaction_debits(&self, tr: &Transaction) -> Result<()> {
let bals = self.balances.read().unwrap(); let bals = self.balances.read().unwrap();
// Hold a write lock before the condition check, so that a debit can't occur
// between checking the balance and the withdraw.
let option = bals.get(&tr.from); let option = bals.get(&tr.from);
if option.is_none() { if option.is_none() {
return Err(AccountingError::AccountNotFound); return Err(AccountingError::AccountNotFound);
} }
let mut bal = option.unwrap().write().unwrap();
if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) { if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) {
return Err(AccountingError::InvalidTransferSignature); return Err(AccountingError::InvalidTransferSignature);
} }
if *bal < tr.data.tokens { loop {
let bal = option.unwrap();
let current = bal.load(Ordering::Relaxed) as i64;
if current < tr.data.tokens {
self.forget_signature_with_last_id(&tr.sig, &tr.data.last_id); self.forget_signature_with_last_id(&tr.sig, &tr.data.last_id);
return Err(AccountingError::InsufficientFunds); return Err(AccountingError::InsufficientFunds);
} }
*bal -= tr.data.tokens; let result = bal.compare_exchange(
current as isize,
(current - tr.data.tokens) as isize,
Ordering::Relaxed,
Ordering::Relaxed,
);
Ok(()) match result {
Ok(_) => return Ok(()),
Err(_) => continue,
};
}
} }
pub fn process_verified_transaction_credits(&self, tr: &Transaction) { pub fn process_verified_transaction_credits(&self, tr: &Transaction) {
@ -300,7 +311,7 @@ impl Accountant {
pub fn get_balance(&self, pubkey: &PublicKey) -> Option<i64> { pub fn get_balance(&self, pubkey: &PublicKey) -> Option<i64> {
let bals = self.balances.read().unwrap(); let bals = self.balances.read().unwrap();
bals.get(pubkey).map(|x| *x.read().unwrap()) bals.get(pubkey).map(|x| x.load(Ordering::Relaxed) as i64)
} }
} }

View File

@ -487,14 +487,14 @@ mod tests {
use std::time::Duration; use std::time::Duration;
use transaction::Transaction; use transaction::Transaction;
use subscribers::{Node, Subscribers};
use streamer;
use std::sync::mpsc::channel;
use std::collections::VecDeque;
use hash::{hash, Hash};
use event::Event;
use entry;
use chrono::prelude::*; use chrono::prelude::*;
use entry;
use event::Event;
use hash::{hash, Hash};
use std::collections::VecDeque;
use std::sync::mpsc::channel;
use streamer;
use subscribers::{Node, Subscribers};
#[test] #[test]
fn test_layout() { fn test_layout() {

View File

@ -4,9 +4,9 @@ use result::{Error, Result};
use std::collections::VecDeque; use std::collections::VecDeque;
use std::fmt; use std::fmt;
use std::io; use std::io;
use std::mem::size_of;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::mem::size_of;
pub type SharedPackets = Arc<RwLock<Packets>>; pub type SharedPackets = Arc<RwLock<Packets>>;
pub type SharedBlob = Arc<RwLock<Blob>>; pub type SharedBlob = Arc<RwLock<Blob>>;

View File

@ -1,10 +1,10 @@
//! The `result` module exposes a Result type that propagates one of many different Error types. //! The `result` module exposes a Result type that propagates one of many different Error types.
use accountant;
use bincode; use bincode;
use serde_json; use serde_json;
use std; use std;
use std::any::Any; use std::any::Any;
use accountant;
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {

View File

@ -382,8 +382,9 @@ mod test {
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use streamer::{blob_receiver, receiver, responder, retransmitter, window, BlobReceiver, use streamer::{
PacketReceiver}; blob_receiver, receiver, responder, retransmitter, window, BlobReceiver, PacketReceiver,
};
use subscribers::{Node, Subscribers}; use subscribers::{Node, Subscribers};
fn get_msgs(r: PacketReceiver, num: &mut usize) { fn get_msgs(r: PacketReceiver, num: &mut usize) {