eth/fetcher: make tests thread safe
This commit is contained in:
		@@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"math/big"
 | 
						"math/big"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
	"sync/atomic"
 | 
						"sync/atomic"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -67,15 +68,17 @@ func createBlocksFromHashes(hashes []common.Hash) map[common.Hash]*types.Block {
 | 
				
			|||||||
type fetcherTester struct {
 | 
					type fetcherTester struct {
 | 
				
			||||||
	fetcher *Fetcher
 | 
						fetcher *Fetcher
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ownHashes []common.Hash                // Hash chain belonging to the tester
 | 
						hashes []common.Hash                // Hash chain belonging to the tester
 | 
				
			||||||
	ownBlocks map[common.Hash]*types.Block // Blocks belonging to the tester
 | 
						blocks map[common.Hash]*types.Block // Blocks belonging to the tester
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lock sync.RWMutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// newTester creates a new fetcher test mocker.
 | 
					// newTester creates a new fetcher test mocker.
 | 
				
			||||||
func newTester() *fetcherTester {
 | 
					func newTester() *fetcherTester {
 | 
				
			||||||
	tester := &fetcherTester{
 | 
						tester := &fetcherTester{
 | 
				
			||||||
		ownHashes: []common.Hash{knownHash},
 | 
							hashes: []common.Hash{knownHash},
 | 
				
			||||||
		ownBlocks: map[common.Hash]*types.Block{knownHash: genesis},
 | 
							blocks: map[common.Hash]*types.Block{knownHash: genesis},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	tester.fetcher = New(tester.hasBlock, tester.importBlock, tester.chainHeight)
 | 
						tester.fetcher = New(tester.hasBlock, tester.importBlock, tester.chainHeight)
 | 
				
			||||||
	tester.fetcher.Start()
 | 
						tester.fetcher.Start()
 | 
				
			||||||
@@ -85,29 +88,38 @@ func newTester() *fetcherTester {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// hasBlock checks if a block is pres	ent in the testers canonical chain.
 | 
					// hasBlock checks if a block is pres	ent in the testers canonical chain.
 | 
				
			||||||
func (f *fetcherTester) hasBlock(hash common.Hash) bool {
 | 
					func (f *fetcherTester) hasBlock(hash common.Hash) bool {
 | 
				
			||||||
	_, ok := f.ownBlocks[hash]
 | 
						f.lock.RLock()
 | 
				
			||||||
 | 
						defer f.lock.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						_, ok := f.blocks[hash]
 | 
				
			||||||
	return ok
 | 
						return ok
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// importBlock injects a new blocks into the simulated chain.
 | 
					// importBlock injects a new blocks into the simulated chain.
 | 
				
			||||||
func (f *fetcherTester) importBlock(peer string, block *types.Block) error {
 | 
					func (f *fetcherTester) importBlock(peer string, block *types.Block) error {
 | 
				
			||||||
 | 
						f.lock.Lock()
 | 
				
			||||||
 | 
						defer f.lock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Make sure the parent in known
 | 
						// Make sure the parent in known
 | 
				
			||||||
	if _, ok := f.ownBlocks[block.ParentHash()]; !ok {
 | 
						if _, ok := f.blocks[block.ParentHash()]; !ok {
 | 
				
			||||||
		return errors.New("unknown parent")
 | 
							return errors.New("unknown parent")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Discard any new blocks if the same height already exists
 | 
						// Discard any new blocks if the same height already exists
 | 
				
			||||||
	if block.NumberU64() <= f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64() {
 | 
						if block.NumberU64() <= f.blocks[f.hashes[len(f.hashes)-1]].NumberU64() {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Otherwise build our current chain
 | 
						// Otherwise build our current chain
 | 
				
			||||||
	f.ownHashes = append(f.ownHashes, block.Hash())
 | 
						f.hashes = append(f.hashes, block.Hash())
 | 
				
			||||||
	f.ownBlocks[block.Hash()] = block
 | 
						f.blocks[block.Hash()] = block
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// chainHeight retrieves the current height (block number) of the chain.
 | 
					// chainHeight retrieves the current height (block number) of the chain.
 | 
				
			||||||
func (f *fetcherTester) chainHeight() uint64 {
 | 
					func (f *fetcherTester) chainHeight() uint64 {
 | 
				
			||||||
	return f.ownBlocks[f.ownHashes[len(f.ownHashes)-1]].NumberU64()
 | 
						f.lock.RLock()
 | 
				
			||||||
 | 
						defer f.lock.RUnlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return f.blocks[f.hashes[len(f.hashes)-1]].NumberU64()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// peerFetcher retrieves a fetcher associated with a simulated peer.
 | 
					// peerFetcher retrieves a fetcher associated with a simulated peer.
 | 
				
			||||||
@@ -149,7 +161,7 @@ func TestSequentialAnnouncements(t *testing.T) {
 | 
				
			|||||||
		tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher)
 | 
							tester.fetcher.Notify("valid", hashes[i], time.Now().Add(-arriveTimeout), fetcher)
 | 
				
			||||||
		time.Sleep(50 * time.Millisecond)
 | 
							time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
 | 
						if imported := len(tester.blocks); imported != targetBlocks+1 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -179,7 +191,7 @@ func TestConcurrentAnnouncements(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		time.Sleep(50 * time.Millisecond)
 | 
							time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
 | 
						if imported := len(tester.blocks); imported != targetBlocks+1 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Make sure no blocks were retrieved twice
 | 
						// Make sure no blocks were retrieved twice
 | 
				
			||||||
@@ -207,7 +219,7 @@ func TestOverlappingAnnouncements(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	time.Sleep(overlap * delay)
 | 
						time.Sleep(overlap * delay)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
 | 
						if imported := len(tester.blocks); imported != targetBlocks+1 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -242,7 +254,7 @@ func TestPendingDeduplication(t *testing.T) {
 | 
				
			|||||||
	time.Sleep(delay)
 | 
						time.Sleep(delay)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Check that all blocks were imported and none fetched twice
 | 
						// Check that all blocks were imported and none fetched twice
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != 2 {
 | 
						if imported := len(tester.blocks); imported != 2 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 2)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if int(counter) != 1 {
 | 
						if int(counter) != 1 {
 | 
				
			||||||
@@ -273,7 +285,7 @@ func TestRandomArrivalImport(t *testing.T) {
 | 
				
			|||||||
	tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher)
 | 
						tester.fetcher.Notify("valid", hashes[skip], time.Now().Add(-arriveTimeout), fetcher)
 | 
				
			||||||
	time.Sleep(50 * time.Millisecond)
 | 
						time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
 | 
						if imported := len(tester.blocks); imported != targetBlocks+1 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -301,7 +313,7 @@ func TestQueueGapFill(t *testing.T) {
 | 
				
			|||||||
	tester.fetcher.Enqueue("valid", blocks[hashes[skip]])
 | 
						tester.fetcher.Enqueue("valid", blocks[hashes[skip]])
 | 
				
			||||||
	time.Sleep(50 * time.Millisecond)
 | 
						time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != targetBlocks+1 {
 | 
						if imported := len(tester.blocks); imported != targetBlocks+1 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, targetBlocks+1)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -334,7 +346,7 @@ func TestImportDeduplication(t *testing.T) {
 | 
				
			|||||||
	tester.fetcher.Enqueue("valid", blocks[hashes[1]])
 | 
						tester.fetcher.Enqueue("valid", blocks[hashes[1]])
 | 
				
			||||||
	time.Sleep(50 * time.Millisecond)
 | 
						time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if imported := len(tester.ownBlocks); imported != 3 {
 | 
						if imported := len(tester.blocks); imported != 3 {
 | 
				
			||||||
		t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 3)
 | 
							t.Fatalf("synchronised block mismatch: have %v, want %v", imported, 3)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if counter != 2 {
 | 
						if counter != 2 {
 | 
				
			||||||
@@ -353,8 +365,8 @@ func TestDistantDiscarding(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// Create a tester and simulate a head block being the middle of the above chain
 | 
						// Create a tester and simulate a head block being the middle of the above chain
 | 
				
			||||||
	tester := newTester()
 | 
						tester := newTester()
 | 
				
			||||||
	tester.ownHashes = []common.Hash{head}
 | 
						tester.hashes = []common.Hash{head}
 | 
				
			||||||
	tester.ownBlocks = map[common.Hash]*types.Block{head: blocks[head]}
 | 
						tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Ensure that a block with a lower number than the threshold is discarded
 | 
						// Ensure that a block with a lower number than the threshold is discarded
 | 
				
			||||||
	tester.fetcher.Enqueue("lower", blocks[hashes[0]])
 | 
						tester.fetcher.Enqueue("lower", blocks[hashes[0]])
 | 
				
			||||||
@@ -413,10 +425,10 @@ func TestCompetingImports(t *testing.T) {
 | 
				
			|||||||
	tester.fetcher.Enqueue("chain C", blocksC[hashesC[len(hashesC)-2]])
 | 
						tester.fetcher.Enqueue("chain C", blocksC[hashesC[len(hashesC)-2]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	start := time.Now()
 | 
						start := time.Now()
 | 
				
			||||||
	for len(tester.ownHashes) != len(hashesA) && time.Since(start) < time.Second {
 | 
						for len(tester.hashes) != len(hashesA) && time.Since(start) < time.Second {
 | 
				
			||||||
		time.Sleep(50 * time.Millisecond)
 | 
							time.Sleep(50 * time.Millisecond)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(tester.ownHashes) != len(hashesA) {
 | 
						if len(tester.hashes) != len(hashesA) {
 | 
				
			||||||
		t.Fatalf("chain length mismatch: have %v, want %v", len(tester.ownHashes), len(hashesA))
 | 
							t.Fatalf("chain length mismatch: have %v, want %v", len(tester.hashes), len(hashesA))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user