core, eth: roll back uncertain headers in failed fast syncs
This commit is contained in:
		@@ -245,7 +245,21 @@ func (bc *BlockChain) SetHead(head uint64) {
 | 
				
			|||||||
	if bc.currentBlock == nil {
 | 
						if bc.currentBlock == nil {
 | 
				
			||||||
		bc.currentBlock = bc.genesisBlock
 | 
							bc.currentBlock = bc.genesisBlock
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	bc.insert(bc.currentBlock)
 | 
						if bc.currentHeader == nil {
 | 
				
			||||||
 | 
							bc.currentHeader = bc.genesisBlock.Header()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if bc.currentFastBlock == nil {
 | 
				
			||||||
 | 
							bc.currentFastBlock = bc.genesisBlock
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil {
 | 
				
			||||||
 | 
							glog.Fatalf("failed to reset head block hash: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := WriteHeadHeaderHash(bc.chainDb, bc.currentHeader.Hash()); err != nil {
 | 
				
			||||||
 | 
							glog.Fatalf("failed to reset head header hash: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil {
 | 
				
			||||||
 | 
							glog.Fatalf("failed to reset head fast block hash: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	bc.loadLastState()
 | 
						bc.loadLastState()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -790,6 +804,27 @@ func (self *BlockChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
 | 
				
			|||||||
	return 0, nil
 | 
						return 0, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Rollback is designed to remove a chain of links from the database that aren't
 | 
				
			||||||
 | 
					// certain enough to be valid.
 | 
				
			||||||
 | 
					func (self *BlockChain) Rollback(chain []common.Hash) {
 | 
				
			||||||
 | 
						for i := len(chain) - 1; i >= 0; i-- {
 | 
				
			||||||
 | 
							hash := chain[i]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if self.currentHeader.Hash() == hash {
 | 
				
			||||||
 | 
								self.currentHeader = self.GetHeader(self.currentHeader.ParentHash)
 | 
				
			||||||
 | 
								WriteHeadHeaderHash(self.chainDb, self.currentHeader.Hash())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if self.currentFastBlock.Hash() == hash {
 | 
				
			||||||
 | 
								self.currentFastBlock = self.GetBlock(self.currentFastBlock.ParentHash())
 | 
				
			||||||
 | 
								WriteHeadFastBlockHash(self.chainDb, self.currentFastBlock.Hash())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if self.currentBlock.Hash() == hash {
 | 
				
			||||||
 | 
								self.currentBlock = self.GetBlock(self.currentBlock.ParentHash())
 | 
				
			||||||
 | 
								WriteHeadBlockHash(self.chainDb, self.currentBlock.Hash())
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// InsertReceiptChain attempts to complete an already existing header chain with
 | 
					// InsertReceiptChain attempts to complete an already existing header chain with
 | 
				
			||||||
// transaction and receipt data.
 | 
					// transaction and receipt data.
 | 
				
			||||||
func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain []types.Receipts) (int, error) {
 | 
					func (self *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain []types.Receipts) (int, error) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -599,8 +599,8 @@ func testReorgBadHashes(t *testing.T, full bool) {
 | 
				
			|||||||
			t.Errorf("last  block gasLimit mismatch: have: %x, want %x", ncm.GasLimit(), blocks[2].Header().GasLimit)
 | 
								t.Errorf("last  block gasLimit mismatch: have: %x, want %x", ncm.GasLimit(), blocks[2].Header().GasLimit)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if ncm.CurrentHeader().Hash() != genesis.Hash() {
 | 
							if ncm.CurrentHeader().Hash() != headers[2].Hash() {
 | 
				
			||||||
			t.Errorf("last header hash mismatch: have: %x, want %x", ncm.CurrentHeader().Hash(), genesis.Hash())
 | 
								t.Errorf("last header hash mismatch: have: %x, want %x", ncm.CurrentHeader().Hash(), headers[2].Hash())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -775,6 +775,11 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 | 
				
			|||||||
	height := uint64(1024)
 | 
						height := uint64(1024)
 | 
				
			||||||
	blocks, receipts := GenerateChain(genesis, gendb, int(height), nil)
 | 
						blocks, receipts := GenerateChain(genesis, gendb, int(height), nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Configure a subchain to roll back
 | 
				
			||||||
 | 
						remove := []common.Hash{}
 | 
				
			||||||
 | 
						for _, block := range blocks[height/2:] {
 | 
				
			||||||
 | 
							remove = append(remove, block.Hash())
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	// Create a small assertion method to check the three heads
 | 
						// Create a small assertion method to check the three heads
 | 
				
			||||||
	assert := func(t *testing.T, kind string, chain *BlockChain, header uint64, fast uint64, block uint64) {
 | 
						assert := func(t *testing.T, kind string, chain *BlockChain, header uint64, fast uint64, block uint64) {
 | 
				
			||||||
		if num := chain.CurrentBlock().NumberU64(); num != block {
 | 
							if num := chain.CurrentBlock().NumberU64(); num != block {
 | 
				
			||||||
@@ -798,6 +803,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to process block %d: %v", n, err)
 | 
							t.Fatalf("failed to process block %d: %v", n, err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	assert(t, "archive", archive, height, height, height)
 | 
						assert(t, "archive", archive, height, height, height)
 | 
				
			||||||
 | 
						archive.Rollback(remove)
 | 
				
			||||||
 | 
						assert(t, "archive", archive, height/2, height/2, height/2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Import the chain as a non-archive node and ensure all pointers are updated
 | 
						// Import the chain as a non-archive node and ensure all pointers are updated
 | 
				
			||||||
	fastDb, _ := ethdb.NewMemDatabase()
 | 
						fastDb, _ := ethdb.NewMemDatabase()
 | 
				
			||||||
@@ -816,6 +823,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to insert receipt %d: %v", n, err)
 | 
							t.Fatalf("failed to insert receipt %d: %v", n, err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	assert(t, "fast", fast, height, height, 0)
 | 
						assert(t, "fast", fast, height, height, 0)
 | 
				
			||||||
 | 
						fast.Rollback(remove)
 | 
				
			||||||
 | 
						assert(t, "fast", fast, height/2, height/2, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Import the chain as a light node and ensure all pointers are updated
 | 
						// Import the chain as a light node and ensure all pointers are updated
 | 
				
			||||||
	lightDb, _ := ethdb.NewMemDatabase()
 | 
						lightDb, _ := ethdb.NewMemDatabase()
 | 
				
			||||||
@@ -827,6 +836,8 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 | 
				
			|||||||
		t.Fatalf("failed to insert header %d: %v", n, err)
 | 
							t.Fatalf("failed to insert header %d: %v", n, err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	assert(t, "light", light, height, 0, 0)
 | 
						assert(t, "light", light, height, 0, 0)
 | 
				
			||||||
 | 
						light.Rollback(remove)
 | 
				
			||||||
 | 
						assert(t, "light", light, height/2, 0, 0)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Tests that chain reorganizations handle transaction removals and reinsertions.
 | 
					// Tests that chain reorganizations handle transaction removals and reinsertions.
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -59,8 +59,8 @@ var (
 | 
				
			|||||||
	maxQueuedStates   = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection)
 | 
						maxQueuedStates   = 256 * 1024 // [eth/63] Maximum number of state requests to queue (DOS protection)
 | 
				
			||||||
	maxResultsProcess = 256        // Number of download results to import at once into the chain
 | 
						maxResultsProcess = 256        // Number of download results to import at once into the chain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	headerCheckFrequency = 64   // Verification frequency of the downloaded headers during fast sync
 | 
						headerCheckFrequency = 100  // Verification frequency of the downloaded headers during fast sync
 | 
				
			||||||
	minCheckedHeaders    = 1024 // Number of headers to verify fully when approaching the chain head
 | 
						minCheckedHeaders    = 2048 // Number of headers to verify fully when approaching the chain head
 | 
				
			||||||
	minFullBlocks        = 1024 // Number of blocks to retrieve fully even in fast sync
 | 
						minFullBlocks        = 1024 // Number of blocks to retrieve fully even in fast sync
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -117,6 +117,7 @@ type Downloader struct {
 | 
				
			|||||||
	insertHeaders   headerChainInsertFn      // Injects a batch of headers into the chain
 | 
						insertHeaders   headerChainInsertFn      // Injects a batch of headers into the chain
 | 
				
			||||||
	insertBlocks    blockChainInsertFn       // Injects a batch of blocks into the chain
 | 
						insertBlocks    blockChainInsertFn       // Injects a batch of blocks into the chain
 | 
				
			||||||
	insertReceipts  receiptChainInsertFn     // Injects a batch of blocks and their receipts into the chain
 | 
						insertReceipts  receiptChainInsertFn     // Injects a batch of blocks and their receipts into the chain
 | 
				
			||||||
 | 
						rollback        chainRollbackFn          // Removes a batch of recently added chain links
 | 
				
			||||||
	dropPeer        peerDropFn               // Drops a peer for misbehaving
 | 
						dropPeer        peerDropFn               // Drops a peer for misbehaving
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Status
 | 
						// Status
 | 
				
			||||||
@@ -152,7 +153,7 @@ type Downloader struct {
 | 
				
			|||||||
func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn,
 | 
					func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader headerCheckFn, hasBlock blockCheckFn, getHeader headerRetrievalFn,
 | 
				
			||||||
	getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn,
 | 
						getBlock blockRetrievalFn, headHeader headHeaderRetrievalFn, headBlock headBlockRetrievalFn, headFastBlock headFastBlockRetrievalFn,
 | 
				
			||||||
	commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn,
 | 
						commitHeadBlock headBlockCommitterFn, getTd tdRetrievalFn, insertHeaders headerChainInsertFn, insertBlocks blockChainInsertFn,
 | 
				
			||||||
	insertReceipts receiptChainInsertFn, dropPeer peerDropFn) *Downloader {
 | 
						insertReceipts receiptChainInsertFn, rollback chainRollbackFn, dropPeer peerDropFn) *Downloader {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &Downloader{
 | 
						return &Downloader{
 | 
				
			||||||
		mode:            mode,
 | 
							mode:            mode,
 | 
				
			||||||
@@ -171,6 +172,7 @@ func New(mode SyncMode, stateDb ethdb.Database, mux *event.TypeMux, hasHeader he
 | 
				
			|||||||
		insertHeaders:   insertHeaders,
 | 
							insertHeaders:   insertHeaders,
 | 
				
			||||||
		insertBlocks:    insertBlocks,
 | 
							insertBlocks:    insertBlocks,
 | 
				
			||||||
		insertReceipts:  insertReceipts,
 | 
							insertReceipts:  insertReceipts,
 | 
				
			||||||
 | 
							rollback:        rollback,
 | 
				
			||||||
		dropPeer:        dropPeer,
 | 
							dropPeer:        dropPeer,
 | 
				
			||||||
		newPeerCh:       make(chan *peer, 1),
 | 
							newPeerCh:       make(chan *peer, 1),
 | 
				
			||||||
		hashCh:          make(chan dataPack, 1),
 | 
							hashCh:          make(chan dataPack, 1),
 | 
				
			||||||
@@ -383,7 +385,7 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e
 | 
				
			|||||||
		d.syncStatsChainHeight = latest
 | 
							d.syncStatsChainHeight = latest
 | 
				
			||||||
		d.syncStatsLock.Unlock()
 | 
							d.syncStatsLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Initiate the sync using a  concurrent header and content retrieval algorithm
 | 
							// Initiate the sync using a concurrent header and content retrieval algorithm
 | 
				
			||||||
		pivot := uint64(0)
 | 
							pivot := uint64(0)
 | 
				
			||||||
		if latest > uint64(minFullBlocks) {
 | 
							if latest > uint64(minFullBlocks) {
 | 
				
			||||||
			pivot = latest - uint64(minFullBlocks)
 | 
								pivot = latest - uint64(minFullBlocks)
 | 
				
			||||||
@@ -394,10 +396,10 @@ func (d *Downloader) syncWithPeer(p *peer, hash common.Hash, td *big.Int) (err e
 | 
				
			|||||||
			d.syncInitHook(origin, latest)
 | 
								d.syncInitHook(origin, latest)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		errc := make(chan error, 4)
 | 
							errc := make(chan error, 4)
 | 
				
			||||||
		go func() { errc <- d.fetchHeaders(p, td, origin+1) }() // Headers are always retrieved
 | 
							go func() { errc <- d.fetchHeaders(p, td, origin+1, latest) }() // Headers are always retrieved
 | 
				
			||||||
		go func() { errc <- d.fetchBodies(origin + 1) }()       // Bodies are retrieved during normal and fast sync
 | 
							go func() { errc <- d.fetchBodies(origin + 1) }()               // Bodies are retrieved during normal and fast sync
 | 
				
			||||||
		go func() { errc <- d.fetchReceipts(origin + 1) }()     // Receipts are retrieved during fast sync
 | 
							go func() { errc <- d.fetchReceipts(origin + 1) }()             // Receipts are retrieved during fast sync
 | 
				
			||||||
		go func() { errc <- d.fetchNodeData() }()               // Node state data is retrieved during fast sync
 | 
							go func() { errc <- d.fetchNodeData() }()                       // Node state data is retrieved during fast sync
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// If any fetcher fails, cancel the others
 | 
							// If any fetcher fails, cancel the others
 | 
				
			||||||
		var fail error
 | 
							var fail error
 | 
				
			||||||
@@ -1049,10 +1051,28 @@ func (d *Downloader) findAncestor(p *peer) (uint64, error) {
 | 
				
			|||||||
//
 | 
					//
 | 
				
			||||||
// The queue parameter can be used to switch between queuing headers for block
 | 
					// The queue parameter can be used to switch between queuing headers for block
 | 
				
			||||||
// body download too, or directly import as pure header chains.
 | 
					// body download too, or directly import as pure header chains.
 | 
				
			||||||
func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error {
 | 
					func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from, latest uint64) error {
 | 
				
			||||||
	glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from)
 | 
						glog.V(logger.Debug).Infof("%v: downloading headers from #%d", p, from)
 | 
				
			||||||
	defer glog.V(logger.Debug).Infof("%v: header download terminated", p)
 | 
						defer glog.V(logger.Debug).Infof("%v: header download terminated", p)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Keep a count of uncertain headers to roll back
 | 
				
			||||||
 | 
						rollback := []*types.Header{}
 | 
				
			||||||
 | 
						defer func() {
 | 
				
			||||||
 | 
							if len(rollback) > 0 {
 | 
				
			||||||
 | 
								hashes := make([]common.Hash, len(rollback))
 | 
				
			||||||
 | 
								for i, header := range rollback {
 | 
				
			||||||
 | 
									hashes[i] = header.Hash()
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								d.rollback(hashes)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
						// Calculate the pivoting point for switching from fast to slow sync
 | 
				
			||||||
 | 
						pivot := uint64(0)
 | 
				
			||||||
 | 
						if d.mode == FastSync && latest > uint64(minFullBlocks) {
 | 
				
			||||||
 | 
							pivot = latest - uint64(minFullBlocks)
 | 
				
			||||||
 | 
						} else if d.mode == LightSync {
 | 
				
			||||||
 | 
							pivot = latest
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	// Create a timeout timer, and the associated hash fetcher
 | 
						// Create a timeout timer, and the associated hash fetcher
 | 
				
			||||||
	request := time.Now()       // time of the last fetch request
 | 
						request := time.Now()       // time of the last fetch request
 | 
				
			||||||
	timeout := time.NewTimer(0) // timer to dump a non-responsive active peer
 | 
						timeout := time.NewTimer(0) // timer to dump a non-responsive active peer
 | 
				
			||||||
@@ -1124,10 +1144,30 @@ func (d *Downloader) fetchHeaders(p *peer, td *big.Int, from uint64) error {
 | 
				
			|||||||
			glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from)
 | 
								glog.V(logger.Detail).Infof("%v: schedule %d headers from #%d", p, len(headers), from)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if d.mode == FastSync || d.mode == LightSync {
 | 
								if d.mode == FastSync || d.mode == LightSync {
 | 
				
			||||||
				if n, err := d.insertHeaders(headers, headerCheckFrequency); err != nil {
 | 
									// Collect the yet unknown headers to mark them as uncertain
 | 
				
			||||||
 | 
									unknown := make([]*types.Header, 0, len(headers))
 | 
				
			||||||
 | 
									for _, header := range headers {
 | 
				
			||||||
 | 
										if !d.hasHeader(header.Hash()) {
 | 
				
			||||||
 | 
											unknown = append(unknown, header)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									// If we're importing pure headers, verify based on their recentness
 | 
				
			||||||
 | 
									frequency := headerCheckFrequency
 | 
				
			||||||
 | 
									if headers[len(headers)-1].Number.Uint64()+uint64(minCheckedHeaders) > pivot {
 | 
				
			||||||
 | 
										frequency = 1
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if n, err := d.insertHeaders(headers, frequency); err != nil {
 | 
				
			||||||
					glog.V(logger.Debug).Infof("%v: invalid header #%d [%x…]: %v", p, headers[n].Number, headers[n].Hash().Bytes()[:4], err)
 | 
										glog.V(logger.Debug).Infof("%v: invalid header #%d [%x…]: %v", p, headers[n].Number, headers[n].Hash().Bytes()[:4], err)
 | 
				
			||||||
					return errInvalidChain
 | 
										return errInvalidChain
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									// All verifications passed, store newly found uncertain headers
 | 
				
			||||||
 | 
									rollback = append(rollback, unknown...)
 | 
				
			||||||
 | 
									if len(rollback) > minCheckedHeaders {
 | 
				
			||||||
 | 
										rollback = append(rollback[:0], rollback[len(rollback)-minCheckedHeaders:]...)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if headers[len(headers)-1].Number.Uint64() >= pivot {
 | 
				
			||||||
 | 
										rollback = rollback[:0]
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if d.mode == FullSync || d.mode == FastSync {
 | 
								if d.mode == FullSync || d.mode == FastSync {
 | 
				
			||||||
				inserts := d.queue.Schedule(headers, from)
 | 
									inserts := d.queue.Schedule(headers, from)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -152,7 +152,7 @@ func newTester(mode SyncMode) *downloadTester {
 | 
				
			|||||||
	tester.stateDb, _ = ethdb.NewMemDatabase()
 | 
						tester.stateDb, _ = ethdb.NewMemDatabase()
 | 
				
			||||||
	tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
 | 
						tester.downloader = New(mode, tester.stateDb, new(event.TypeMux), tester.hasHeader, tester.hasBlock, tester.getHeader,
 | 
				
			||||||
		tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
 | 
							tester.getBlock, tester.headHeader, tester.headBlock, tester.headFastBlock, tester.commitHeadBlock, tester.getTd,
 | 
				
			||||||
		tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.dropPeer)
 | 
							tester.insertHeaders, tester.insertBlocks, tester.insertReceipts, tester.rollback, tester.dropPeer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return tester
 | 
						return tester
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -272,6 +272,16 @@ func (dl *downloadTester) insertHeaders(headers []*types.Header, checkFreq int)
 | 
				
			|||||||
	dl.lock.Lock()
 | 
						dl.lock.Lock()
 | 
				
			||||||
	defer dl.lock.Unlock()
 | 
						defer dl.lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Do a quick check, as the blockchain.InsertHeaderChain doesn't insert anthing in case of errors
 | 
				
			||||||
 | 
						if _, ok := dl.ownHeaders[headers[0].ParentHash]; !ok {
 | 
				
			||||||
 | 
							return 0, errors.New("unknown parent")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for i := 1; i < len(headers); i++ {
 | 
				
			||||||
 | 
							if headers[i].ParentHash != headers[i-1].Hash() {
 | 
				
			||||||
 | 
								return i, errors.New("unknown parent")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Do a full insert if pre-checks passed
 | 
				
			||||||
	for i, header := range headers {
 | 
						for i, header := range headers {
 | 
				
			||||||
		if _, ok := dl.ownHeaders[header.Hash()]; ok {
 | 
							if _, ok := dl.ownHeaders[header.Hash()]; ok {
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
@@ -322,6 +332,22 @@ func (dl *downloadTester) insertReceipts(blocks types.Blocks, receipts []types.R
 | 
				
			|||||||
	return len(blocks), nil
 | 
						return len(blocks), nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// rollback removes some recently added elements from the chain.
 | 
				
			||||||
 | 
					func (dl *downloadTester) rollback(hashes []common.Hash) {
 | 
				
			||||||
 | 
						dl.lock.Lock()
 | 
				
			||||||
 | 
						defer dl.lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := len(hashes) - 1; i >= 0; i-- {
 | 
				
			||||||
 | 
							if dl.ownHashes[len(dl.ownHashes)-1] == hashes[i] {
 | 
				
			||||||
 | 
								dl.ownHashes = dl.ownHashes[:len(dl.ownHashes)-1]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							delete(dl.ownChainTd, hashes[i])
 | 
				
			||||||
 | 
							delete(dl.ownHeaders, hashes[i])
 | 
				
			||||||
 | 
							delete(dl.ownReceipts, hashes[i])
 | 
				
			||||||
 | 
							delete(dl.ownBlocks, hashes[i])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// newPeer registers a new block download source into the downloader.
 | 
					// newPeer registers a new block download source into the downloader.
 | 
				
			||||||
func (dl *downloadTester) newPeer(id string, version int, hashes []common.Hash, headers map[common.Hash]*types.Header, blocks map[common.Hash]*types.Block, receipts map[common.Hash]types.Receipts) error {
 | 
					func (dl *downloadTester) newPeer(id string, version int, hashes []common.Hash, headers map[common.Hash]*types.Header, blocks map[common.Hash]*types.Block, receipts map[common.Hash]types.Receipts) error {
 | 
				
			||||||
	return dl.newSlowPeer(id, version, hashes, headers, blocks, receipts, 0)
 | 
						return dl.newSlowPeer(id, version, hashes, headers, blocks, receipts, 0)
 | 
				
			||||||
@@ -1031,6 +1057,56 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
 | 
				
			|||||||
	assertOwnChain(t, tester, targetBlocks+1)
 | 
						assertOwnChain(t, tester, targetBlocks+1)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tests that upon detecting an invalid header, the recent ones are rolled back
 | 
				
			||||||
 | 
					func TestInvalidHeaderRollback63Fast(t *testing.T)  { testInvalidHeaderRollback(t, 63, FastSync) }
 | 
				
			||||||
 | 
					func TestInvalidHeaderRollback64Fast(t *testing.T)  { testInvalidHeaderRollback(t, 64, FastSync) }
 | 
				
			||||||
 | 
					func TestInvalidHeaderRollback64Light(t *testing.T) { testInvalidHeaderRollback(t, 64, LightSync) }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 | 
				
			||||||
 | 
						// Create a small enough block chain to download
 | 
				
			||||||
 | 
						targetBlocks := 3*minCheckedHeaders + minFullBlocks
 | 
				
			||||||
 | 
						hashes, headers, blocks, receipts := makeChain(targetBlocks, 0, genesis, nil)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tester := newTester(mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Attempt to sync with an attacker that feeds junk during the fast sync phase
 | 
				
			||||||
 | 
						tester.newPeer("fast-attack", protocol, hashes, headers, blocks, receipts)
 | 
				
			||||||
 | 
						missing := minCheckedHeaders + MaxHeaderFetch + 1
 | 
				
			||||||
 | 
						delete(tester.peerHeaders["fast-attack"], hashes[len(hashes)-missing])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := tester.sync("fast-attack", nil); err == nil {
 | 
				
			||||||
 | 
							t.Fatalf("succeeded fast attacker synchronisation")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if head := tester.headHeader().Number.Int64(); int(head) > MaxHeaderFetch {
 | 
				
			||||||
 | 
							t.Fatalf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Attempt to sync with an attacker that feeds junk during the block import phase
 | 
				
			||||||
 | 
						tester.newPeer("block-attack", protocol, hashes, headers, blocks, receipts)
 | 
				
			||||||
 | 
						missing = 3*minCheckedHeaders + MaxHeaderFetch + 1
 | 
				
			||||||
 | 
						delete(tester.peerHeaders["block-attack"], hashes[len(hashes)-missing])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := tester.sync("block-attack", nil); err == nil {
 | 
				
			||||||
 | 
							t.Fatalf("succeeded block attacker synchronisation")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if mode == FastSync {
 | 
				
			||||||
 | 
							// Fast sync should not discard anything below the verified pivot point
 | 
				
			||||||
 | 
							if head := tester.headHeader().Number.Int64(); int(head) < 3*minCheckedHeaders {
 | 
				
			||||||
 | 
								t.Fatalf("rollback head mismatch: have %v, want at least %v", head, 3*minCheckedHeaders)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						} else if mode == LightSync {
 | 
				
			||||||
 | 
							// Light sync should still discard data as before
 | 
				
			||||||
 | 
							if head := tester.headHeader().Number.Int64(); int(head) > 3*minCheckedHeaders {
 | 
				
			||||||
 | 
								t.Fatalf("rollback head mismatch: have %v, want at most %v", head, 3*minCheckedHeaders)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// Synchronise with the valid peer and make sure sync succeeds
 | 
				
			||||||
 | 
						tester.newPeer("valid", protocol, hashes, headers, blocks, receipts)
 | 
				
			||||||
 | 
						if err := tester.sync("valid", nil); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("failed to synchronise blocks: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						assertOwnChain(t, tester, targetBlocks+1)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Tests that if a peer sends an invalid block piece (body or receipt) for a
 | 
					// Tests that if a peer sends an invalid block piece (body or receipt) for a
 | 
				
			||||||
// requested block, it gets dropped immediately by the downloader.
 | 
					// requested block, it gets dropped immediately by the downloader.
 | 
				
			||||||
func TestInvalidContentAttack62(t *testing.T)      { testInvalidContentAttack(t, 62, FullSync) }
 | 
					func TestInvalidContentAttack62(t *testing.T)      { testInvalidContentAttack(t, 62, FullSync) }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -60,6 +60,9 @@ type blockChainInsertFn func(types.Blocks) (int, error)
 | 
				
			|||||||
// receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain.
 | 
					// receiptChainInsertFn is a callback type to insert a batch of receipts into the local chain.
 | 
				
			||||||
type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error)
 | 
					type receiptChainInsertFn func(types.Blocks, []types.Receipts) (int, error)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// chainRollbackFn is a callback type to remove a few recently added elements from the local chain.
 | 
				
			||||||
 | 
					type chainRollbackFn func([]common.Hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// peerDropFn is a callback type for dropping a peer detected as malicious.
 | 
					// peerDropFn is a callback type for dropping a peer detected as malicious.
 | 
				
			||||||
type peerDropFn func(id string)
 | 
					type peerDropFn func(id string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -131,7 +131,7 @@ func NewProtocolManager(mode Mode, networkId int, mux *event.TypeMux, txpool txP
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	manager.downloader = downloader.New(syncMode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader,
 | 
						manager.downloader = downloader.New(syncMode, chaindb, manager.eventMux, blockchain.HasHeader, blockchain.HasBlock, blockchain.GetHeader,
 | 
				
			||||||
		blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
 | 
							blockchain.GetBlock, blockchain.CurrentHeader, blockchain.CurrentBlock, blockchain.CurrentFastBlock, blockchain.FastSyncCommitHead,
 | 
				
			||||||
		blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, manager.removePeer)
 | 
							blockchain.GetTd, blockchain.InsertHeaderChain, blockchain.InsertChain, blockchain.InsertReceiptChain, blockchain.Rollback, manager.removePeer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	validator := func(block *types.Block, parent *types.Block) error {
 | 
						validator := func(block *types.Block, parent *types.Block) error {
 | 
				
			||||||
		return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)
 | 
							return core.ValidateHeader(pow, block.Header(), parent.Header(), true, false)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user