[release/1.4.16] core/state: implement reverts by journaling all changes
This commit replaces the deep-copy based state revert mechanism with a
linear complexity journal. This commit also hides several internal
StateDB methods to limit the number of ways in which calling code can
use the journal incorrectly.
As usual consultation and bug fixes to the initial implementation were
provided by @karalabe, @obscuren and @Arachnid. Thank you!
(cherry picked from commit 1f1ea18b54)
			
			
This commit is contained in:
		| @@ -97,7 +97,8 @@ func (b *SimulatedBackend) ContractCall(contract common.Address, data []byte, pe | |||||||
| 		statedb *state.StateDB | 		statedb *state.StateDB | ||||||
| 	) | 	) | ||||||
| 	if pending { | 	if pending { | ||||||
| 		block, statedb = b.pendingBlock, b.pendingState.Copy() | 		block, statedb = b.pendingBlock, b.pendingState | ||||||
|  | 		defer statedb.RevertToSnapshot(statedb.Snapshot()) | ||||||
| 	} else { | 	} else { | ||||||
| 		block = b.blockchain.CurrentBlock() | 		block = b.blockchain.CurrentBlock() | ||||||
| 		statedb, _ = b.blockchain.State() | 		statedb, _ = b.blockchain.State() | ||||||
| @@ -119,6 +120,7 @@ func (b *SimulatedBackend) ContractCall(contract common.Address, data []byte, pe | |||||||
| 		value:    new(big.Int), | 		value:    new(big.Int), | ||||||
| 		data:     data, | 		data:     data, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Execute the call and return | 	// Execute the call and return | ||||||
| 	vmenv := core.NewEnv(statedb, chainConfig, b.blockchain, msg, block.Header(), vm.Config{}) | 	vmenv := core.NewEnv(statedb, chainConfig, b.blockchain, msg, block.Header(), vm.Config{}) | ||||||
| 	gaspool := new(core.GasPool).AddGas(common.MaxBig) | 	gaspool := new(core.GasPool).AddGas(common.MaxBig) | ||||||
| @@ -146,8 +148,10 @@ func (b *SimulatedBackend) EstimateGasLimit(sender common.Address, contract *com | |||||||
| 	// Create a copy of the currently pending state db to screw around with | 	// Create a copy of the currently pending state db to screw around with | ||||||
| 	var ( | 	var ( | ||||||
| 		block   = b.pendingBlock | 		block   = b.pendingBlock | ||||||
| 		statedb = b.pendingState.Copy() | 		statedb = b.pendingState | ||||||
| 	) | 	) | ||||||
|  | 	defer statedb.RevertToSnapshot(statedb.Snapshot()) | ||||||
|  |  | ||||||
| 	// If there's no code to interact with, respond with an appropriate error | 	// If there's no code to interact with, respond with an appropriate error | ||||||
| 	if contract != nil { | 	if contract != nil { | ||||||
| 		if code := statedb.GetCode(*contract); len(code) == 0 { | 		if code := statedb.GetCode(*contract); len(code) == 0 { | ||||||
|   | |||||||
| @@ -226,8 +226,8 @@ func (ruleSet) IsHomestead(*big.Int) bool { return true } | |||||||
| func (self *VMEnv) RuleSet() vm.RuleSet       { return ruleSet{} } | func (self *VMEnv) RuleSet() vm.RuleSet       { return ruleSet{} } | ||||||
| func (self *VMEnv) Vm() vm.Vm                 { return self.evm } | func (self *VMEnv) Vm() vm.Vm                 { return self.evm } | ||||||
| func (self *VMEnv) Db() vm.Database           { return self.state } | func (self *VMEnv) Db() vm.Database           { return self.state } | ||||||
| func (self *VMEnv) MakeSnapshot() vm.Database  { return self.state.Copy() } | func (self *VMEnv) SnapshotDatabase() int     { return self.state.Snapshot() } | ||||||
| func (self *VMEnv) SetSnapshot(db vm.Database) { self.state.Set(db.(*state.StateDB)) } | func (self *VMEnv) RevertToSnapshot(snap int) { self.state.RevertToSnapshot(snap) } | ||||||
| func (self *VMEnv) Origin() common.Address    { return *self.transactor } | func (self *VMEnv) Origin() common.Address    { return *self.transactor } | ||||||
| func (self *VMEnv) BlockNumber() *big.Int     { return common.Big0 } | func (self *VMEnv) BlockNumber() *big.Int     { return common.Big0 } | ||||||
| func (self *VMEnv) Coinbase() common.Address  { return *self.transactor } | func (self *VMEnv) Coinbase() common.Address  { return *self.transactor } | ||||||
|   | |||||||
| @@ -131,7 +131,7 @@ func (b *BlockGen) AddUncheckedReceipt(receipt *types.Receipt) { | |||||||
| // TxNonce returns the next valid transaction nonce for the | // TxNonce returns the next valid transaction nonce for the | ||||||
| // account at addr. It panics if the account does not exist. | // account at addr. It panics if the account does not exist. | ||||||
| func (b *BlockGen) TxNonce(addr common.Address) uint64 { | func (b *BlockGen) TxNonce(addr common.Address) uint64 { | ||||||
| 	if !b.statedb.HasAccount(addr) { | 	if !b.statedb.Exist(addr) { | ||||||
| 		panic("account does not exist") | 		panic("account does not exist") | ||||||
| 	} | 	} | ||||||
| 	return b.statedb.GetNonce(addr) | 	return b.statedb.GetNonce(addr) | ||||||
|   | |||||||
| @@ -85,7 +85,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A | |||||||
| 		createAccount = true | 		createAccount = true | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	snapshotPreTransfer := env.MakeSnapshot() | 	snapshotPreTransfer := env.SnapshotDatabase() | ||||||
| 	var ( | 	var ( | ||||||
| 		from = env.Db().GetAccount(caller.Address()) | 		from = env.Db().GetAccount(caller.Address()) | ||||||
| 		to   vm.Account | 		to   vm.Account | ||||||
| @@ -129,7 +129,7 @@ func exec(env vm.Environment, caller vm.ContractRef, address, codeAddr *common.A | |||||||
| 	if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { | 	if err != nil && (env.RuleSet().IsHomestead(env.BlockNumber()) || err != vm.CodeStoreOutOfGasError) { | ||||||
| 		contract.UseGas(contract.Gas) | 		contract.UseGas(contract.Gas) | ||||||
|  |  | ||||||
| 		env.SetSnapshot(snapshotPreTransfer) | 		env.RevertToSnapshot(snapshotPreTransfer) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return ret, addr, err | 	return ret, addr, err | ||||||
| @@ -144,7 +144,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA | |||||||
| 		return nil, common.Address{}, vm.DepthError | 		return nil, common.Address{}, vm.DepthError | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	snapshot := env.MakeSnapshot() | 	snapshot := env.SnapshotDatabase() | ||||||
|  |  | ||||||
| 	var to vm.Account | 	var to vm.Account | ||||||
| 	if !env.Db().Exist(*toAddr) { | 	if !env.Db().Exist(*toAddr) { | ||||||
| @@ -162,7 +162,7 @@ func execDelegateCall(env vm.Environment, caller vm.ContractRef, originAddr, toA | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		contract.UseGas(contract.Gas) | 		contract.UseGas(contract.Gas) | ||||||
|  |  | ||||||
| 		env.SetSnapshot(snapshot) | 		env.RevertToSnapshot(snapshot) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return ret, addr, err | 	return ret, addr, err | ||||||
|   | |||||||
| @@ -52,7 +52,7 @@ func (self *StateDB) RawDump() Dump { | |||||||
| 			panic(err) | 			panic(err) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		obj := NewObject(common.BytesToAddress(addr), data, nil) | 		obj := newObject(nil, common.BytesToAddress(addr), data, nil) | ||||||
| 		account := DumpAccount{ | 		account := DumpAccount{ | ||||||
| 			Balance:  data.Balance.String(), | 			Balance:  data.Balance.String(), | ||||||
| 			Nonce:    data.Nonce, | 			Nonce:    data.Nonce, | ||||||
|   | |||||||
							
								
								
									
										117
									
								
								core/state/journal.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								core/state/journal.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | |||||||
|  | // Copyright 2016 The go-ethereum Authors | ||||||
|  | // This file is part of the go-ethereum library. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Lesser General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||||
|  | // GNU Lesser General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Lesser General Public License | ||||||
|  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  | ||||||
|  | package state | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"math/big" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/common" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type journalEntry interface { | ||||||
|  | 	undo(*StateDB) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type journal []journalEntry | ||||||
|  |  | ||||||
|  | type ( | ||||||
|  | 	// Changes to the account trie. | ||||||
|  | 	createObjectChange struct { | ||||||
|  | 		account *common.Address | ||||||
|  | 	} | ||||||
|  | 	resetObjectChange struct { | ||||||
|  | 		prev *StateObject | ||||||
|  | 	} | ||||||
|  | 	deleteAccountChange struct { | ||||||
|  | 		account     *common.Address | ||||||
|  | 		prev        bool // whether account had already suicided | ||||||
|  | 		prevbalance *big.Int | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Changes to individual accounts. | ||||||
|  | 	balanceChange struct { | ||||||
|  | 		account *common.Address | ||||||
|  | 		prev    *big.Int | ||||||
|  | 	} | ||||||
|  | 	nonceChange struct { | ||||||
|  | 		account *common.Address | ||||||
|  | 		prev    uint64 | ||||||
|  | 	} | ||||||
|  | 	storageChange struct { | ||||||
|  | 		account       *common.Address | ||||||
|  | 		key, prevalue common.Hash | ||||||
|  | 	} | ||||||
|  | 	codeChange struct { | ||||||
|  | 		account            *common.Address | ||||||
|  | 		prevcode, prevhash []byte | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Changes to other state values. | ||||||
|  | 	refundChange struct { | ||||||
|  | 		prev *big.Int | ||||||
|  | 	} | ||||||
|  | 	addLogChange struct { | ||||||
|  | 		txhash common.Hash | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (ch createObjectChange) undo(s *StateDB) { | ||||||
|  | 	s.GetStateObject(*ch.account).deleted = true | ||||||
|  | 	delete(s.stateObjects, *ch.account) | ||||||
|  | 	delete(s.stateObjectsDirty, *ch.account) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch resetObjectChange) undo(s *StateDB) { | ||||||
|  | 	s.setStateObject(ch.prev) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch deleteAccountChange) undo(s *StateDB) { | ||||||
|  | 	obj := s.GetStateObject(*ch.account) | ||||||
|  | 	if obj != nil { | ||||||
|  | 		obj.remove = ch.prev | ||||||
|  | 		obj.setBalance(ch.prevbalance) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch balanceChange) undo(s *StateDB) { | ||||||
|  | 	s.GetStateObject(*ch.account).setBalance(ch.prev) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch nonceChange) undo(s *StateDB) { | ||||||
|  | 	s.GetStateObject(*ch.account).setNonce(ch.prev) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch codeChange) undo(s *StateDB) { | ||||||
|  | 	s.GetStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch storageChange) undo(s *StateDB) { | ||||||
|  | 	s.GetStateObject(*ch.account).setState(ch.key, ch.prevalue) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch refundChange) undo(s *StateDB) { | ||||||
|  | 	s.refund = ch.prev | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (ch addLogChange) undo(s *StateDB) { | ||||||
|  | 	logs := s.logs[ch.txhash] | ||||||
|  | 	if len(logs) == 1 { | ||||||
|  | 		delete(s.logs, ch.txhash) | ||||||
|  | 	} else { | ||||||
|  | 		s.logs[ch.txhash] = logs[:len(logs)-1] | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -29,11 +29,8 @@ func create() (*ManagedState, *account) { | |||||||
| 	db, _ := ethdb.NewMemDatabase() | 	db, _ := ethdb.NewMemDatabase() | ||||||
| 	statedb, _ := New(common.Hash{}, db) | 	statedb, _ := New(common.Hash{}, db) | ||||||
| 	ms := ManageState(statedb) | 	ms := ManageState(statedb) | ||||||
| 	so := &StateObject{address: addr} | 	ms.StateDB.SetNonce(addr, 100) | ||||||
| 	so.SetNonce(100) | 	ms.accounts[addr] = newAccount(ms.StateDB.GetStateObject(addr)) | ||||||
| 	ms.StateDB.stateObjects[addr] = so |  | ||||||
| 	ms.accounts[addr] = newAccount(so) |  | ||||||
|  |  | ||||||
| 	return ms, ms.accounts[addr] | 	return ms, ms.accounts[addr] | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -66,6 +66,7 @@ func (self Storage) Copy() Storage { | |||||||
| type StateObject struct { | type StateObject struct { | ||||||
| 	address common.Address // Ethereum address of this account | 	address common.Address // Ethereum address of this account | ||||||
| 	data    Account | 	data    Account | ||||||
|  | 	db      *StateDB | ||||||
|  |  | ||||||
| 	// DB error. | 	// DB error. | ||||||
| 	// State objects are used by the consensus core and VM which are | 	// State objects are used by the consensus core and VM which are | ||||||
| @@ -99,15 +100,15 @@ type Account struct { | |||||||
| 	CodeHash []byte | 	CodeHash []byte | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewObject creates a state object. | // newObject creates a state object. | ||||||
| func NewObject(address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { | func newObject(db *StateDB, address common.Address, data Account, onDirty func(addr common.Address)) *StateObject { | ||||||
| 	if data.Balance == nil { | 	if data.Balance == nil { | ||||||
| 		data.Balance = new(big.Int) | 		data.Balance = new(big.Int) | ||||||
| 	} | 	} | ||||||
| 	if data.CodeHash == nil { | 	if data.CodeHash == nil { | ||||||
| 		data.CodeHash = emptyCodeHash | 		data.CodeHash = emptyCodeHash | ||||||
| 	} | 	} | ||||||
| 	return &StateObject{address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} | 	return &StateObject{db: db, address: address, data: data, cachedStorage: make(Storage), dirtyStorage: make(Storage), onDirty: onDirty} | ||||||
| } | } | ||||||
|  |  | ||||||
| // EncodeRLP implements rlp.Encoder. | // EncodeRLP implements rlp.Encoder. | ||||||
| @@ -122,7 +123,7 @@ func (self *StateObject) setError(err error) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateObject) MarkForDeletion() { | func (self *StateObject) markForDeletion() { | ||||||
| 	self.remove = true | 	self.remove = true | ||||||
| 	if self.onDirty != nil { | 	if self.onDirty != nil { | ||||||
| 		self.onDirty(self.Address()) | 		self.onDirty(self.Address()) | ||||||
| @@ -163,7 +164,16 @@ func (self *StateObject) GetState(db trie.Database, key common.Hash) common.Hash | |||||||
| } | } | ||||||
|  |  | ||||||
| // SetState updates a value in account storage. | // SetState updates a value in account storage. | ||||||
| func (self *StateObject) SetState(key, value common.Hash) { | func (self *StateObject) SetState(db trie.Database, key, value common.Hash) { | ||||||
|  | 	self.db.journal = append(self.db.journal, storageChange{ | ||||||
|  | 		account:  &self.address, | ||||||
|  | 		key:      key, | ||||||
|  | 		prevalue: self.GetState(db, key), | ||||||
|  | 	}) | ||||||
|  | 	self.setState(key, value) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (self *StateObject) setState(key, value common.Hash) { | ||||||
| 	self.cachedStorage[key] = value | 	self.cachedStorage[key] = value | ||||||
| 	self.dirtyStorage[key] = value | 	self.dirtyStorage[key] = value | ||||||
|  |  | ||||||
| @@ -189,7 +199,7 @@ func (self *StateObject) updateTrie(db trie.Database) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // UpdateRoot sets the trie root to the current root hash of | // UpdateRoot sets the trie root to the current root hash of | ||||||
| func (self *StateObject) UpdateRoot(db trie.Database) { | func (self *StateObject) updateRoot(db trie.Database) { | ||||||
| 	self.updateTrie(db) | 	self.updateTrie(db) | ||||||
| 	self.data.Root = self.trie.Hash() | 	self.data.Root = self.trie.Hash() | ||||||
| } | } | ||||||
| @@ -232,6 +242,14 @@ func (c *StateObject) SubBalance(amount *big.Int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateObject) SetBalance(amount *big.Int) { | func (self *StateObject) SetBalance(amount *big.Int) { | ||||||
|  | 	self.db.journal = append(self.db.journal, balanceChange{ | ||||||
|  | 		account: &self.address, | ||||||
|  | 		prev:    new(big.Int).Set(self.data.Balance), | ||||||
|  | 	}) | ||||||
|  | 	self.setBalance(amount) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (self *StateObject) setBalance(amount *big.Int) { | ||||||
| 	self.data.Balance = amount | 	self.data.Balance = amount | ||||||
| 	if self.onDirty != nil { | 	if self.onDirty != nil { | ||||||
| 		self.onDirty(self.Address()) | 		self.onDirty(self.Address()) | ||||||
| @@ -242,8 +260,8 @@ func (self *StateObject) SetBalance(amount *big.Int) { | |||||||
| // Return the gas back to the origin. Used by the Virtual machine or Closures | // Return the gas back to the origin. Used by the Virtual machine or Closures | ||||||
| func (c *StateObject) ReturnGas(gas, price *big.Int) {} | func (c *StateObject) ReturnGas(gas, price *big.Int) {} | ||||||
|  |  | ||||||
| func (self *StateObject) Copy(db trie.Database, onDirty func(addr common.Address)) *StateObject { | func (self *StateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *StateObject { | ||||||
| 	stateObject := NewObject(self.address, self.data, onDirty) | 	stateObject := newObject(db, self.address, self.data, onDirty) | ||||||
| 	stateObject.trie = self.trie | 	stateObject.trie = self.trie | ||||||
| 	stateObject.code = self.code | 	stateObject.code = self.code | ||||||
| 	stateObject.dirtyStorage = self.dirtyStorage.Copy() | 	stateObject.dirtyStorage = self.dirtyStorage.Copy() | ||||||
| @@ -280,6 +298,16 @@ func (self *StateObject) Code(db trie.Database) []byte { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { | func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { | ||||||
|  | 	prevcode := self.Code(self.db.db) | ||||||
|  | 	self.db.journal = append(self.db.journal, codeChange{ | ||||||
|  | 		account:  &self.address, | ||||||
|  | 		prevhash: self.CodeHash(), | ||||||
|  | 		prevcode: prevcode, | ||||||
|  | 	}) | ||||||
|  | 	self.setCode(codeHash, code) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (self *StateObject) setCode(codeHash common.Hash, code []byte) { | ||||||
| 	self.code = code | 	self.code = code | ||||||
| 	self.data.CodeHash = codeHash[:] | 	self.data.CodeHash = codeHash[:] | ||||||
| 	self.dirtyCode = true | 	self.dirtyCode = true | ||||||
| @@ -290,6 +318,14 @@ func (self *StateObject) SetCode(codeHash common.Hash, code []byte) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateObject) SetNonce(nonce uint64) { | func (self *StateObject) SetNonce(nonce uint64) { | ||||||
|  | 	self.db.journal = append(self.db.journal, nonceChange{ | ||||||
|  | 		account: &self.address, | ||||||
|  | 		prev:    self.data.Nonce, | ||||||
|  | 	}) | ||||||
|  | 	self.setNonce(nonce) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (self *StateObject) setNonce(nonce uint64) { | ||||||
| 	self.data.Nonce = nonce | 	self.data.Nonce = nonce | ||||||
| 	if self.onDirty != nil { | 	if self.onDirty != nil { | ||||||
| 		self.onDirty(self.Address()) | 		self.onDirty(self.Address()) | ||||||
| @@ -322,7 +358,7 @@ func (self *StateObject) ForEachStorage(cb func(key, value common.Hash) bool) { | |||||||
| 		cb(h, value) | 		cb(h, value) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	it := self.trie.Iterator() | 	it := self.getTrie(self.db.db).Iterator() | ||||||
| 	for it.Next() { | 	for it.Next() { | ||||||
| 		// ignore cached values | 		// ignore cached values | ||||||
| 		key := common.BytesToHash(self.trie.GetKey(it.Key)) | 		key := common.BytesToHash(self.trie.GetKey(it.Key)) | ||||||
|   | |||||||
| @@ -46,8 +46,8 @@ func (s *StateSuite) TestDump(c *checker.C) { | |||||||
| 	obj3.SetBalance(big.NewInt(44)) | 	obj3.SetBalance(big.NewInt(44)) | ||||||
|  |  | ||||||
| 	// write some of them to the trie | 	// write some of them to the trie | ||||||
| 	s.state.UpdateStateObject(obj1) | 	s.state.updateStateObject(obj1) | ||||||
| 	s.state.UpdateStateObject(obj2) | 	s.state.updateStateObject(obj2) | ||||||
| 	s.state.Commit() | 	s.state.Commit() | ||||||
|  |  | ||||||
| 	// check that dump contains the state objects that are in trie | 	// check that dump contains the state objects that are in trie | ||||||
| @@ -116,12 +116,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { | |||||||
| 	// set initial state object value | 	// set initial state object value | ||||||
| 	s.state.SetState(stateobjaddr, storageaddr, data1) | 	s.state.SetState(stateobjaddr, storageaddr, data1) | ||||||
| 	// get snapshot of current state | 	// get snapshot of current state | ||||||
| 	snapshot := s.state.Copy() | 	snapshot := s.state.Snapshot() | ||||||
|  |  | ||||||
| 	// set new state object value | 	// set new state object value | ||||||
| 	s.state.SetState(stateobjaddr, storageaddr, data2) | 	s.state.SetState(stateobjaddr, storageaddr, data2) | ||||||
| 	// restore snapshot | 	// restore snapshot | ||||||
| 	s.state.Set(snapshot) | 	s.state.RevertToSnapshot(snapshot) | ||||||
|  |  | ||||||
| 	// get state storage value | 	// get state storage value | ||||||
| 	res := s.state.GetState(stateobjaddr, storageaddr) | 	res := s.state.GetState(stateobjaddr, storageaddr) | ||||||
| @@ -129,6 +129,12 @@ func (s *StateSuite) TestSnapshot(c *checker.C) { | |||||||
| 	c.Assert(data1, checker.DeepEquals, res) | 	c.Assert(data1, checker.DeepEquals, res) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestSnapshotEmpty(t *testing.T) { | ||||||
|  | 	db, _ := ethdb.NewMemDatabase() | ||||||
|  | 	state, _ := New(common.Hash{}, db) | ||||||
|  | 	state.RevertToSnapshot(state.Snapshot()) | ||||||
|  | } | ||||||
|  |  | ||||||
| // use testing instead of checker because checker does not support | // use testing instead of checker because checker does not support | ||||||
| // printing/logging in tests (-check.vv does not work) | // printing/logging in tests (-check.vv does not work) | ||||||
| func TestSnapshot2(t *testing.T) { | func TestSnapshot2(t *testing.T) { | ||||||
| @@ -152,7 +158,7 @@ func TestSnapshot2(t *testing.T) { | |||||||
| 	so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) | 	so0.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e'}), []byte{'c', 'a', 'f', 'e'}) | ||||||
| 	so0.remove = false | 	so0.remove = false | ||||||
| 	so0.deleted = false | 	so0.deleted = false | ||||||
| 	state.SetStateObject(so0) | 	state.setStateObject(so0) | ||||||
|  |  | ||||||
| 	root, _ := state.Commit() | 	root, _ := state.Commit() | ||||||
| 	state.Reset(root) | 	state.Reset(root) | ||||||
| @@ -164,15 +170,15 @@ func TestSnapshot2(t *testing.T) { | |||||||
| 	so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) | 	so1.SetCode(crypto.Keccak256Hash([]byte{'c', 'a', 'f', 'e', '2'}), []byte{'c', 'a', 'f', 'e', '2'}) | ||||||
| 	so1.remove = true | 	so1.remove = true | ||||||
| 	so1.deleted = true | 	so1.deleted = true | ||||||
| 	state.SetStateObject(so1) | 	state.setStateObject(so1) | ||||||
|  |  | ||||||
| 	so1 = state.GetStateObject(stateobjaddr1) | 	so1 = state.GetStateObject(stateobjaddr1) | ||||||
| 	if so1 != nil { | 	if so1 != nil { | ||||||
| 		t.Fatalf("deleted object not nil when getting") | 		t.Fatalf("deleted object not nil when getting") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	snapshot := state.Copy() | 	snapshot := state.Snapshot() | ||||||
| 	state.Set(snapshot) | 	state.RevertToSnapshot(snapshot) | ||||||
|  |  | ||||||
| 	so0Restored := state.GetStateObject(stateobjaddr0) | 	so0Restored := state.GetStateObject(stateobjaddr0) | ||||||
| 	// Update lazily-loaded values before comparing. | 	// Update lazily-loaded values before comparing. | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ package state | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/big" | 	"math/big" | ||||||
|  | 	"sort" | ||||||
| 	"sync" | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/common" | 	"github.com/ethereum/go-ethereum/common" | ||||||
| @@ -40,12 +41,17 @@ var StartingNonce uint64 | |||||||
| const ( | const ( | ||||||
| 	// Number of past tries to keep. The arbitrarily chosen value here | 	// Number of past tries to keep. The arbitrarily chosen value here | ||||||
| 	// is max uncle depth + 1. | 	// is max uncle depth + 1. | ||||||
| 	maxJournalLength = 8 | 	maxTrieCacheLength = 8 | ||||||
|  |  | ||||||
| 	// Number of codehash->size associations to keep. | 	// Number of codehash->size associations to keep. | ||||||
| 	codeSizeCacheSize = 100000 | 	codeSizeCacheSize = 100000 | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type revision struct { | ||||||
|  | 	id           int | ||||||
|  | 	journalIndex int | ||||||
|  | } | ||||||
|  |  | ||||||
| // StateDBs within the ethereum protocol are used to store anything | // StateDBs within the ethereum protocol are used to store anything | ||||||
| // within the merkle trie. StateDBs take care of caching and storing | // within the merkle trie. StateDBs take care of caching and storing | ||||||
| // nested states. It's the general query interface to retrieve: | // nested states. It's the general query interface to retrieve: | ||||||
| @@ -69,6 +75,12 @@ type StateDB struct { | |||||||
| 	logs         map[common.Hash]vm.Logs | 	logs         map[common.Hash]vm.Logs | ||||||
| 	logSize      uint | 	logSize      uint | ||||||
|  |  | ||||||
|  | 	// Journal of state modifications. This is the backbone of | ||||||
|  | 	// Snapshot and RevertToSnapshot. | ||||||
|  | 	journal        journal | ||||||
|  | 	validRevisions []revision | ||||||
|  | 	nextRevisionId int | ||||||
|  |  | ||||||
| 	lock sync.Mutex | 	lock sync.Mutex | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -124,12 +136,12 @@ func (self *StateDB) Reset(root common.Hash) error { | |||||||
| 	self.trie = tr | 	self.trie = tr | ||||||
| 	self.stateObjects = make(map[common.Address]*StateObject) | 	self.stateObjects = make(map[common.Address]*StateObject) | ||||||
| 	self.stateObjectsDirty = make(map[common.Address]struct{}) | 	self.stateObjectsDirty = make(map[common.Address]struct{}) | ||||||
| 	self.refund = new(big.Int) |  | ||||||
| 	self.thash = common.Hash{} | 	self.thash = common.Hash{} | ||||||
| 	self.bhash = common.Hash{} | 	self.bhash = common.Hash{} | ||||||
| 	self.txIndex = 0 | 	self.txIndex = 0 | ||||||
| 	self.logs = make(map[common.Hash]vm.Logs) | 	self.logs = make(map[common.Hash]vm.Logs) | ||||||
| 	self.logSize = 0 | 	self.logSize = 0 | ||||||
|  | 	self.clearJournalAndRefund() | ||||||
|  |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| @@ -150,7 +162,7 @@ func (self *StateDB) pushTrie(t *trie.SecureTrie) { | |||||||
| 	self.lock.Lock() | 	self.lock.Lock() | ||||||
| 	defer self.lock.Unlock() | 	defer self.lock.Unlock() | ||||||
|  |  | ||||||
| 	if len(self.pastTries) >= maxJournalLength { | 	if len(self.pastTries) >= maxTrieCacheLength { | ||||||
| 		copy(self.pastTries, self.pastTries[1:]) | 		copy(self.pastTries, self.pastTries[1:]) | ||||||
| 		self.pastTries[len(self.pastTries)-1] = t | 		self.pastTries[len(self.pastTries)-1] = t | ||||||
| 	} else { | 	} else { | ||||||
| @@ -165,6 +177,8 @@ func (self *StateDB) StartRecord(thash, bhash common.Hash, ti int) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) AddLog(log *vm.Log) { | func (self *StateDB) AddLog(log *vm.Log) { | ||||||
|  | 	self.journal = append(self.journal, addLogChange{txhash: self.thash}) | ||||||
|  |  | ||||||
| 	log.TxHash = self.thash | 	log.TxHash = self.thash | ||||||
| 	log.BlockHash = self.bhash | 	log.BlockHash = self.bhash | ||||||
| 	log.TxIndex = uint(self.txIndex) | 	log.TxIndex = uint(self.txIndex) | ||||||
| @@ -186,13 +200,12 @@ func (self *StateDB) Logs() vm.Logs { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) AddRefund(gas *big.Int) { | func (self *StateDB) AddRefund(gas *big.Int) { | ||||||
|  | 	self.journal = append(self.journal, refundChange{prev: new(big.Int).Set(self.refund)}) | ||||||
| 	self.refund.Add(self.refund, gas) | 	self.refund.Add(self.refund, gas) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) HasAccount(addr common.Address) bool { | // Exist reports whether the given account address exists in the state. | ||||||
| 	return self.GetStateObject(addr) != nil | // Notably this also returns true for suicided accounts. | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *StateDB) Exist(addr common.Address) bool { | func (self *StateDB) Exist(addr common.Address) bool { | ||||||
| 	return self.GetStateObject(addr) != nil | 	return self.GetStateObject(addr) != nil | ||||||
| } | } | ||||||
| @@ -207,7 +220,6 @@ func (self *StateDB) GetBalance(addr common.Address) *big.Int { | |||||||
| 	if stateObject != nil { | 	if stateObject != nil { | ||||||
| 		return stateObject.Balance() | 		return stateObject.Balance() | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return common.Big0 | 	return common.Big0 | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -282,6 +294,13 @@ func (self *StateDB) AddBalance(addr common.Address, amount *big.Int) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (self *StateDB) SetBalance(addr common.Address, amount *big.Int) { | ||||||
|  | 	stateObject := self.GetOrNewStateObject(addr) | ||||||
|  | 	if stateObject != nil { | ||||||
|  | 		stateObject.SetBalance(amount) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { | func (self *StateDB) SetNonce(addr common.Address, nonce uint64) { | ||||||
| 	stateObject := self.GetOrNewStateObject(addr) | 	stateObject := self.GetOrNewStateObject(addr) | ||||||
| 	if stateObject != nil { | 	if stateObject != nil { | ||||||
| @@ -299,27 +318,36 @@ func (self *StateDB) SetCode(addr common.Address, code []byte) { | |||||||
| func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { | func (self *StateDB) SetState(addr common.Address, key common.Hash, value common.Hash) { | ||||||
| 	stateObject := self.GetOrNewStateObject(addr) | 	stateObject := self.GetOrNewStateObject(addr) | ||||||
| 	if stateObject != nil { | 	if stateObject != nil { | ||||||
| 		stateObject.SetState(key, value) | 		stateObject.SetState(self.db, key, value) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Delete marks the given account as suicided. | ||||||
|  | // This clears the account balance. | ||||||
|  | // | ||||||
|  | // The account's state object is still available until the state is committed, | ||||||
|  | // GetStateObject will return a non-nil account after Delete. | ||||||
| func (self *StateDB) Delete(addr common.Address) bool { | func (self *StateDB) Delete(addr common.Address) bool { | ||||||
| 	stateObject := self.GetStateObject(addr) | 	stateObject := self.GetStateObject(addr) | ||||||
| 	if stateObject != nil { | 	if stateObject == nil { | ||||||
| 		stateObject.MarkForDeletion() | 		return false | ||||||
|  | 	} | ||||||
|  | 	self.journal = append(self.journal, deleteAccountChange{ | ||||||
|  | 		account:     &addr, | ||||||
|  | 		prev:        stateObject.remove, | ||||||
|  | 		prevbalance: new(big.Int).Set(stateObject.Balance()), | ||||||
|  | 	}) | ||||||
|  | 	stateObject.markForDeletion() | ||||||
| 	stateObject.data.Balance = new(big.Int) | 	stateObject.data.Balance = new(big.Int) | ||||||
| 	return true | 	return true | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return false |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // | ||||||
| // Setting, updating & deleting state object methods | // Setting, updating & deleting state object methods | ||||||
| // | // | ||||||
|  |  | ||||||
| // Update the given state object and apply it to state trie | // updateStateObject writes the given object to the trie. | ||||||
| func (self *StateDB) UpdateStateObject(stateObject *StateObject) { | func (self *StateDB) updateStateObject(stateObject *StateObject) { | ||||||
| 	addr := stateObject.Address() | 	addr := stateObject.Address() | ||||||
| 	data, err := rlp.EncodeToBytes(stateObject) | 	data, err := rlp.EncodeToBytes(stateObject) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -328,10 +356,9 @@ func (self *StateDB) UpdateStateObject(stateObject *StateObject) { | |||||||
| 	self.trie.Update(addr[:], data) | 	self.trie.Update(addr[:], data) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Delete the given state object and delete it from the state trie | // deleteStateObject removes the given object from the state trie. | ||||||
| func (self *StateDB) DeleteStateObject(stateObject *StateObject) { | func (self *StateDB) deleteStateObject(stateObject *StateObject) { | ||||||
| 	stateObject.deleted = true | 	stateObject.deleted = true | ||||||
|  |  | ||||||
| 	addr := stateObject.Address() | 	addr := stateObject.Address() | ||||||
| 	self.trie.Delete(addr[:]) | 	self.trie.Delete(addr[:]) | ||||||
| } | } | ||||||
| @@ -357,12 +384,12 @@ func (self *StateDB) GetStateObject(addr common.Address) (stateObject *StateObje | |||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
| 	// Insert into the live set. | 	// Insert into the live set. | ||||||
| 	obj := NewObject(addr, data, self.MarkStateObjectDirty) | 	obj := newObject(self, addr, data, self.MarkStateObjectDirty) | ||||||
| 	self.SetStateObject(obj) | 	self.setStateObject(obj) | ||||||
| 	return obj | 	return obj | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) SetStateObject(object *StateObject) { | func (self *StateDB) setStateObject(object *StateObject) { | ||||||
| 	self.stateObjects[object.Address()] = object | 	self.stateObjects[object.Address()] = object | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -370,52 +397,55 @@ func (self *StateDB) SetStateObject(object *StateObject) { | |||||||
| func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { | func (self *StateDB) GetOrNewStateObject(addr common.Address) *StateObject { | ||||||
| 	stateObject := self.GetStateObject(addr) | 	stateObject := self.GetStateObject(addr) | ||||||
| 	if stateObject == nil || stateObject.deleted { | 	if stateObject == nil || stateObject.deleted { | ||||||
| 		stateObject = self.CreateStateObject(addr) | 		stateObject, _ = self.createObject(addr) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return stateObject | 	return stateObject | ||||||
| } | } | ||||||
|  |  | ||||||
| // NewStateObject create a state object whether it exist in the trie or not |  | ||||||
| func (self *StateDB) newStateObject(addr common.Address) *StateObject { |  | ||||||
| 	if glog.V(logger.Core) { |  | ||||||
| 		glog.Infof("(+) %x\n", addr) |  | ||||||
| 	} |  | ||||||
| 	obj := NewObject(addr, Account{}, self.MarkStateObjectDirty) |  | ||||||
| 	obj.SetNonce(StartingNonce) // sets the object to dirty |  | ||||||
| 	self.stateObjects[addr] = obj |  | ||||||
| 	return obj |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly | // MarkStateObjectDirty adds the specified object to the dirty map to avoid costly | ||||||
| // state object cache iteration to find a handful of modified ones. | // state object cache iteration to find a handful of modified ones. | ||||||
| func (self *StateDB) MarkStateObjectDirty(addr common.Address) { | func (self *StateDB) MarkStateObjectDirty(addr common.Address) { | ||||||
| 	self.stateObjectsDirty[addr] = struct{}{} | 	self.stateObjectsDirty[addr] = struct{}{} | ||||||
| } | } | ||||||
|  |  | ||||||
| // Creates creates a new state object and takes ownership. | // createObject creates a new state object. If there is an existing account with | ||||||
| func (self *StateDB) CreateStateObject(addr common.Address) *StateObject { | // the given address, it is overwritten and returned as the second return value. | ||||||
| 	// Get previous (if any) | func (self *StateDB) createObject(addr common.Address) (newobj, prev *StateObject) { | ||||||
| 	so := self.GetStateObject(addr) | 	prev = self.GetStateObject(addr) | ||||||
| 	// Create a new one | 	newobj = newObject(self, addr, Account{}, self.MarkStateObjectDirty) | ||||||
| 	newSo := self.newStateObject(addr) | 	newobj.setNonce(StartingNonce) // sets the object to dirty | ||||||
|  | 	if prev == nil { | ||||||
| 	// If it existed set the balance to the new account | 		if glog.V(logger.Core) { | ||||||
| 	if so != nil { | 			glog.Infof("(+) %x\n", addr) | ||||||
| 		newSo.data.Balance = so.data.Balance |  | ||||||
| 		} | 		} | ||||||
|  | 		self.journal = append(self.journal, createObjectChange{account: &addr}) | ||||||
| 	return newSo | 	} else { | ||||||
|  | 		self.journal = append(self.journal, resetObjectChange{prev: prev}) | ||||||
|  | 	} | ||||||
|  | 	self.setStateObject(newobj) | ||||||
|  | 	return newobj, prev | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // CreateAccount explicitly creates a state object. If a state object with the address | ||||||
|  | // already exists the balance is carried over to the new account. | ||||||
|  | // | ||||||
|  | // CreateAccount is called during the EVM CREATE operation. The situation might arise that | ||||||
|  | // a contract does the following: | ||||||
|  | // | ||||||
|  | //   1. sends funds to sha(account ++ (nonce + 1)) | ||||||
|  | //   2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) | ||||||
|  | // | ||||||
|  | // Carrying over the balance ensures that Ether doesn't disappear. | ||||||
| func (self *StateDB) CreateAccount(addr common.Address) vm.Account { | func (self *StateDB) CreateAccount(addr common.Address) vm.Account { | ||||||
| 	return self.CreateStateObject(addr) | 	new, prev := self.createObject(addr) | ||||||
|  | 	if prev != nil { | ||||||
|  | 		new.setBalance(prev.data.Balance) | ||||||
|  | 	} | ||||||
|  | 	return new | ||||||
| } | } | ||||||
|  |  | ||||||
| // | // Copy creates a deep, independent copy of the state. | ||||||
| // Setting, copying of the state methods | // Snapshots of the copied state cannot be applied to the copy. | ||||||
| // |  | ||||||
|  |  | ||||||
| func (self *StateDB) Copy() *StateDB { | func (self *StateDB) Copy() *StateDB { | ||||||
| 	self.lock.Lock() | 	self.lock.Lock() | ||||||
| 	defer self.lock.Unlock() | 	defer self.lock.Unlock() | ||||||
| @@ -434,7 +464,7 @@ func (self *StateDB) Copy() *StateDB { | |||||||
| 	} | 	} | ||||||
| 	// Copy the dirty states and logs | 	// Copy the dirty states and logs | ||||||
| 	for addr, _ := range self.stateObjectsDirty { | 	for addr, _ := range self.stateObjectsDirty { | ||||||
| 		state.stateObjects[addr] = self.stateObjects[addr].Copy(self.db, state.MarkStateObjectDirty) | 		state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) | ||||||
| 		state.stateObjectsDirty[addr] = struct{}{} | 		state.stateObjectsDirty[addr] = struct{}{} | ||||||
| 	} | 	} | ||||||
| 	for hash, logs := range self.logs { | 	for hash, logs := range self.logs { | ||||||
| @@ -444,21 +474,38 @@ func (self *StateDB) Copy() *StateDB { | |||||||
| 	return state | 	return state | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) Set(state *StateDB) { | // Snapshot returns an identifier for the current revision of the state. | ||||||
| 	self.lock.Lock() | func (self *StateDB) Snapshot() int { | ||||||
| 	defer self.lock.Unlock() | 	id := self.nextRevisionId | ||||||
|  | 	self.nextRevisionId++ | ||||||
| 	self.db = state.db | 	self.validRevisions = append(self.validRevisions, revision{id, len(self.journal)}) | ||||||
| 	self.trie = state.trie | 	return id | ||||||
| 	self.pastTries = state.pastTries |  | ||||||
| 	self.stateObjects = state.stateObjects |  | ||||||
| 	self.stateObjectsDirty = state.stateObjectsDirty |  | ||||||
| 	self.codeSizeCache = state.codeSizeCache |  | ||||||
| 	self.refund = state.refund |  | ||||||
| 	self.logs = state.logs |  | ||||||
| 	self.logSize = state.logSize |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // RevertToSnapshot reverts all state changes made since the given revision. | ||||||
|  | func (self *StateDB) RevertToSnapshot(revid int) { | ||||||
|  | 	// Find the snapshot in the stack of valid snapshots. | ||||||
|  | 	idx := sort.Search(len(self.validRevisions), func(i int) bool { | ||||||
|  | 		return self.validRevisions[i].id >= revid | ||||||
|  | 	}) | ||||||
|  | 	if idx == len(self.validRevisions) || self.validRevisions[idx].id != revid { | ||||||
|  | 		panic(fmt.Errorf("revision id %v cannot be reverted", revid)) | ||||||
|  | 	} | ||||||
|  | 	snapshot := self.validRevisions[idx].journalIndex | ||||||
|  |  | ||||||
|  | 	// Replay the journal to undo changes. | ||||||
|  | 	for i := len(self.journal) - 1; i >= snapshot; i-- { | ||||||
|  | 		self.journal[i].undo(self) | ||||||
|  | 	} | ||||||
|  | 	self.journal = self.journal[:snapshot] | ||||||
|  |  | ||||||
|  | 	// Remove invalidated snapshots from the stack. | ||||||
|  | 	self.validRevisions = self.validRevisions[:idx] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GetRefund returns the current value of the refund counter. | ||||||
|  | // The return value must not be modified by the caller and will become | ||||||
|  | // invalid at the next call to AddRefund. | ||||||
| func (self *StateDB) GetRefund() *big.Int { | func (self *StateDB) GetRefund() *big.Int { | ||||||
| 	return self.refund | 	return self.refund | ||||||
| } | } | ||||||
| @@ -467,16 +514,17 @@ func (self *StateDB) GetRefund() *big.Int { | |||||||
| // It is called in between transactions to get the root hash that | // It is called in between transactions to get the root hash that | ||||||
| // goes into transaction receipts. | // goes into transaction receipts. | ||||||
| func (s *StateDB) IntermediateRoot() common.Hash { | func (s *StateDB) IntermediateRoot() common.Hash { | ||||||
| 	s.refund = new(big.Int) |  | ||||||
| 	for addr, _ := range s.stateObjectsDirty { | 	for addr, _ := range s.stateObjectsDirty { | ||||||
| 		stateObject := s.stateObjects[addr] | 		stateObject := s.stateObjects[addr] | ||||||
| 		if stateObject.remove { | 		if stateObject.remove { | ||||||
| 			s.DeleteStateObject(stateObject) | 			s.deleteStateObject(stateObject) | ||||||
| 		} else { | 		} else { | ||||||
| 			stateObject.UpdateRoot(s.db) | 			stateObject.updateRoot(s.db) | ||||||
| 			s.UpdateStateObject(stateObject) | 			s.updateStateObject(stateObject) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 	// Invalidate journal because reverting across transactions is not allowed. | ||||||
|  | 	s.clearJournalAndRefund() | ||||||
| 	return s.trie.Hash() | 	return s.trie.Hash() | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -486,9 +534,9 @@ func (s *StateDB) IntermediateRoot() common.Hash { | |||||||
| // DeleteSuicides should not be used for consensus related updates | // DeleteSuicides should not be used for consensus related updates | ||||||
| // under any circumstances. | // under any circumstances. | ||||||
| func (s *StateDB) DeleteSuicides() { | func (s *StateDB) DeleteSuicides() { | ||||||
| 	// Reset refund so that any used-gas calculations can use | 	// Reset refund so that any used-gas calculations can use this method. | ||||||
| 	// this method. | 	s.clearJournalAndRefund() | ||||||
| 	s.refund = new(big.Int) |  | ||||||
| 	for addr, _ := range s.stateObjectsDirty { | 	for addr, _ := range s.stateObjectsDirty { | ||||||
| 		stateObject := s.stateObjects[addr] | 		stateObject := s.stateObjects[addr] | ||||||
|  |  | ||||||
| @@ -516,15 +564,21 @@ func (s *StateDB) CommitBatch() (root common.Hash, batch ethdb.Batch) { | |||||||
| 	return root, batch | 	return root, batch | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { | func (s *StateDB) clearJournalAndRefund() { | ||||||
|  | 	s.journal = nil | ||||||
|  | 	s.validRevisions = s.validRevisions[:0] | ||||||
| 	s.refund = new(big.Int) | 	s.refund = new(big.Int) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) { | ||||||
|  | 	defer s.clearJournalAndRefund() | ||||||
|  |  | ||||||
| 	// Commit objects to the trie. | 	// Commit objects to the trie. | ||||||
| 	for addr, stateObject := range s.stateObjects { | 	for addr, stateObject := range s.stateObjects { | ||||||
| 		if stateObject.remove { | 		if stateObject.remove { | ||||||
| 			// If the object has been removed, don't bother syncing it | 			// If the object has been removed, don't bother syncing it | ||||||
| 			// and just mark it for deletion in the trie. | 			// and just mark it for deletion in the trie. | ||||||
| 			s.DeleteStateObject(stateObject) | 			s.deleteStateObject(stateObject) | ||||||
| 		} else if _, ok := s.stateObjectsDirty[addr]; ok { | 		} else if _, ok := s.stateObjectsDirty[addr]; ok { | ||||||
| 			// Write any contract code associated with the state object | 			// Write any contract code associated with the state object | ||||||
| 			if stateObject.code != nil && stateObject.dirtyCode { | 			if stateObject.code != nil && stateObject.dirtyCode { | ||||||
| @@ -538,7 +592,7 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) | |||||||
| 				return common.Hash{}, err | 				return common.Hash{}, err | ||||||
| 			} | 			} | ||||||
| 			// Update the object in the main account trie. | 			// Update the object in the main account trie. | ||||||
| 			s.UpdateStateObject(stateObject) | 			s.updateStateObject(stateObject) | ||||||
| 		} | 		} | ||||||
| 		delete(s.stateObjectsDirty, addr) | 		delete(s.stateObjectsDirty, addr) | ||||||
| 	} | 	} | ||||||
| @@ -549,7 +603,3 @@ func (s *StateDB) commit(dbw trie.DatabaseWriter) (root common.Hash, err error) | |||||||
| 	} | 	} | ||||||
| 	return root, err | 	return root, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *StateDB) Refunds() *big.Int { |  | ||||||
| 	return self.refund |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -17,11 +17,19 @@ | |||||||
| package state | package state | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/binary" | ||||||
|  | 	"fmt" | ||||||
|  | 	"math" | ||||||
| 	"math/big" | 	"math/big" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"testing/quick" | ||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/common" | 	"github.com/ethereum/go-ethereum/common" | ||||||
| 	"github.com/ethereum/go-ethereum/crypto" | 	"github.com/ethereum/go-ethereum/core/vm" | ||||||
| 	"github.com/ethereum/go-ethereum/ethdb" | 	"github.com/ethereum/go-ethereum/ethdb" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -34,16 +42,16 @@ func TestUpdateLeaks(t *testing.T) { | |||||||
|  |  | ||||||
| 	// Update it with some accounts | 	// Update it with some accounts | ||||||
| 	for i := byte(0); i < 255; i++ { | 	for i := byte(0); i < 255; i++ { | ||||||
| 		obj := state.GetOrNewStateObject(common.BytesToAddress([]byte{i})) | 		addr := common.BytesToAddress([]byte{i}) | ||||||
| 		obj.AddBalance(big.NewInt(int64(11 * i))) | 		state.AddBalance(addr, big.NewInt(int64(11*i))) | ||||||
| 		obj.SetNonce(uint64(42 * i)) | 		state.SetNonce(addr, uint64(42*i)) | ||||||
| 		if i%2 == 0 { | 		if i%2 == 0 { | ||||||
| 			obj.SetState(common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) | 			state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) | ||||||
| 		} | 		} | ||||||
| 		if i%3 == 0 { | 		if i%3 == 0 { | ||||||
| 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) | 			state.SetCode(addr, []byte{i, i, i, i, i}) | ||||||
| 		} | 		} | ||||||
| 		state.UpdateStateObject(obj) | 		state.IntermediateRoot() | ||||||
| 	} | 	} | ||||||
| 	// Ensure that no data was leaked into the database | 	// Ensure that no data was leaked into the database | ||||||
| 	for _, key := range db.Keys() { | 	for _, key := range db.Keys() { | ||||||
| @@ -61,51 +69,38 @@ func TestIntermediateLeaks(t *testing.T) { | |||||||
| 	transState, _ := New(common.Hash{}, transDb) | 	transState, _ := New(common.Hash{}, transDb) | ||||||
| 	finalState, _ := New(common.Hash{}, finalDb) | 	finalState, _ := New(common.Hash{}, finalDb) | ||||||
|  |  | ||||||
| 	// Update the states with some objects | 	modify := func(state *StateDB, addr common.Address, i, tweak byte) { | ||||||
|  | 		state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) | ||||||
|  | 		state.SetNonce(addr, uint64(42*i+tweak)) | ||||||
|  | 		if i%2 == 0 { | ||||||
|  | 			state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) | ||||||
|  | 			state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) | ||||||
|  | 		} | ||||||
|  | 		if i%3 == 0 { | ||||||
|  | 			state.SetCode(addr, []byte{i, i, i, i, i, tweak}) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Modify the transient state. | ||||||
| 	for i := byte(0); i < 255; i++ { | 	for i := byte(0); i < 255; i++ { | ||||||
| 		// Create a new state object with some data into the transition database | 		modify(transState, common.Address{byte(i)}, i, 0) | ||||||
| 		obj := transState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) |  | ||||||
| 		obj.SetBalance(big.NewInt(int64(11 * i))) |  | ||||||
| 		obj.SetNonce(uint64(42 * i)) |  | ||||||
| 		if i%2 == 0 { |  | ||||||
| 			obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.BytesToHash([]byte{i, i, i, i, 0})) |  | ||||||
| 	} | 	} | ||||||
| 		if i%3 == 0 { | 	// Write modifications to trie. | ||||||
| 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 0}), []byte{i, i, i, i, i, 0}) | 	transState.IntermediateRoot() | ||||||
| 		} |  | ||||||
| 		transState.UpdateStateObject(obj) |  | ||||||
|  |  | ||||||
| 		// Overwrite all the data with new values in the transition database | 	// Overwrite all the data with new values in the transient database. | ||||||
| 		obj.SetBalance(big.NewInt(int64(11*i + 1))) | 	for i := byte(0); i < 255; i++ { | ||||||
| 		obj.SetNonce(uint64(42*i + 1)) | 		modify(transState, common.Address{byte(i)}, i, 99) | ||||||
| 		if i%2 == 0 { | 		modify(finalState, common.Address{byte(i)}, i, 99) | ||||||
| 			obj.SetState(common.BytesToHash([]byte{i, i, i, 0}), common.Hash{}) |  | ||||||
| 			obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) |  | ||||||
| 	} | 	} | ||||||
| 		if i%3 == 0 { |  | ||||||
| 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) |  | ||||||
| 		} |  | ||||||
| 		transState.UpdateStateObject(obj) |  | ||||||
|  |  | ||||||
| 		// Create the final state object directly in the final database | 	// Commit and cross check the databases. | ||||||
| 		obj = finalState.GetOrNewStateObject(common.BytesToAddress([]byte{i})) |  | ||||||
| 		obj.SetBalance(big.NewInt(int64(11*i + 1))) |  | ||||||
| 		obj.SetNonce(uint64(42*i + 1)) |  | ||||||
| 		if i%2 == 0 { |  | ||||||
| 			obj.SetState(common.BytesToHash([]byte{i, i, i, 1}), common.BytesToHash([]byte{i, i, i, i, 1})) |  | ||||||
| 		} |  | ||||||
| 		if i%3 == 0 { |  | ||||||
| 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i, 1}), []byte{i, i, i, i, i, 1}) |  | ||||||
| 		} |  | ||||||
| 		finalState.UpdateStateObject(obj) |  | ||||||
| 	} |  | ||||||
| 	if _, err := transState.Commit(); err != nil { | 	if _, err := transState.Commit(); err != nil { | ||||||
| 		t.Fatalf("failed to commit transition state: %v", err) | 		t.Fatalf("failed to commit transition state: %v", err) | ||||||
| 	} | 	} | ||||||
| 	if _, err := finalState.Commit(); err != nil { | 	if _, err := finalState.Commit(); err != nil { | ||||||
| 		t.Fatalf("failed to commit final state: %v", err) | 		t.Fatalf("failed to commit final state: %v", err) | ||||||
| 	} | 	} | ||||||
| 	// Cross check the databases to ensure they are the same |  | ||||||
| 	for _, key := range finalDb.Keys() { | 	for _, key := range finalDb.Keys() { | ||||||
| 		if _, err := transDb.Get(key); err != nil { | 		if _, err := transDb.Get(key); err != nil { | ||||||
| 			val, _ := finalDb.Get(key) | 			val, _ := finalDb.Get(key) | ||||||
| @@ -119,3 +114,243 @@ func TestIntermediateLeaks(t *testing.T) { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestSnapshotRandom(t *testing.T) { | ||||||
|  | 	config := &quick.Config{MaxCount: 1000} | ||||||
|  | 	err := quick.Check((*snapshotTest).run, config) | ||||||
|  | 	if cerr, ok := err.(*quick.CheckError); ok { | ||||||
|  | 		test := cerr.In[0].(*snapshotTest) | ||||||
|  | 		t.Errorf("%v:\n%s", test.err, test) | ||||||
|  | 	} else if err != nil { | ||||||
|  | 		t.Error(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes | ||||||
|  | // captured by the snapshot. Instances of this test with pseudorandom content are created | ||||||
|  | // by Generate. | ||||||
|  | // | ||||||
|  | // The test works as follows: | ||||||
|  | // | ||||||
|  | // A new state is created and all actions are applied to it. Several snapshots are taken | ||||||
|  | // in between actions. The test then reverts each snapshot. For each snapshot the actions | ||||||
|  | // leading up to it are replayed on a fresh, empty state. The behaviour of all public | ||||||
|  | // accessor methods on the reverted state must match the return value of the equivalent | ||||||
|  | // methods on the replayed state. | ||||||
|  | type snapshotTest struct { | ||||||
|  | 	addrs     []common.Address // all account addresses | ||||||
|  | 	actions   []testAction     // modifications to the state | ||||||
|  | 	snapshots []int            // actions indexes at which snapshot is taken | ||||||
|  | 	err       error            // failure details are reported through this field | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type testAction struct { | ||||||
|  | 	name   string | ||||||
|  | 	fn     func(testAction, *StateDB) | ||||||
|  | 	args   []int64 | ||||||
|  | 	noAddr bool | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // newTestAction creates a random action that changes state. | ||||||
|  | func newTestAction(addr common.Address, r *rand.Rand) testAction { | ||||||
|  | 	actions := []testAction{ | ||||||
|  | 		{ | ||||||
|  | 			name: "SetBalance", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.SetBalance(addr, big.NewInt(a.args[0])) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 1), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "AddBalance", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.AddBalance(addr, big.NewInt(a.args[0])) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 1), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "SetNonce", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.SetNonce(addr, uint64(a.args[0])) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 1), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "SetState", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				var key, val common.Hash | ||||||
|  | 				binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) | ||||||
|  | 				binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) | ||||||
|  | 				s.SetState(addr, key, val) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 2), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "SetCode", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				code := make([]byte, 16) | ||||||
|  | 				binary.BigEndian.PutUint64(code, uint64(a.args[0])) | ||||||
|  | 				binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) | ||||||
|  | 				s.SetCode(addr, code) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 2), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "CreateAccount", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.CreateAccount(addr) | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "Delete", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.Delete(addr) | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "AddRefund", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				s.AddRefund(big.NewInt(a.args[0])) | ||||||
|  | 			}, | ||||||
|  | 			args:   make([]int64, 1), | ||||||
|  | 			noAddr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "AddLog", | ||||||
|  | 			fn: func(a testAction, s *StateDB) { | ||||||
|  | 				data := make([]byte, 2) | ||||||
|  | 				binary.BigEndian.PutUint16(data, uint16(a.args[0])) | ||||||
|  | 				s.AddLog(&vm.Log{Address: addr, Data: data}) | ||||||
|  | 			}, | ||||||
|  | 			args: make([]int64, 1), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	action := actions[r.Intn(len(actions))] | ||||||
|  | 	var nameargs []string | ||||||
|  | 	if !action.noAddr { | ||||||
|  | 		nameargs = append(nameargs, addr.Hex()) | ||||||
|  | 	} | ||||||
|  | 	for _, i := range action.args { | ||||||
|  | 		action.args[i] = rand.Int63n(100) | ||||||
|  | 		nameargs = append(nameargs, fmt.Sprint(action.args[i])) | ||||||
|  | 	} | ||||||
|  | 	action.name += strings.Join(nameargs, ", ") | ||||||
|  | 	return action | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Generate returns a new snapshot test of the given size. All randomness is | ||||||
|  | // derived from r. | ||||||
|  | func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { | ||||||
|  | 	// Generate random actions. | ||||||
|  | 	addrs := make([]common.Address, 50) | ||||||
|  | 	for i := range addrs { | ||||||
|  | 		addrs[i][0] = byte(i) | ||||||
|  | 	} | ||||||
|  | 	actions := make([]testAction, size) | ||||||
|  | 	for i := range actions { | ||||||
|  | 		addr := addrs[r.Intn(len(addrs))] | ||||||
|  | 		actions[i] = newTestAction(addr, r) | ||||||
|  | 	} | ||||||
|  | 	// Generate snapshot indexes. | ||||||
|  | 	nsnapshots := int(math.Sqrt(float64(size))) | ||||||
|  | 	if size > 0 && nsnapshots == 0 { | ||||||
|  | 		nsnapshots = 1 | ||||||
|  | 	} | ||||||
|  | 	snapshots := make([]int, nsnapshots) | ||||||
|  | 	snaplen := len(actions) / nsnapshots | ||||||
|  | 	for i := range snapshots { | ||||||
|  | 		// Try to place the snapshots some number of actions apart from each other. | ||||||
|  | 		snapshots[i] = (i * snaplen) + r.Intn(snaplen) | ||||||
|  | 	} | ||||||
|  | 	return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (test *snapshotTest) String() string { | ||||||
|  | 	out := new(bytes.Buffer) | ||||||
|  | 	sindex := 0 | ||||||
|  | 	for i, action := range test.actions { | ||||||
|  | 		if len(test.snapshots) > sindex && i == test.snapshots[sindex] { | ||||||
|  | 			fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) | ||||||
|  | 			sindex++ | ||||||
|  | 		} | ||||||
|  | 		fmt.Fprintf(out, "%4d: %s\n", i, action.name) | ||||||
|  | 	} | ||||||
|  | 	return out.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (test *snapshotTest) run() bool { | ||||||
|  | 	// Run all actions and create snapshots. | ||||||
|  | 	var ( | ||||||
|  | 		db, _        = ethdb.NewMemDatabase() | ||||||
|  | 		state, _     = New(common.Hash{}, db) | ||||||
|  | 		snapshotRevs = make([]int, len(test.snapshots)) | ||||||
|  | 		sindex       = 0 | ||||||
|  | 	) | ||||||
|  | 	for i, action := range test.actions { | ||||||
|  | 		if len(test.snapshots) > sindex && i == test.snapshots[sindex] { | ||||||
|  | 			snapshotRevs[sindex] = state.Snapshot() | ||||||
|  | 			sindex++ | ||||||
|  | 		} | ||||||
|  | 		action.fn(action, state) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Revert all snapshots in reverse order. Each revert must yield a state | ||||||
|  | 	// that is equivalent to fresh state with all actions up the snapshot applied. | ||||||
|  | 	for sindex--; sindex >= 0; sindex-- { | ||||||
|  | 		checkstate, _ := New(common.Hash{}, db) | ||||||
|  | 		for _, action := range test.actions[:test.snapshots[sindex]] { | ||||||
|  | 			action.fn(action, checkstate) | ||||||
|  | 		} | ||||||
|  | 		state.RevertToSnapshot(snapshotRevs[sindex]) | ||||||
|  | 		if err := test.checkEqual(state, checkstate); err != nil { | ||||||
|  | 			test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) | ||||||
|  | 			return false | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // checkEqual checks that methods of state and checkstate return the same values. | ||||||
|  | func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { | ||||||
|  | 	for _, addr := range test.addrs { | ||||||
|  | 		var err error | ||||||
|  | 		checkeq := func(op string, a, b interface{}) bool { | ||||||
|  | 			if err == nil && !reflect.DeepEqual(a, b) { | ||||||
|  | 				err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) | ||||||
|  | 				return false | ||||||
|  | 			} | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 		// Check basic accessor methods. | ||||||
|  | 		checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) | ||||||
|  | 		checkeq("IsDeleted", state.IsDeleted(addr), checkstate.IsDeleted(addr)) | ||||||
|  | 		checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) | ||||||
|  | 		checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) | ||||||
|  | 		checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) | ||||||
|  | 		checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) | ||||||
|  | 		checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) | ||||||
|  | 		// Check storage. | ||||||
|  | 		if obj := state.GetStateObject(addr); obj != nil { | ||||||
|  | 			obj.ForEachStorage(func(key, val common.Hash) bool { | ||||||
|  | 				return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) | ||||||
|  | 			}) | ||||||
|  | 			checkobj := checkstate.GetStateObject(addr) | ||||||
|  | 			checkobj.ForEachStorage(func(key, checkval common.Hash) bool { | ||||||
|  | 				return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) | ||||||
|  | 			}) | ||||||
|  | 		} | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 { | ||||||
|  | 		return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", | ||||||
|  | 			state.GetRefund(), checkstate.GetRefund()) | ||||||
|  | 	} | ||||||
|  | 	if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { | ||||||
|  | 		return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", | ||||||
|  | 			state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|   | |||||||
| @@ -57,7 +57,7 @@ func makeTestState() (ethdb.Database, common.Hash, []*testAccount) { | |||||||
| 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) | 			obj.SetCode(crypto.Keccak256Hash([]byte{i, i, i, i, i}), []byte{i, i, i, i, i}) | ||||||
| 			acc.code = []byte{i, i, i, i, i} | 			acc.code = []byte{i, i, i, i, i} | ||||||
| 		} | 		} | ||||||
| 		state.UpdateStateObject(obj) | 		state.updateStateObject(obj) | ||||||
| 		accounts = append(accounts, acc) | 		accounts = append(accounts, acc) | ||||||
| 	} | 	} | ||||||
| 	root, _ := state.Commit() | 	root, _ := state.Commit() | ||||||
|   | |||||||
| @@ -240,7 +240,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction) error { | |||||||
|  |  | ||||||
| 	// Make sure the account exist. Non existent accounts | 	// Make sure the account exist. Non existent accounts | ||||||
| 	// haven't got funds and well therefor never pass. | 	// haven't got funds and well therefor never pass. | ||||||
| 	if !currentState.HasAccount(from) { | 	if !currentState.Exist(from) { | ||||||
| 		return ErrNonExistentAccount | 		return ErrNonExistentAccount | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -36,9 +36,9 @@ type Environment interface { | |||||||
| 	// The state database | 	// The state database | ||||||
| 	Db() Database | 	Db() Database | ||||||
| 	// Creates a restorable snapshot | 	// Creates a restorable snapshot | ||||||
| 	MakeSnapshot() Database | 	SnapshotDatabase() int | ||||||
| 	// Set database to previous snapshot | 	// Set database to previous snapshot | ||||||
| 	SetSnapshot(Database) | 	RevertToSnapshot(int) | ||||||
| 	// Address of the original invoker (first occurrence of the VM invoker) | 	// Address of the original invoker (first occurrence of the VM invoker) | ||||||
| 	Origin() common.Address | 	Origin() common.Address | ||||||
| 	// The block number this VM is invoked on | 	// The block number this VM is invoked on | ||||||
|   | |||||||
| @@ -187,8 +187,8 @@ func (self *Env) StructLogs() []StructLog { | |||||||
|  |  | ||||||
| //func (self *Env) PrevHash() []byte      { return self.parent } | //func (self *Env) PrevHash() []byte      { return self.parent } | ||||||
| func (self *Env) Coinbase() common.Address { return common.Address{} } | func (self *Env) Coinbase() common.Address { return common.Address{} } | ||||||
| func (self *Env) MakeSnapshot() Database   { return nil } | func (self *Env) SnapshotDatabase() int    { return 0 } | ||||||
| func (self *Env) SetSnapshot(Database)     {} | func (self *Env) RevertToSnapshot(int)     {} | ||||||
| func (self *Env) Time() *big.Int           { return big.NewInt(time.Now().Unix()) } | func (self *Env) Time() *big.Int           { return big.NewInt(time.Now().Unix()) } | ||||||
| func (self *Env) Difficulty() *big.Int     { return big.NewInt(0) } | func (self *Env) Difficulty() *big.Int     { return big.NewInt(0) } | ||||||
| func (self *Env) Db() Database             { return nil } | func (self *Env) Db() Database             { return nil } | ||||||
|   | |||||||
| @@ -100,11 +100,11 @@ func (self *Env) SetDepth(i int) { self.depth = i } | |||||||
| func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { | func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { | ||||||
| 	return self.state.GetBalance(from).Cmp(balance) >= 0 | 	return self.state.GetBalance(from).Cmp(balance) >= 0 | ||||||
| } | } | ||||||
| func (self *Env) MakeSnapshot() vm.Database { | func (self *Env) SnapshotDatabase() int { | ||||||
| 	return self.state.Copy() | 	return self.state.Snapshot() | ||||||
| } | } | ||||||
| func (self *Env) SetSnapshot(copy vm.Database) { | func (self *Env) RevertToSnapshot(snapshot int) { | ||||||
| 	self.state.Set(copy.(*state.StateDB)) | 	self.state.RevertToSnapshot(snapshot) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { | func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { | ||||||
|   | |||||||
| @@ -95,12 +95,12 @@ func (self *VMEnv) CanTransfer(from common.Address, balance *big.Int) bool { | |||||||
| 	return self.state.GetBalance(from).Cmp(balance) >= 0 | 	return self.state.GetBalance(from).Cmp(balance) >= 0 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *VMEnv) MakeSnapshot() vm.Database { | func (self *VMEnv) SnapshotDatabase() int { | ||||||
| 	return self.state.Copy() | 	return self.state.Snapshot() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *VMEnv) SetSnapshot(copy vm.Database) { | func (self *VMEnv) RevertToSnapshot(snapshot int) { | ||||||
| 	self.state.Set(copy.(*state.StateDB)) | 	self.state.RevertToSnapshot(snapshot) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { | func (self *VMEnv) Transfer(from, to vm.Account, amount *big.Int) { | ||||||
|   | |||||||
| @@ -23,7 +23,6 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/common" | 	"github.com/ethereum/go-ethereum/common" | ||||||
| 	"github.com/ethereum/go-ethereum/core/state" | 	"github.com/ethereum/go-ethereum/core/state" | ||||||
| 	"github.com/ethereum/go-ethereum/crypto" |  | ||||||
| 	"github.com/ethereum/go-ethereum/ethdb" | 	"github.com/ethereum/go-ethereum/ethdb" | ||||||
| 	"github.com/ethereum/go-ethereum/trie" | 	"github.com/ethereum/go-ethereum/trie" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| @@ -54,16 +53,13 @@ func makeTestState() (common.Hash, ethdb.Database) { | |||||||
| 	sdb, _ := ethdb.NewMemDatabase() | 	sdb, _ := ethdb.NewMemDatabase() | ||||||
| 	st, _ := state.New(common.Hash{}, sdb) | 	st, _ := state.New(common.Hash{}, sdb) | ||||||
| 	for i := byte(0); i < 100; i++ { | 	for i := byte(0); i < 100; i++ { | ||||||
| 		so := st.GetOrNewStateObject(common.Address{i}) | 		addr := common.Address{i} | ||||||
| 		for j := byte(0); j < 100; j++ { | 		for j := byte(0); j < 100; j++ { | ||||||
| 			val := common.Hash{i, j} | 			st.SetState(addr, common.Hash{j}, common.Hash{i, j}) | ||||||
| 			so.SetState(common.Hash{j}, val) |  | ||||||
| 			so.SetNonce(100) |  | ||||||
| 		} | 		} | ||||||
| 		so.AddBalance(big.NewInt(int64(i))) | 		st.SetNonce(addr, 100) | ||||||
| 		so.SetCode(crypto.Keccak256Hash([]byte{i, i, i}), []byte{i, i, i}) | 		st.AddBalance(addr, big.NewInt(int64(i))) | ||||||
| 		so.UpdateRoot(sdb) | 		st.SetCode(addr, []byte{i, i, i}) | ||||||
| 		st.UpdateStateObject(so) |  | ||||||
| 	} | 	} | ||||||
| 	root, _ := st.Commit() | 	root, _ := st.Commit() | ||||||
| 	return root, sdb | 	return root, sdb | ||||||
|   | |||||||
| @@ -173,7 +173,7 @@ func (self *worker) pending() (*types.Block, *state.StateDB) { | |||||||
| 			self.current.receipts, | 			self.current.receipts, | ||||||
| 		), self.current.state | 		), self.current.state | ||||||
| 	} | 	} | ||||||
| 	return self.current.Block, self.current.state | 	return self.current.Block, self.current.state.Copy() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *worker) start() { | func (self *worker) start() { | ||||||
| @@ -665,7 +665,7 @@ func (env *Work) commitTransactions(mux *event.TypeMux, transactions types.Trans | |||||||
| } | } | ||||||
|  |  | ||||||
| func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { | func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, gp *core.GasPool) (error, vm.Logs) { | ||||||
| 	snap := env.state.Copy() | 	snap := env.state.Snapshot() | ||||||
|  |  | ||||||
| 	// this is a bit of a hack to force jit for the miners | 	// this is a bit of a hack to force jit for the miners | ||||||
| 	config := env.config.VmConfig | 	config := env.config.VmConfig | ||||||
| @@ -676,7 +676,7 @@ func (env *Work) commitTransaction(tx *types.Transaction, bc *core.BlockChain, g | |||||||
|  |  | ||||||
| 	receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) | 	receipt, logs, _, err := core.ApplyTransaction(env.config, bc, gp, env.state, env.header, tx, env.header.GasUsed, config) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		env.state.Set(snap) | 		env.state.RevertToSnapshot(snap) | ||||||
| 		return err, nil | 		return err, nil | ||||||
| 	} | 	} | ||||||
| 	env.txs = append(env.txs, tx) | 	env.txs = append(env.txs, tx) | ||||||
|   | |||||||
| @@ -96,14 +96,7 @@ func BenchStateTest(ruleSet RuleSet, p string, conf bconf, b *testing.B) error { | |||||||
| func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { | func benchStateTest(ruleSet RuleSet, test VmTest, env map[string]string, b *testing.B) { | ||||||
| 	b.StopTimer() | 	b.StopTimer() | ||||||
| 	db, _ := ethdb.NewMemDatabase() | 	db, _ := ethdb.NewMemDatabase() | ||||||
| 	statedb, _ := state.New(common.Hash{}, db) | 	statedb := makePreState(db, test.Pre) | ||||||
| 	for addr, account := range test.Pre { |  | ||||||
| 		obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) |  | ||||||
| 		statedb.SetStateObject(obj) |  | ||||||
| 		for a, v := range account.Storage { |  | ||||||
| 			obj.SetState(common.HexToHash(a), common.HexToHash(v)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	b.StartTimer() | 	b.StartTimer() | ||||||
|  |  | ||||||
| 	RunState(ruleSet, statedb, env, test.Exec) | 	RunState(ruleSet, statedb, env, test.Exec) | ||||||
| @@ -135,14 +128,7 @@ func runStateTests(ruleSet RuleSet, tests map[string]VmTest, skipTests []string) | |||||||
|  |  | ||||||
| func runStateTest(ruleSet RuleSet, test VmTest) error { | func runStateTest(ruleSet RuleSet, test VmTest) error { | ||||||
| 	db, _ := ethdb.NewMemDatabase() | 	db, _ := ethdb.NewMemDatabase() | ||||||
| 	statedb, _ := state.New(common.Hash{}, db) | 	statedb := makePreState(db, test.Pre) | ||||||
| 	for addr, account := range test.Pre { |  | ||||||
| 		obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) |  | ||||||
| 		statedb.SetStateObject(obj) |  | ||||||
| 		for a, v := range account.Storage { |  | ||||||
| 			obj.SetState(common.HexToHash(a), common.HexToHash(v)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// XXX Yeah, yeah... | 	// XXX Yeah, yeah... | ||||||
| 	env := make(map[string]string) | 	env := make(map[string]string) | ||||||
| @@ -234,7 +220,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string | |||||||
| 	} | 	} | ||||||
| 	// Set pre compiled contracts | 	// Set pre compiled contracts | ||||||
| 	vm.Precompiled = vm.PrecompiledContracts() | 	vm.Precompiled = vm.PrecompiledContracts() | ||||||
| 	snapshot := statedb.Copy() | 	snapshot := statedb.Snapshot() | ||||||
| 	gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) | 	gaspool := new(core.GasPool).AddGas(common.Big(env["currentGasLimit"])) | ||||||
|  |  | ||||||
| 	key, _ := hex.DecodeString(tx["secretKey"]) | 	key, _ := hex.DecodeString(tx["secretKey"]) | ||||||
| @@ -244,7 +230,7 @@ func RunState(ruleSet RuleSet, statedb *state.StateDB, env, tx map[string]string | |||||||
| 	vmenv.origin = addr | 	vmenv.origin = addr | ||||||
| 	ret, _, err := core.ApplyMessage(vmenv, message, gaspool) | 	ret, _, err := core.ApplyMessage(vmenv, message, gaspool) | ||||||
| 	if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { | 	if core.IsNonceErr(err) || core.IsInvalidTxErr(err) || core.IsGasLimitErr(err) { | ||||||
| 		statedb.Set(snapshot) | 		statedb.RevertToSnapshot(snapshot) | ||||||
| 	} | 	} | ||||||
| 	statedb.Commit() | 	statedb.Commit() | ||||||
|  |  | ||||||
|   | |||||||
| @@ -103,19 +103,25 @@ func (self Log) Topics() [][]byte { | |||||||
| 	return t | 	return t | ||||||
| } | } | ||||||
|  |  | ||||||
| func StateObjectFromAccount(db ethdb.Database, addr string, account Account, onDirty func(common.Address)) *state.StateObject { | func makePreState(db ethdb.Database, accounts map[string]Account) *state.StateDB { | ||||||
|  | 	statedb, _ := state.New(common.Hash{}, db) | ||||||
|  | 	for addr, account := range accounts { | ||||||
|  | 		insertAccount(statedb, addr, account) | ||||||
|  | 	} | ||||||
|  | 	return statedb | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func insertAccount(state *state.StateDB, saddr string, account Account) { | ||||||
| 	if common.IsHex(account.Code) { | 	if common.IsHex(account.Code) { | ||||||
| 		account.Code = account.Code[2:] | 		account.Code = account.Code[2:] | ||||||
| 	} | 	} | ||||||
| 	code := common.Hex2Bytes(account.Code) | 	addr := common.HexToAddress(saddr) | ||||||
| 	codeHash := crypto.Keccak256Hash(code) | 	state.SetCode(addr, common.Hex2Bytes(account.Code)) | ||||||
| 	obj := state.NewObject(common.HexToAddress(addr), state.Account{ | 	state.SetNonce(addr, common.Big(account.Nonce).Uint64()) | ||||||
| 		Balance:  common.Big(account.Balance), | 	state.SetBalance(addr, common.Big(account.Balance)) | ||||||
| 		CodeHash: codeHash[:], | 	for a, v := range account.Storage { | ||||||
| 		Nonce:    common.Big(account.Nonce).Uint64(), | 		state.SetState(addr, common.HexToHash(a), common.HexToHash(v)) | ||||||
| 	}, onDirty) | 	} | ||||||
| 	obj.SetCode(codeHash, code) |  | ||||||
| 	return obj |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type VmEnv struct { | type VmEnv struct { | ||||||
| @@ -239,11 +245,11 @@ func (self *Env) CanTransfer(from common.Address, balance *big.Int) bool { | |||||||
|  |  | ||||||
| 	return self.state.GetBalance(from).Cmp(balance) >= 0 | 	return self.state.GetBalance(from).Cmp(balance) >= 0 | ||||||
| } | } | ||||||
| func (self *Env) MakeSnapshot() vm.Database { | func (self *Env) SnapshotDatabase() int { | ||||||
| 	return self.state.Copy() | 	return self.state.Snapshot() | ||||||
| } | } | ||||||
| func (self *Env) SetSnapshot(copy vm.Database) { | func (self *Env) RevertToSnapshot(snapshot int) { | ||||||
| 	self.state.Set(copy.(*state.StateDB)) | 	self.state.RevertToSnapshot(snapshot) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { | func (self *Env) Transfer(from, to vm.Account, amount *big.Int) { | ||||||
|   | |||||||
| @@ -101,14 +101,7 @@ func BenchVmTest(p string, conf bconf, b *testing.B) error { | |||||||
| func benchVmTest(test VmTest, env map[string]string, b *testing.B) { | func benchVmTest(test VmTest, env map[string]string, b *testing.B) { | ||||||
| 	b.StopTimer() | 	b.StopTimer() | ||||||
| 	db, _ := ethdb.NewMemDatabase() | 	db, _ := ethdb.NewMemDatabase() | ||||||
| 	statedb, _ := state.New(common.Hash{}, db) | 	statedb := makePreState(db, test.Pre) | ||||||
| 	for addr, account := range test.Pre { |  | ||||||
| 		obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) |  | ||||||
| 		statedb.SetStateObject(obj) |  | ||||||
| 		for a, v := range account.Storage { |  | ||||||
| 			obj.SetState(common.HexToHash(a), common.HexToHash(v)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	b.StartTimer() | 	b.StartTimer() | ||||||
|  |  | ||||||
| 	RunVm(statedb, env, test.Exec) | 	RunVm(statedb, env, test.Exec) | ||||||
| @@ -152,14 +145,7 @@ func runVmTests(tests map[string]VmTest, skipTests []string) error { | |||||||
|  |  | ||||||
| func runVmTest(test VmTest) error { | func runVmTest(test VmTest) error { | ||||||
| 	db, _ := ethdb.NewMemDatabase() | 	db, _ := ethdb.NewMemDatabase() | ||||||
| 	statedb, _ := state.New(common.Hash{}, db) | 	statedb := makePreState(db, test.Pre) | ||||||
| 	for addr, account := range test.Pre { |  | ||||||
| 		obj := StateObjectFromAccount(db, addr, account, statedb.MarkStateObjectDirty) |  | ||||||
| 		statedb.SetStateObject(obj) |  | ||||||
| 		for a, v := range account.Storage { |  | ||||||
| 			obj.SetState(common.HexToHash(a), common.HexToHash(v)) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// XXX Yeah, yeah... | 	// XXX Yeah, yeah... | ||||||
| 	env := make(map[string]string) | 	env := make(map[string]string) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user