Merge pull request #3153 from fjl/trie-unload-fix
trie: improve cache unloading mechanism
This commit is contained in:
		@@ -39,12 +39,12 @@ import (
 | 
				
			|||||||
var StartingNonce uint64
 | 
					var StartingNonce uint64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	// Number of past tries to keep. The arbitrarily chosen value here
 | 
						// Number of past tries to keep. This value is chosen such that
 | 
				
			||||||
	// is max uncle depth + 1.
 | 
						// reasonable chain reorg depths will hit an existing trie.
 | 
				
			||||||
	maxPastTries = 8
 | 
						maxPastTries = 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Trie cache generation limit.
 | 
						// Trie cache generation limit.
 | 
				
			||||||
	maxTrieCacheGen = 100
 | 
						maxTrieCacheGen = 120
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Number of codehash->size associations to keep.
 | 
						// Number of codehash->size associations to keep.
 | 
				
			||||||
	codeSizeCacheSize = 100000
 | 
						codeSizeCacheSize = 100000
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -58,7 +58,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
 | 
				
			|||||||
			return hash, n, nil
 | 
								return hash, n, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if n.canUnload(h.cachegen, h.cachelimit) {
 | 
							if n.canUnload(h.cachegen, h.cachelimit) {
 | 
				
			||||||
			// Evict the node from cache. All of its subnodes will have a lower or equal
 | 
								// Unload the node from cache. All of its subnodes will have a lower or equal
 | 
				
			||||||
			// cache generation number.
 | 
								// cache generation number.
 | 
				
			||||||
			return hash, hash, nil
 | 
								return hash, hash, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return hashNode{}, n, err
 | 
							return hashNode{}, n, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Cache the hash of the ndoe for later reuse.
 | 
						// Cache the hash of the ndoe for later reuse and remove
 | 
				
			||||||
	if hash, ok := hashed.(hashNode); ok && !force {
 | 
						// the dirty flag in commit mode. It's fine to assign these values directly
 | 
				
			||||||
		switch cached := cached.(type) {
 | 
						// without copying the node first because hashChildren copies it.
 | 
				
			||||||
		case *shortNode:
 | 
						cachedHash, _ := hashed.(hashNode)
 | 
				
			||||||
			cached = cached.copy()
 | 
						switch cn := cached.(type) {
 | 
				
			||||||
			cached.flags.hash = hash
 | 
						case *shortNode:
 | 
				
			||||||
			if db != nil {
 | 
							cn.flags.hash = cachedHash
 | 
				
			||||||
				cached.flags.dirty = false
 | 
							if db != nil {
 | 
				
			||||||
			}
 | 
								cn.flags.dirty = false
 | 
				
			||||||
			return hashed, cached, nil
 | 
							}
 | 
				
			||||||
		case *fullNode:
 | 
						case *fullNode:
 | 
				
			||||||
			cached = cached.copy()
 | 
							cn.flags.hash = cachedHash
 | 
				
			||||||
			cached.flags.hash = hash
 | 
							if db != nil {
 | 
				
			||||||
			if db != nil {
 | 
								cn.flags.dirty = false
 | 
				
			||||||
				cached.flags.dirty = false
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			return hashed, cached, nil
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return hashed, cached, nil
 | 
						return hashed, cached, nil
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										26
									
								
								trie/node.go
									
									
									
									
									
								
							
							
						
						
									
										26
									
								
								trie/node.go
									
									
									
									
									
								
							@@ -104,8 +104,8 @@ func (n valueNode) fstring(ind string) string {
 | 
				
			|||||||
	return fmt.Sprintf("%x ", []byte(n))
 | 
						return fmt.Sprintf("%x ", []byte(n))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func mustDecodeNode(hash, buf []byte) node {
 | 
					func mustDecodeNode(hash, buf []byte, cachegen uint16) node {
 | 
				
			||||||
	n, err := decodeNode(hash, buf)
 | 
						n, err := decodeNode(hash, buf, cachegen)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		panic(fmt.Sprintf("node %x: %v", hash, err))
 | 
							panic(fmt.Sprintf("node %x: %v", hash, err))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -113,7 +113,7 @@ func mustDecodeNode(hash, buf []byte) node {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// decodeNode parses the RLP encoding of a trie node.
 | 
					// decodeNode parses the RLP encoding of a trie node.
 | 
				
			||||||
func decodeNode(hash, buf []byte) (node, error) {
 | 
					func decodeNode(hash, buf []byte, cachegen uint16) (node, error) {
 | 
				
			||||||
	if len(buf) == 0 {
 | 
						if len(buf) == 0 {
 | 
				
			||||||
		return nil, io.ErrUnexpectedEOF
 | 
							return nil, io.ErrUnexpectedEOF
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -123,22 +123,22 @@ func decodeNode(hash, buf []byte) (node, error) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	switch c, _ := rlp.CountValues(elems); c {
 | 
						switch c, _ := rlp.CountValues(elems); c {
 | 
				
			||||||
	case 2:
 | 
						case 2:
 | 
				
			||||||
		n, err := decodeShort(hash, buf, elems)
 | 
							n, err := decodeShort(hash, buf, elems, cachegen)
 | 
				
			||||||
		return n, wrapError(err, "short")
 | 
							return n, wrapError(err, "short")
 | 
				
			||||||
	case 17:
 | 
						case 17:
 | 
				
			||||||
		n, err := decodeFull(hash, buf, elems)
 | 
							n, err := decodeFull(hash, buf, elems, cachegen)
 | 
				
			||||||
		return n, wrapError(err, "full")
 | 
							return n, wrapError(err, "full")
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		return nil, fmt.Errorf("invalid number of list elements: %v", c)
 | 
							return nil, fmt.Errorf("invalid number of list elements: %v", c)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func decodeShort(hash, buf, elems []byte) (node, error) {
 | 
					func decodeShort(hash, buf, elems []byte, cachegen uint16) (node, error) {
 | 
				
			||||||
	kbuf, rest, err := rlp.SplitString(elems)
 | 
						kbuf, rest, err := rlp.SplitString(elems)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	flag := nodeFlag{hash: hash}
 | 
						flag := nodeFlag{hash: hash, gen: cachegen}
 | 
				
			||||||
	key := compactDecode(kbuf)
 | 
						key := compactDecode(kbuf)
 | 
				
			||||||
	if key[len(key)-1] == 16 {
 | 
						if key[len(key)-1] == 16 {
 | 
				
			||||||
		// value node
 | 
							// value node
 | 
				
			||||||
@@ -148,17 +148,17 @@ func decodeShort(hash, buf, elems []byte) (node, error) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		return &shortNode{key, append(valueNode{}, val...), flag}, nil
 | 
							return &shortNode{key, append(valueNode{}, val...), flag}, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	r, _, err := decodeRef(rest)
 | 
						r, _, err := decodeRef(rest, cachegen)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, wrapError(err, "val")
 | 
							return nil, wrapError(err, "val")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &shortNode{key, r, flag}, nil
 | 
						return &shortNode{key, r, flag}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
 | 
					func decodeFull(hash, buf, elems []byte, cachegen uint16) (*fullNode, error) {
 | 
				
			||||||
	n := &fullNode{flags: nodeFlag{hash: hash}}
 | 
						n := &fullNode{flags: nodeFlag{hash: hash, gen: cachegen}}
 | 
				
			||||||
	for i := 0; i < 16; i++ {
 | 
						for i := 0; i < 16; i++ {
 | 
				
			||||||
		cld, rest, err := decodeRef(elems)
 | 
							cld, rest, err := decodeRef(elems, cachegen)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return n, wrapError(err, fmt.Sprintf("[%d]", i))
 | 
								return n, wrapError(err, fmt.Sprintf("[%d]", i))
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -176,7 +176,7 @@ func decodeFull(hash, buf, elems []byte) (*fullNode, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const hashLen = len(common.Hash{})
 | 
					const hashLen = len(common.Hash{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func decodeRef(buf []byte) (node, []byte, error) {
 | 
					func decodeRef(buf []byte, cachegen uint16) (node, []byte, error) {
 | 
				
			||||||
	kind, val, rest, err := rlp.Split(buf)
 | 
						kind, val, rest, err := rlp.Split(buf)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, buf, err
 | 
							return nil, buf, err
 | 
				
			||||||
@@ -189,7 +189,7 @@ func decodeRef(buf []byte) (node, []byte, error) {
 | 
				
			|||||||
			err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
 | 
								err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
 | 
				
			||||||
			return nil, buf, err
 | 
								return nil, buf, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		n, err := decodeNode(nil, buf)
 | 
							n, err := decodeNode(nil, buf, cachegen)
 | 
				
			||||||
		return n, rest, err
 | 
							return n, rest, err
 | 
				
			||||||
	case kind == rlp.String && len(val) == 0:
 | 
						case kind == rlp.String && len(val) == 0:
 | 
				
			||||||
		// empty node
 | 
							// empty node
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -101,7 +101,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value
 | 
				
			|||||||
		if !bytes.Equal(sha.Sum(nil), wantHash) {
 | 
							if !bytes.Equal(sha.Sum(nil), wantHash) {
 | 
				
			||||||
			return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
 | 
								return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		n, err := decodeNode(wantHash, buf)
 | 
							n, err := decodeNode(wantHash, buf, 0)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, fmt.Errorf("bad proof node %d: %v", i, err)
 | 
								return nil, fmt.Errorf("bad proof node %d: %v", i, err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -82,7 +82,7 @@ func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, c
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	key := root.Bytes()
 | 
						key := root.Bytes()
 | 
				
			||||||
	blob, _ := s.database.Get(key)
 | 
						blob, _ := s.database.Get(key)
 | 
				
			||||||
	if local, err := decodeNode(key, blob); local != nil && err == nil {
 | 
						if local, err := decodeNode(key, blob, 0); local != nil && err == nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	// Assemble the new sub-trie sync request
 | 
						// Assemble the new sub-trie sync request
 | 
				
			||||||
@@ -158,7 +158,7 @@ func (s *TrieSync) Process(results []SyncResult) (int, error) {
 | 
				
			|||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		// Decode the node data content and update the request
 | 
							// Decode the node data content and update the request
 | 
				
			||||||
		node, err := decodeNode(item.Hash[:], item.Data)
 | 
							node, err := decodeNode(item.Hash[:], item.Data, 0)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return i, err
 | 
								return i, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -246,7 +246,7 @@ func (s *TrieSync) children(req *request) ([]*request, error) {
 | 
				
			|||||||
		if node, ok := (*child.node).(hashNode); ok {
 | 
							if node, ok := (*child.node).(hashNode); ok {
 | 
				
			||||||
			// Try to resolve the node from the local database
 | 
								// Try to resolve the node from the local database
 | 
				
			||||||
			blob, _ := s.database.Get(node)
 | 
								blob, _ := s.database.Get(node)
 | 
				
			||||||
			if local, err := decodeNode(node[:], blob); local != nil && err == nil {
 | 
								if local, err := decodeNode(node[:], blob, 0); local != nil && err == nil {
 | 
				
			||||||
				*child.node = local
 | 
									*child.node = local
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								trie/trie.go
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								trie/trie.go
									
									
									
									
									
								
							@@ -105,13 +105,11 @@ func New(root common.Hash, db Database) (*Trie, error) {
 | 
				
			|||||||
		if db == nil {
 | 
							if db == nil {
 | 
				
			||||||
			panic("trie.New: cannot use existing root without a database")
 | 
								panic("trie.New: cannot use existing root without a database")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if v, _ := trie.db.Get(root[:]); len(v) == 0 {
 | 
							rootnode, err := trie.resolveHash(root[:], nil, nil)
 | 
				
			||||||
			return nil, &MissingNodeError{
 | 
							if err != nil {
 | 
				
			||||||
				RootHash: root,
 | 
								return nil, err
 | 
				
			||||||
				NodeHash: root,
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		trie.root = hashNode(root.Bytes())
 | 
							trie.root = rootnode
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return trie, nil
 | 
						return trie, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -158,14 +156,15 @@ func (t *Trie) tryGet(origNode node, key []byte, pos int) (value []byte, newnode
 | 
				
			|||||||
		if err == nil && didResolve {
 | 
							if err == nil && didResolve {
 | 
				
			||||||
			n = n.copy()
 | 
								n = n.copy()
 | 
				
			||||||
			n.Val = newnode
 | 
								n.Val = newnode
 | 
				
			||||||
 | 
								n.flags.gen = t.cachegen
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return value, n, didResolve, err
 | 
							return value, n, didResolve, err
 | 
				
			||||||
	case *fullNode:
 | 
						case *fullNode:
 | 
				
			||||||
		value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
 | 
							value, newnode, didResolve, err = t.tryGet(n.Children[key[pos]], key, pos+1)
 | 
				
			||||||
		if err == nil && didResolve {
 | 
							if err == nil && didResolve {
 | 
				
			||||||
			n = n.copy()
 | 
								n = n.copy()
 | 
				
			||||||
 | 
								n.flags.gen = t.cachegen
 | 
				
			||||||
			n.Children[key[pos]] = newnode
 | 
								n.Children[key[pos]] = newnode
 | 
				
			||||||
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return value, n, didResolve, err
 | 
							return value, n, didResolve, err
 | 
				
			||||||
	case hashNode:
 | 
						case hashNode:
 | 
				
			||||||
@@ -261,7 +260,8 @@ func (t *Trie) insert(n node, prefix, key []byte, value node) (bool, node, error
 | 
				
			|||||||
			return false, n, err
 | 
								return false, n, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		n = n.copy()
 | 
							n = n.copy()
 | 
				
			||||||
		n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
 | 
							n.flags = t.newFlag()
 | 
				
			||||||
 | 
							n.Children[key[0]] = nn
 | 
				
			||||||
		return true, n, nil
 | 
							return true, n, nil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	case nil:
 | 
						case nil:
 | 
				
			||||||
@@ -345,7 +345,8 @@ func (t *Trie) delete(n node, prefix, key []byte) (bool, node, error) {
 | 
				
			|||||||
			return false, n, err
 | 
								return false, n, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		n = n.copy()
 | 
							n = n.copy()
 | 
				
			||||||
		n.Children[key[0]], n.flags.hash, n.flags.dirty = nn, nil, true
 | 
							n.flags = t.newFlag()
 | 
				
			||||||
 | 
							n.Children[key[0]] = nn
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Check how many non-nil entries are left after deleting and
 | 
							// Check how many non-nil entries are left after deleting and
 | 
				
			||||||
		// reduce the full node to a short node if only one entry is
 | 
							// reduce the full node to a short node if only one entry is
 | 
				
			||||||
@@ -443,7 +444,7 @@ func (t *Trie) resolveHash(n hashNode, prefix, suffix []byte) (node, error) {
 | 
				
			|||||||
			SuffixLen: len(suffix),
 | 
								SuffixLen: len(suffix),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	dec := mustDecodeNode(n, enc)
 | 
						dec := mustDecodeNode(n, enc, t.cachegen)
 | 
				
			||||||
	return dec, nil
 | 
						return dec, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -300,25 +300,6 @@ func TestReplication(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Not an actual test
 | 
					 | 
				
			||||||
func TestOutput(t *testing.T) {
 | 
					 | 
				
			||||||
	t.Skip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	base := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
 | 
					 | 
				
			||||||
	trie := newEmpty()
 | 
					 | 
				
			||||||
	for i := 0; i < 50; i++ {
 | 
					 | 
				
			||||||
		updateString(trie, fmt.Sprintf("%s%d", base, i), "valueeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	fmt.Println("############################## FULL ################################")
 | 
					 | 
				
			||||||
	fmt.Println(trie.root)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	trie.Commit()
 | 
					 | 
				
			||||||
	fmt.Println("############################## SMALL ################################")
 | 
					 | 
				
			||||||
	trie2, _ := New(trie.Hash(), trie.db)
 | 
					 | 
				
			||||||
	getString(trie2, base+"20")
 | 
					 | 
				
			||||||
	fmt.Println(trie2.root)
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func TestLargeValue(t *testing.T) {
 | 
					func TestLargeValue(t *testing.T) {
 | 
				
			||||||
	trie := newEmpty()
 | 
						trie := newEmpty()
 | 
				
			||||||
	trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
 | 
						trie.Update([]byte("key1"), []byte{99, 99, 99, 99})
 | 
				
			||||||
@@ -326,14 +307,56 @@ func TestLargeValue(t *testing.T) {
 | 
				
			|||||||
	trie.Hash()
 | 
						trie.Hash()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type countingDB struct {
 | 
				
			||||||
 | 
						Database
 | 
				
			||||||
 | 
						gets map[string]int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (db *countingDB) Get(key []byte) ([]byte, error) {
 | 
				
			||||||
 | 
						db.gets[string(key)]++
 | 
				
			||||||
 | 
						return db.Database.Get(key)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// TestCacheUnload checks that decoded nodes are unloaded after a
 | 
				
			||||||
 | 
					// certain number of commit operations.
 | 
				
			||||||
 | 
					func TestCacheUnload(t *testing.T) {
 | 
				
			||||||
 | 
						// Create test trie with two branches.
 | 
				
			||||||
 | 
						trie := newEmpty()
 | 
				
			||||||
 | 
						key1 := "---------------------------------"
 | 
				
			||||||
 | 
						key2 := "---some other branch"
 | 
				
			||||||
 | 
						updateString(trie, key1, "this is the branch of key1.")
 | 
				
			||||||
 | 
						updateString(trie, key2, "this is the branch of key2.")
 | 
				
			||||||
 | 
						root, _ := trie.Commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Commit the trie repeatedly and access key1.
 | 
				
			||||||
 | 
						// The branch containing it is loaded from DB exactly two times:
 | 
				
			||||||
 | 
						// in the 0th and 6th iteration.
 | 
				
			||||||
 | 
						db := &countingDB{Database: trie.db, gets: make(map[string]int)}
 | 
				
			||||||
 | 
						trie, _ = New(root, db)
 | 
				
			||||||
 | 
						trie.SetCacheLimit(5)
 | 
				
			||||||
 | 
						for i := 0; i < 12; i++ {
 | 
				
			||||||
 | 
							getString(trie, key1)
 | 
				
			||||||
 | 
							trie.Commit()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that it got loaded two times.
 | 
				
			||||||
 | 
						for dbkey, count := range db.gets {
 | 
				
			||||||
 | 
							if count != 2 {
 | 
				
			||||||
 | 
								t.Errorf("db key %x loaded %d times, want %d times", []byte(dbkey), count, 2)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// randTest performs random trie operations.
 | 
				
			||||||
 | 
					// Instances of this test are created by Generate.
 | 
				
			||||||
 | 
					type randTest []randTestStep
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type randTestStep struct {
 | 
					type randTestStep struct {
 | 
				
			||||||
	op    int
 | 
						op    int
 | 
				
			||||||
	key   []byte // for opUpdate, opDelete, opGet
 | 
						key   []byte // for opUpdate, opDelete, opGet
 | 
				
			||||||
	value []byte // for opUpdate
 | 
						value []byte // for opUpdate
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type randTest []randTestStep
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	opUpdate = iota
 | 
						opUpdate = iota
 | 
				
			||||||
	opDelete
 | 
						opDelete
 | 
				
			||||||
@@ -342,6 +365,7 @@ const (
 | 
				
			|||||||
	opHash
 | 
						opHash
 | 
				
			||||||
	opReset
 | 
						opReset
 | 
				
			||||||
	opItercheckhash
 | 
						opItercheckhash
 | 
				
			||||||
 | 
						opCheckCacheInvariant
 | 
				
			||||||
	opMax // boundary value, not an actual op
 | 
						opMax // boundary value, not an actual op
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -437,6 +461,44 @@ func runRandTest(rt randTest) bool {
 | 
				
			|||||||
				fmt.Println("hashes not equal")
 | 
									fmt.Println("hashes not equal")
 | 
				
			||||||
				return false
 | 
									return false
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
							case opCheckCacheInvariant:
 | 
				
			||||||
 | 
								return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return true
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool {
 | 
				
			||||||
 | 
						var children []node
 | 
				
			||||||
 | 
						var flag nodeFlag
 | 
				
			||||||
 | 
						switch n := n.(type) {
 | 
				
			||||||
 | 
						case *shortNode:
 | 
				
			||||||
 | 
							flag = n.flags
 | 
				
			||||||
 | 
							children = []node{n.Val}
 | 
				
			||||||
 | 
						case *fullNode:
 | 
				
			||||||
 | 
							flag = n.flags
 | 
				
			||||||
 | 
							children = n.Children[:]
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						showerror := func() {
 | 
				
			||||||
 | 
							fmt.Printf("at depth %d node %s", depth, spew.Sdump(n))
 | 
				
			||||||
 | 
							fmt.Printf("parent: %s", spew.Sdump(parent))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if flag.gen > parentCachegen {
 | 
				
			||||||
 | 
							fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
 | 
				
			||||||
 | 
							showerror()
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if depth > 0 && !parentDirty && flag.dirty {
 | 
				
			||||||
 | 
							fmt.Printf("cache invariant violation: child is dirty but parent isn't\n")
 | 
				
			||||||
 | 
							showerror()
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, child := range children {
 | 
				
			||||||
 | 
							if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) {
 | 
				
			||||||
 | 
								return false
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return true
 | 
						return true
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user