Compare commits

...

7 Commits

10 changed files with 433 additions and 109 deletions

View File

@ -1 +1 @@
1.4.0 1.3.3

View File

@ -48,10 +48,10 @@ import (
const ( const (
ClientIdentifier = "Geth" ClientIdentifier = "Geth"
Version = "1.4.0-unstable" Version = "1.3.3"
VersionMajor = 1 VersionMajor = 1
VersionMinor = 4 VersionMinor = 3
VersionPatch = 0 VersionPatch = 3
) )
var ( var (

View File

@ -591,6 +591,19 @@ func (bc *BlockChain) HasBlock(hash common.Hash) bool {
return bc.GetBlock(hash) != nil return bc.GetBlock(hash) != nil
} }
// HasBlockAndState checks if a block and associated state trie is fully present
// in the database or not, caching it if present.
func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
// Check first that the block itself is known
block := bc.GetBlock(hash)
if block == nil {
return false
}
// Ensure the associated state is also present
_, err := state.New(block.Root(), bc.chainDb)
return err == nil
}
// GetBlock retrieves a block from the database by hash, caching it if found. // GetBlock retrieves a block from the database by hash, caching it if found.
func (self *BlockChain) GetBlock(hash common.Hash) *types.Block { func (self *BlockChain) GetBlock(hash common.Hash) *types.Block {
// Short circuit if the block's already in the cache, retrieve otherwise // Short circuit if the block's already in the cache, retrieve otherwise

View File

@ -55,7 +55,6 @@ type StateDB struct {
func New(root common.Hash, db ethdb.Database) (*StateDB, error) { func New(root common.Hash, db ethdb.Database) (*StateDB, error) {
tr, err := trie.NewSecure(root, db) tr, err := trie.NewSecure(root, db)
if err != nil { if err != nil {
glog.Errorf("can't create state trie with root %x: %v", root[:], err)
return nil, err return nil, err
} }
return &StateDB{ return &StateDB{

View File

@ -138,7 +138,6 @@ func (pool *TxPool) resetState() {
} }
} }
} }
// Check the queue and move transactions over to the pending if possible // Check the queue and move transactions over to the pending if possible
// or remove those that have become invalid // or remove those that have become invalid
pool.checkQueue() pool.checkQueue()
@ -290,17 +289,15 @@ func (pool *TxPool) addTx(hash common.Hash, addr common.Address, tx *types.Trans
} }
// Add queues a single transaction in the pool if it is valid. // Add queues a single transaction in the pool if it is valid.
func (self *TxPool) Add(tx *types.Transaction) (err error) { func (self *TxPool) Add(tx *types.Transaction) error {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
err = self.add(tx) if err := self.add(tx); err != nil {
if err == nil { return err
// check and validate the queueue
self.checkQueue()
} }
self.checkQueue()
return return nil
} }
// AddTransactions attempts to queue all valid transactions in txs. // AddTransactions attempts to queue all valid transactions in txs.
@ -406,51 +403,55 @@ func (pool *TxPool) checkQueue() {
pool.resetState() pool.resetState()
} }
var addq txQueue var promote txQueue
for address, txs := range pool.queue { for address, txs := range pool.queue {
// guessed nonce is the nonce currently kept by the tx pool (pending state)
guessedNonce := pool.pendingState.GetNonce(address)
// true nonce is the nonce known by the last state
currentState, err := pool.currentState() currentState, err := pool.currentState()
if err != nil { if err != nil {
glog.Errorf("could not get current state: %v", err) glog.Errorf("could not get current state: %v", err)
return return
} }
trueNonce := currentState.GetNonce(address) balance := currentState.GetBalance(address)
addq := addq[:0]
var (
guessedNonce = pool.pendingState.GetNonce(address) // nonce currently kept by the tx pool (pending state)
trueNonce = currentState.GetNonce(address) // nonce known by the last state
)
promote = promote[:0]
for hash, tx := range txs { for hash, tx := range txs {
if tx.Nonce() < trueNonce { // Drop processed or out of fund transactions
// Drop queued transactions whose nonce is lower than if tx.Nonce() < trueNonce || balance.Cmp(tx.Cost()) < 0 {
// the account nonce because they have been processed. if glog.V(logger.Core) {
glog.Infof("removed tx (%v) from pool queue: low tx nonce or out of funds\n", tx)
}
delete(txs, hash) delete(txs, hash)
} else {
// Collect the remaining transactions for the next pass.
addq = append(addq, txQueueEntry{hash, address, tx})
}
}
// Find the next consecutive nonce range starting at the
// current account nonce.
sort.Sort(addq)
for i, e := range addq {
// start deleting the transactions from the queue if they exceed the limit
if i > maxQueued {
delete(pool.queue[address], e.hash)
continue continue
} }
// Collect the remaining transactions for the next pass.
if e.Nonce() > guessedNonce { promote = append(promote, txQueueEntry{hash, address, tx})
if len(addq)-i > maxQueued { }
// Find the next consecutive nonce range starting at the current account nonce,
// pushing the guessed nonce forward if we add consecutive transactions.
sort.Sort(promote)
for i, entry := range promote {
// If we reached a gap in the nonces, enforce transaction limit and stop
if entry.Nonce() > guessedNonce {
if len(promote)-i > maxQueued {
if glog.V(logger.Debug) { if glog.V(logger.Debug) {
glog.Infof("Queued tx limit exceeded for %s. Tx %s removed\n", common.PP(address[:]), common.PP(e.hash[:])) glog.Infof("Queued tx limit exceeded for %s. Tx %s removed\n", common.PP(address[:]), common.PP(entry.hash[:]))
} }
for j := i + maxQueued; j < len(addq); j++ { for _, drop := range promote[i+maxQueued:] {
delete(txs, addq[j].hash) delete(txs, drop.hash)
} }
} }
break break
} }
delete(txs, e.hash) // Otherwise promote the transaction and move the guess nonce if needed
pool.addTx(e.hash, address, e.Transaction) pool.addTx(entry.hash, address, entry.Transaction)
delete(txs, entry.hash)
if entry.Nonce() == guessedNonce {
guessedNonce++
}
} }
// Delete the entire queue entry if it became empty. // Delete the entire queue entry if it became empty.
if len(txs) == 0 { if len(txs) == 0 {
@ -460,20 +461,56 @@ func (pool *TxPool) checkQueue() {
} }
// validatePool removes invalid and processed transactions from the main pool. // validatePool removes invalid and processed transactions from the main pool.
// If a transaction is removed for being invalid (e.g. out of funds), all sub-
// sequent (Still valid) transactions are moved back into the future queue. This
// is important to prevent a drained account from DOSing the network with non
// executable transactions.
func (pool *TxPool) validatePool() { func (pool *TxPool) validatePool() {
state, err := pool.currentState() state, err := pool.currentState()
if err != nil { if err != nil {
glog.V(logger.Info).Infoln("failed to get current state: %v", err) glog.V(logger.Info).Infoln("failed to get current state: %v", err)
return return
} }
balanceCache := make(map[common.Address]*big.Int)
// Clean up the pending pool, accumulating invalid nonces
gaps := make(map[common.Address]uint64)
for hash, tx := range pool.pending { for hash, tx := range pool.pending {
from, _ := tx.From() // err already checked sender, _ := tx.From() // err already checked
// perform light nonce validation
if state.GetNonce(from) > tx.Nonce() { // Perform light nonce and balance validation
balance := balanceCache[sender]
if balance == nil {
balance = state.GetBalance(sender)
balanceCache[sender] = balance
}
if past := state.GetNonce(sender) > tx.Nonce(); past || balance.Cmp(tx.Cost()) < 0 {
// Remove an already past it invalidated transaction
if glog.V(logger.Core) { if glog.V(logger.Core) {
glog.Infof("removed tx (%x) from pool: low tx nonce\n", hash[:4]) glog.Infof("removed tx (%v) from pool: low tx nonce or out of funds\n", tx)
} }
delete(pool.pending, hash) delete(pool.pending, hash)
// Track the smallest invalid nonce to postpone subsequent transactions
if !past {
if prev, ok := gaps[sender]; !ok || tx.Nonce() < prev {
gaps[sender] = tx.Nonce()
}
}
}
}
// Move all transactions after a gap back to the future queue
if len(gaps) > 0 {
for hash, tx := range pool.pending {
sender, _ := tx.From()
if gap, ok := gaps[sender]; ok && tx.Nonce() >= gap {
if glog.V(logger.Core) {
glog.Infof("postponed tx (%v) due to introduced gap\n", tx)
}
pool.queueTx(hash, tx)
delete(pool.pending, hash)
}
} }
} }
} }

View File

@ -79,7 +79,7 @@ func TestTransactionQueue(t *testing.T) {
tx := transaction(0, big.NewInt(100), key) tx := transaction(0, big.NewInt(100), key)
from, _ := tx.From() from, _ := tx.From()
currentState, _ := pool.currentState() currentState, _ := pool.currentState()
currentState.AddBalance(from, big.NewInt(1)) currentState.AddBalance(from, big.NewInt(1000))
pool.queueTx(tx.Hash(), tx) pool.queueTx(tx.Hash(), tx)
pool.checkQueue() pool.checkQueue()
@ -104,15 +104,17 @@ func TestTransactionQueue(t *testing.T) {
tx1 := transaction(0, big.NewInt(100), key) tx1 := transaction(0, big.NewInt(100), key)
tx2 := transaction(10, big.NewInt(100), key) tx2 := transaction(10, big.NewInt(100), key)
tx3 := transaction(11, big.NewInt(100), key) tx3 := transaction(11, big.NewInt(100), key)
from, _ = tx1.From()
currentState, _ = pool.currentState()
currentState.AddBalance(from, big.NewInt(1000))
pool.queueTx(tx1.Hash(), tx1) pool.queueTx(tx1.Hash(), tx1)
pool.queueTx(tx2.Hash(), tx2) pool.queueTx(tx2.Hash(), tx2)
pool.queueTx(tx3.Hash(), tx3) pool.queueTx(tx3.Hash(), tx3)
from, _ = tx1.From()
pool.checkQueue() pool.checkQueue()
if len(pool.pending) != 1 { if len(pool.pending) != 1 {
t.Error("expected tx pool to be 1 =") t.Error("expected tx pool to be 1, got", len(pool.pending))
} }
if len(pool.queue[from]) != 2 { if len(pool.queue[from]) != 2 {
t.Error("expected len(queue) == 2, got", len(pool.queue[from])) t.Error("expected len(queue) == 2, got", len(pool.queue[from]))
@ -261,3 +263,264 @@ func TestRemovedTxEvent(t *testing.T) {
t.Error("expected 1 pending tx, got", len(pool.pending)) t.Error("expected 1 pending tx, got", len(pool.pending))
} }
} }
// Tests that if an account runs out of funds, any pending and queued transactions
// are dropped.
func TestTransactionDropping(t *testing.T) {
// Create a test account and fund it
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000))
// Add some pending and some queued transactions
var (
tx0 = transaction(0, big.NewInt(100), key)
tx1 = transaction(1, big.NewInt(200), key)
tx10 = transaction(10, big.NewInt(100), key)
tx11 = transaction(11, big.NewInt(200), key)
)
pool.addTx(tx0.Hash(), account, tx0)
pool.addTx(tx1.Hash(), account, tx1)
pool.queueTx(tx10.Hash(), tx10)
pool.queueTx(tx11.Hash(), tx11)
// Check that pre and post validations leave the pool as is
if len(pool.pending) != 2 {
t.Errorf("pending transaction mismatch: have %d, want %d", len(pool.pending), 2)
}
if len(pool.queue[account]) != 2 {
t.Errorf("queued transaction mismatch: have %d, want %d", len(pool.queue), 2)
}
pool.resetState()
if len(pool.pending) != 2 {
t.Errorf("pending transaction mismatch: have %d, want %d", len(pool.pending), 2)
}
if len(pool.queue[account]) != 2 {
t.Errorf("queued transaction mismatch: have %d, want %d", len(pool.queue), 2)
}
// Reduce the balance of the account, and check that invalidated transactions are dropped
state.AddBalance(account, big.NewInt(-750))
pool.resetState()
if _, ok := pool.pending[tx0.Hash()]; !ok {
t.Errorf("funded pending transaction missing: %v", tx0)
}
if _, ok := pool.pending[tx1.Hash()]; ok {
t.Errorf("out-of-fund pending transaction present: %v", tx1)
}
if _, ok := pool.queue[account][tx10.Hash()]; !ok {
t.Errorf("funded queued transaction missing: %v", tx10)
}
if _, ok := pool.queue[account][tx11.Hash()]; ok {
t.Errorf("out-of-fund queued transaction present: %v", tx11)
}
}
// Tests that if a transaction is dropped from the current pending pool (e.g. out
// of fund), all consecutive (still valid, but not executable) transactions are
// postponed back into the future queue to prevent broadcating them.
func TestTransactionPostponing(t *testing.T) {
// Create a test account and fund it
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000))
// Add a batch consecutive pending transactions for validation
txns := []*types.Transaction{}
for i := 0; i < 100; i++ {
var tx *types.Transaction
if i%2 == 0 {
tx = transaction(uint64(i), big.NewInt(100), key)
} else {
tx = transaction(uint64(i), big.NewInt(500), key)
}
pool.addTx(tx.Hash(), account, tx)
txns = append(txns, tx)
}
// Check that pre and post validations leave the pool as is
if len(pool.pending) != len(txns) {
t.Errorf("pending transaction mismatch: have %d, want %d", len(pool.pending), len(txns))
}
if len(pool.queue[account]) != 0 {
t.Errorf("queued transaction mismatch: have %d, want %d", len(pool.queue), 0)
}
pool.resetState()
if len(pool.pending) != len(txns) {
t.Errorf("pending transaction mismatch: have %d, want %d", len(pool.pending), len(txns))
}
if len(pool.queue[account]) != 0 {
t.Errorf("queued transaction mismatch: have %d, want %d", len(pool.queue), 0)
}
// Reduce the balance of the account, and check that transactions are reorganized
state.AddBalance(account, big.NewInt(-750))
pool.resetState()
if _, ok := pool.pending[txns[0].Hash()]; !ok {
t.Errorf("tx %d: valid and funded transaction missing from pending pool: %v", 0, txns[0])
}
if _, ok := pool.queue[account][txns[0].Hash()]; ok {
t.Errorf("tx %d: valid and funded transaction present in future queue: %v", 0, txns[0])
}
for i, tx := range txns[1:] {
if i%2 == 1 {
if _, ok := pool.pending[tx.Hash()]; ok {
t.Errorf("tx %d: valid but future transaction present in pending pool: %v", i+1, tx)
}
if _, ok := pool.queue[account][tx.Hash()]; !ok {
t.Errorf("tx %d: valid but future transaction missing from future queue: %v", i+1, tx)
}
} else {
if _, ok := pool.pending[tx.Hash()]; ok {
t.Errorf("tx %d: out-of-fund transaction present in pending pool: %v", i+1, tx)
}
if _, ok := pool.queue[account][tx.Hash()]; ok {
t.Errorf("tx %d: out-of-fund transaction present in future queue: %v", i+1, tx)
}
}
}
}
// Tests that if the transaction count belonging to a single account goes above
// some threshold, the higher transactions are dropped to prevent DOS attacks.
func TestTransactionQueueLimiting(t *testing.T) {
// Create a test account and fund it
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000000))
// Keep queuing up transactions and make sure all above a limit are dropped
for i := uint64(1); i <= maxQueued+5; i++ {
if err := pool.Add(transaction(i, big.NewInt(100000), key)); err != nil {
t.Fatalf("tx %d: failed to add transaction: %v", i, err)
}
if len(pool.pending) != 0 {
t.Errorf("tx %d: pending pool size mismatch: have %d, want %d", i, len(pool.pending), 0)
}
if i <= maxQueued {
if len(pool.queue[account]) != int(i) {
t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, len(pool.queue[account]), i)
}
} else {
if len(pool.queue[account]) != maxQueued {
t.Errorf("tx %d: queue limit mismatch: have %d, want %d", i, len(pool.queue[account]), maxQueued)
}
}
}
}
// Tests that even if the transaction count belonging to a single account goes
// above some threshold, as long as the transactions are executable, they are
// accepted.
func TestTransactionPendingLimiting(t *testing.T) {
// Create a test account and fund it
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000000))
// Keep queuing up transactions and make sure all above a limit are dropped
for i := uint64(0); i < maxQueued+5; i++ {
if err := pool.Add(transaction(i, big.NewInt(100000), key)); err != nil {
t.Fatalf("tx %d: failed to add transaction: %v", i, err)
}
if len(pool.pending) != int(i)+1 {
t.Errorf("tx %d: pending pool size mismatch: have %d, want %d", i, len(pool.pending), i+1)
}
if len(pool.queue[account]) != 0 {
t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, len(pool.queue[account]), 0)
}
}
}
// Tests that the transaction limits are enforced the same way irrelevant whether
// the transactions are added one by one or in batches.
func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) }
func TestTransactionPendingLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 0) }
func testTransactionLimitingEquivalency(t *testing.T, origin uint64) {
// Add a batch of transactions to a pool one by one
pool1, key1 := setupTxPool()
account1, _ := transaction(0, big.NewInt(0), key1).From()
state1, _ := pool1.currentState()
state1.AddBalance(account1, big.NewInt(1000000))
for i := uint64(0); i < maxQueued+5; i++ {
if err := pool1.Add(transaction(origin+i, big.NewInt(100000), key1)); err != nil {
t.Fatalf("tx %d: failed to add transaction: %v", i, err)
}
}
// Add a batch of transactions to a pool in one bit batch
pool2, key2 := setupTxPool()
account2, _ := transaction(0, big.NewInt(0), key2).From()
state2, _ := pool2.currentState()
state2.AddBalance(account2, big.NewInt(1000000))
txns := []*types.Transaction{}
for i := uint64(0); i < maxQueued+5; i++ {
txns = append(txns, transaction(origin+i, big.NewInt(100000), key2))
}
pool2.AddTransactions(txns)
// Ensure the batch optimization honors the same pool mechanics
if len(pool1.pending) != len(pool2.pending) {
t.Errorf("pending transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.pending), len(pool2.pending))
}
if len(pool1.queue[account1]) != len(pool2.queue[account2]) {
t.Errorf("queued transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.queue[account1]), len(pool2.queue[account2]))
}
}
// Benchmarks the speed of validating the contents of the pending queue of the
// transaction pool.
func BenchmarkValidatePool100(b *testing.B) { benchmarkValidatePool(b, 100) }
func BenchmarkValidatePool1000(b *testing.B) { benchmarkValidatePool(b, 1000) }
func BenchmarkValidatePool10000(b *testing.B) { benchmarkValidatePool(b, 10000) }
func benchmarkValidatePool(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000000))
for i := 0; i < size; i++ {
tx := transaction(uint64(i), big.NewInt(100000), key)
pool.addTx(tx.Hash(), account, tx)
}
// Benchmark the speed of pool validation
b.ResetTimer()
for i := 0; i < b.N; i++ {
pool.validatePool()
}
}
// Benchmarks the speed of scheduling the contents of the future queue of the
// transaction pool.
func BenchmarkCheckQueue100(b *testing.B) { benchmarkCheckQueue(b, 100) }
func BenchmarkCheckQueue1000(b *testing.B) { benchmarkCheckQueue(b, 1000) }
func BenchmarkCheckQueue10000(b *testing.B) { benchmarkCheckQueue(b, 10000) }
func benchmarkCheckQueue(b *testing.B, size int) {
// Add a batch of transactions to a pool one by one
pool, key := setupTxPool()
account, _ := transaction(0, big.NewInt(0), key).From()
state, _ := pool.currentState()
state.AddBalance(account, big.NewInt(1000000))
for i := 0; i < size; i++ {
tx := transaction(uint64(1+i), big.NewInt(100000), key)
pool.queueTx(tx.Hash(), tx)
}
// Benchmark the speed of pool validation
b.ResetTimer()
for i := 0; i < b.N; i++ {
pool.checkQueue()
}
}

View File

@ -112,20 +112,20 @@ type Downloader struct {
syncStatsLock sync.RWMutex // Lock protecting the sync stats fields syncStatsLock sync.RWMutex // Lock protecting the sync stats fields
// Callbacks // Callbacks
hasHeader headerCheckFn // Checks if a header is present in the chain hasHeader headerCheckFn // Checks if a header is present in the chain
hasBlock blockCheckFn // Checks if a block is present in the chain hasBlockAndState blockAndStateCheckFn // Checks if a block and associated state is present in the chain
getHeader headerRetrievalFn // Retrieves a header from the chain getHeader headerRetrievalFn // Retrieves a header from the chain
getBlock blockRetrievalFn // Retrieves a block from the chain getBlock blockRetrievalFn // Retrieves a block from the chain
headHeader headHeaderRetrievalFn // Retrieves the head header from the chain headHeader headHeaderRetrievalFn // Retrieves the head header from the chain
headBlock headBlockRetrievalFn // Retrieves the head block from the chain headBlock headBlockRetrievalFn // Retrieves the head block from the chain
headFastBlock headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain headFastBlock headFastBlockRetrievalFn // Retrieves the head fast-sync block from the chain
commitHeadBlock headBlockCommitterFn // Commits a manually assembled block as the chain head commitHeadBlock headBlockCommitterFn // Commits a manually assembled block as the chain head
getTd tdRetrievalFn // Retrieves the TD of a block from the chain getTd tdRetrievalFn // Retrieves the TD of a block from the chain
insertHeaders headerChainInsertFn // Injects a batch of headers into the chain insertHeaders headerChainInsertFn // Injects a batch of headers into the chain
insertBlocks blockChainInsertFn // Injects a batch of blocks into the chain insertBlocks blockChainInsertFn // Injects a batch of blocks into the chain
insertReceipts receiptChainInsertFn // Injects a batch of blocks and their receipts into the chain insertReceipts receiptChainInsertFn // Injects a batch of blocks and their receipts into the chain
rollback chainRollbackFn // Removes a batch of recently added chain links rollback chainRollbackFn // Removes a batch of recently added chain links
dropPeer peerDropFn // Drops a peer for misbehaving dropPeer peerDropFn // Drops a peer for misbehaving
// Status // Status
synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
@ -156,41 +156,41 @@ type Downloader struct {
} }
// New creates a new downloader to fetch hashes and blocks from remote peers. // New creates a new downloader to fetch hashes and blocks from remote peers.
func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn, func New(stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlockAndState blockAndStateCheckFn,
getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn, getHeader headerRetrievalFn, getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn,
commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn, headFastBlock headFastBlockRetrievalFn, commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn,
insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader { insertBlocks blockChainInsertFn, insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
return &Downloader{ return &Downloader{
mode: FullSync, mode: FullSync,
mux: mux, mux: mux,
queue: newQueue(stateDb), queue: newQueue(stateDb),
peers: newPeerSet(), peers: newPeerSet(),
hasHeader: hasHeader, hasHeader: hasHeader,
hasBlock: hasBlock, hasBlockAndState: hasBlockAndState,
getHeader: getHeader, getHeader: getHeader,
getBlock: getBlock, getBlock: getBlock,
headHeader: headHeader, headHeader: headHeader,
headBlock: headBlock, headBlock: headBlock,
headFastBlock: headFastBlock, headFastBlock: headFastBlock,
commitHeadBlock: commitHeadBlock, commitHeadBlock: commitHeadBlock,
getTd: getTd, getTd: getTd,
insertHeaders: insertHeaders, insertHeaders: insertHeaders,
insertBlocks: insertBlocks, insertBlocks: insertBlocks,
insertReceipts: insertReceipts, insertReceipts: insertReceipts,
rollback: rollback, rollback: rollback,
dropPeer: dropPeer, dropPeer: dropPeer,
newPeerCh: make(chan *peer, 1), newPeerCh: make(chan *peer, 1),
hashCh: make(chan dataPack, 1), hashCh: make(chan dataPack, 1),
blockCh: make(chan dataPack, 1), blockCh: make(chan dataPack, 1),
headerCh: make(chan dataPack, 1), headerCh: make(chan dataPack, 1),
bodyCh: make(chan dataPack, 1), bodyCh: make(chan dataPack, 1),
receiptCh: make(chan dataPack, 1), receiptCh: make(chan dataPack, 1),
stateCh: make(chan dataPack, 1), stateCh: make(chan dataPack, 1),
blockWakeCh: make(chan bool, 1), blockWakeCh: make(chan bool, 1),
bodyWakeCh: make(chan bool, 1), bodyWakeCh: make(chan bool, 1),
receiptWakeCh: make(chan bool, 1), receiptWakeCh: make(chan bool, 1),
stateWakeCh: make(chan bool, 1), stateWakeCh: make(chan bool, 1),
} }
} }
@ -564,7 +564,7 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
// Check if a common ancestor was found // Check if a common ancestor was found
finished = true finished = true
for i := len(hashes) - 1; i >= 0; i-- { for i := len(hashes) - 1; i >= 0; i-- {
if d.hasBlock(hashes[i]) { if d.hasBlockAndState(hashes[i]) {
number, hash = uint64(from)+uint64(i), hashes[i] number, hash = uint64(from)+uint64(i), hashes[i]
break break
} }
@ -620,11 +620,11 @@ func (d *Downloader) findAncestor61(p *peer) (uint64, error) {
arrived = true arrived = true
// Modify the search interval based on the response // Modify the search interval based on the response
block := d.getBlock(hashes[0]) if !d.hasBlockAndState(hashes[0]) {
if block == nil {
end = check end = check
break break
} }
block := d.getBlock(hashes[0]) // this doesn't check state, hence the above explicit check
if block.NumberU64() != check { if block.NumberU64() != check {
glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check) glog.V(logger.Debug).Infof("%v: non requested hash #%d [%x…], instead of #%d", p, block.NumberU64(), block.Hash().Bytes()[:4], check)
return 0, errBadPeer return 0, errBadPeer
@ -989,7 +989,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
// Check if a common ancestor was found // Check if a common ancestor was found
finished = true finished = true
for i := len(headers) - 1; i >= 0; i-- { for i := len(headers) - 1; i >= 0; i-- {
if (d.mode != LightSync && d.hasBlock(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) { if (d.mode != LightSync && d.hasBlockAndState(headers[i].Hash())) || (d.mode == LightSync && d.hasHeader(headers[i].Hash())) {
number, hash = headers[i].Number.Uint64(), headers[i].Hash() number, hash = headers[i].Number.Uint64(), headers[i].Hash()
break break
} }
@ -1045,7 +1045,7 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
arrived = true arrived = true
// Modify the search interval based on the response // Modify the search interval based on the response
if (d.mode == FullSync && !d.hasBlock(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) { if (d.mode == FullSync && !d.hasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.hasHeader(headers[0].Hash())) {
end = check end = check
break break
} }

View File

@ -150,6 +150,8 @@ func newTester() *downloadTester {
peerChainTds: make(map[string]map[common.Hash]*big.Int), peerChainTds: make(map[string]map[common.Hash]*big.Int),
} }
tester.stateDb, _ = ethdb.NewMemDatabase() tester.stateDb, _ = ethdb.NewMemDatabase()
tester.stateDb.Put(genesis.Root().Bytes(), []byte{0x00})
tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader, tester.downloader = New(tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd, tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer) tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
@ -177,9 +179,14 @@ func (dl *downloadTester) hasHeader(hash common.Hash) bool {
return dl.getHeader(hash) != nil return dl.getHeader(hash) != nil
} }
// hasBlock checks if a block is present in the testers canonical chain. // hasBlock checks if a block and associated state is present in the testers canonical chain.
func (dl *downloadTester) hasBlock(hash common.Hash) bool { func (dl *downloadTester) hasBlock(hash common.Hash) bool {
return dl.getBlock(hash) != nil block := dl.getBlock(hash)
if block == nil {
return false
}
_, err := dl.stateDb.Get(block.Root().Bytes())
return err == nil
} }
// getHeader retrieves a header from the testers canonical chain. // getHeader retrieves a header from the testers canonical chain.
@ -292,8 +299,10 @@ func (dl *downloadTester) insertBlocks(blocks types.Blocks) (int, error) {
defer dl.lock.Unlock() defer dl.lock.Unlock()
for i, block := range blocks { for i, block := range blocks {
if _, ok := dl.ownBlocks[block.ParentHash()]; !ok { if parent, ok := dl.ownBlocks[block.ParentHash()]; !ok {
return i, errors.New("unknown parent") return i, errors.New("unknown parent")
} else if _, err := dl.stateDb.Get(parent.Root().Bytes()); err != nil {
return i, fmt.Errorf("unknown parent state %x: %v", parent.Root(), err)
} }
if _, ok := dl.ownHeaders[block.Hash()]; !ok { if _, ok := dl.ownHeaders[block.Hash()]; !ok {
dl.ownHashes = append(dl.ownHashes, block.Hash()) dl.ownHashes = append(dl.ownHashes, block.Hash())
@ -1102,6 +1111,8 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
} }
// Tests that upon detecting an invalid header, the recent ones are rolled back // Tests that upon detecting an invalid header, the recent ones are rolled back
// for various failure scenarios. Afterwards a full sync is attempted to make
// sure no state was corrupted.
func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) } func TestInvalidHeaderRollback63Fast(t *testing.T) { testInvalidHeaderRollback(t, 63, FastSync) }
func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) } func TestInvalidHeaderRollback64Fast(t *testing.T) { testInvalidHeaderRollback(t, 64, FastSync) }
func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) } func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }

View File

@ -27,8 +27,8 @@ import (
// headerCheckFn is a callback type for verifying a header's presence in the local chain. // headerCheckFn is a callback type for verifying a header's presence in the local chain.
type headerCheckFn func(common.Hash) bool type headerCheckFn func(common.Hash) bool
// blockCheckFn is a callback type for verifying a block's presence in the local chain. // blockAndStateCheckFn is a callback type for verifying block and associated states' presence in the local chain.
type blockCheckFn func(common.Hash) bool type blockAndStateCheckFn func(common.Hash) bool
// headerRetrievalFn is a callback type for retrieving a header from the local chain. // headerRetrievalFn is a callback type for retrieving a header from the local chain.
type headerRetrievalFn func(common.Hash) *types.Header type headerRetrievalFn func(common.Hash) *types.Header

View File

@ -138,9 +138,10 @@ func NewProtocolManager(fastSync bool, networkId int, mux *event.TypeMux, txpool
return nil, errIncompatibleConfig return nil, errIncompatibleConfig
} }
// Construct the different synchronisation mechanisms // Construct the different synchronisation mechanisms
manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader, blockchain.GetBlock, manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlockAndState, blockchain.GetHeader,
blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead, blockchain.GetTd, blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer) blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback,
manager.removePeer)
validator := func(block *types.Block, parent *types.Block) error { validator := func(block *types.Block, parent *types.Block) error {
return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false) return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)