From 4515f84958874c44cd826a0a141b70867758cb0d Mon Sep 17 00:00:00 2001 From: rjl493456442 Date: Wed, 3 Jun 2020 14:10:41 +0800 Subject: [PATCH] accounts/abi/bind/backends: unify simulatedBackend and RPC --- accounts/abi/bind/backends/simulated.go | 39 +++++++++++++++----- accounts/abi/bind/backends/simulated_test.go | 32 +++++++++++----- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 78e14cc09d..cadf095efe 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -344,6 +344,21 @@ func (b *SimulatedBackend) PendingCodeAt(ctx context.Context, contract common.Ad return b.pendingState.GetCode(contract), nil } +type revertError struct { + error + errData interface{} // additional data +} + +func (e revertError) ErrorCode() int { + // revert errors are execution errors. + // See: https://github.com/ethereum/wiki/wiki/JSON-RPC-Error-Codes-Improvement-Proposal + return 3 +} + +func (e revertError) ErrorData() interface{} { + return e.errData +} + // CallContract executes a contract call. func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallMsg, blockNumber *big.Int) ([]byte, error) { b.mu.Lock() @@ -364,7 +379,10 @@ func (b *SimulatedBackend) CallContract(ctx context.Context, call ethereum.CallM if len(res.Revert()) > 0 { reason, err := abi.UnpackRevert(res.Revert()) if err == nil { - return nil, fmt.Errorf("execution reverted: %v", reason) + return nil, &revertError{ + error: errors.New("execution reverted"), + errData: reason, + } } } return res.Return(), res.Err @@ -384,7 +402,10 @@ func (b *SimulatedBackend) PendingCallContract(ctx context.Context, call ethereu if len(res.Revert()) > 0 { reason, err := abi.UnpackRevert(res.Revert()) if err == nil { - return nil, fmt.Errorf("execution reverted: %v", reason) + return nil, &revertError{ + error: errors.New("execution reverted"), + errData: reason, + } } } return res.Return(), res.Err @@ -486,16 +507,16 @@ func (b *SimulatedBackend) EstimateGas(ctx context.Context, call ethereum.CallMs } if failed { if result != nil && result.Err != vm.ErrOutOfGas { - errMsg := fmt.Sprintf("always failing transaction (%v)", result.Err) if len(result.Revert()) > 0 { - ret, err := abi.UnpackRevert(result.Revert()) - if err != nil { - errMsg += fmt.Sprintf(" (%#x)", result.Revert()) - } else { - errMsg += fmt.Sprintf(" (%s)", ret) + reason, err := abi.UnpackRevert(result.Revert()) + if err == nil { + return 0, &revertError{ + error: errors.New("execution reverted"), + errData: reason, + } } } - return 0, errors.New(errMsg) + return 0, result.Err } // Otherwise, the specified gas cap is too low return 0, fmt.Errorf("gas required exceeds allowance (%d)", cap) diff --git a/accounts/abi/bind/backends/simulated_test.go b/accounts/abi/bind/backends/simulated_test.go index 9631bbb6e8..2d246d353e 100644 --- a/accounts/abi/bind/backends/simulated_test.go +++ b/accounts/abi/bind/backends/simulated_test.go @@ -20,8 +20,8 @@ import ( "bytes" "context" "errors" - "fmt" "math/big" + "reflect" "strings" "testing" "time" @@ -388,6 +388,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { message ethereum.CallMsg expect uint64 expectError error + expectData interface{} }{ {"plain transfer(valid)", ethereum.CallMsg{ From: addr, @@ -396,7 +397,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: big.NewInt(1), Data: nil, - }, params.TxGas, nil}, + }, params.TxGas, nil, nil}, {"plain transfer(invalid)", ethereum.CallMsg{ From: addr, @@ -405,7 +406,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: big.NewInt(1), Data: nil, - }, 0, errors.New("always failing transaction (execution reverted)")}, + }, 0, errors.New("execution reverted"), nil}, {"Revert", ethereum.CallMsg{ From: addr, @@ -414,7 +415,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: nil, Data: common.Hex2Bytes("d8b98391"), - }, 0, errors.New("always failing transaction (execution reverted) (revert reason)")}, + }, 0, errors.New("execution reverted"), "revert reason"}, {"PureRevert", ethereum.CallMsg{ From: addr, @@ -423,7 +424,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: nil, Data: common.Hex2Bytes("aa8b1d30"), - }, 0, errors.New("always failing transaction (execution reverted)")}, + }, 0, errors.New("execution reverted"), nil}, {"OOG", ethereum.CallMsg{ From: addr, @@ -432,7 +433,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: nil, Data: common.Hex2Bytes("50f6fe34"), - }, 0, errors.New("gas required exceeds allowance (100000)")}, + }, 0, errors.New("gas required exceeds allowance (100000)"), nil}, {"Assert", ethereum.CallMsg{ From: addr, @@ -441,7 +442,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: nil, Data: common.Hex2Bytes("b9b046f9"), - }, 0, errors.New("always failing transaction (invalid opcode: opcode 0xfe not defined)")}, + }, 0, errors.New("invalid opcode: opcode 0xfe not defined"), nil}, {"Valid", ethereum.CallMsg{ From: addr, @@ -450,7 +451,7 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { GasPrice: big.NewInt(0), Value: nil, Data: common.Hex2Bytes("e09fface"), - }, 21275, nil}, + }, 21275, nil, nil}, } for _, c := range cases { got, err := sim.EstimateGas(context.Background(), c.message) @@ -461,6 +462,13 @@ func TestSimulatedBackend_EstimateGas(t *testing.T) { if c.expectError.Error() != err.Error() { t.Fatalf("Expect error, want %v, got %v", c.expectError, err) } + if c.expectData != nil { + if rerr, ok := err.(*revertError); !ok { + t.Fatalf("Expect revert error, got %T", err) + } else if !reflect.DeepEqual(rerr.ErrorData(), c.expectData) { + t.Fatalf("Error data mismatch, want %v, got %v", c.expectData, rerr.ErrorData()) + } + } continue } if got != c.expect { @@ -1041,8 +1049,12 @@ func TestSimulatedBackend_CallContractRevert(t *testing.T) { t.Errorf("result from %v was not nil: %v", key, res) } if val != nil { - if err.Error() != fmt.Sprintf("execution reverted: %v", val) { - t.Errorf("error was malformed: got %v want %v", err, fmt.Errorf("execution reverted: %v", val)) + rerr, ok := err.(*revertError) + if !ok { + t.Errorf("expect revert error") + } + if !reflect.DeepEqual(rerr.ErrorData(), val) { + t.Errorf("error was malformed: got %v want %v", rerr.ErrorData(), val) } } else { // revert(0x0,0x0)