p2p/testing: check for all expectations in TestExchanges
Handle all expectations in ProtocolSession.TestExchanges in any order that are received.
This commit is contained in:
@@ -24,7 +24,11 @@ that can be used to send and receive messages
|
||||
package testing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
@@ -34,6 +38,7 @@ import (
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
"github.com/ethereum/go-ethereum/p2p/simulations"
|
||||
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
"github.com/ethereum/go-ethereum/rpc"
|
||||
)
|
||||
|
||||
@@ -152,7 +157,7 @@ type mockNode struct {
|
||||
testNode
|
||||
|
||||
trigger chan *Trigger
|
||||
expect chan *Expect
|
||||
expect chan []Expect
|
||||
err chan error
|
||||
stop chan struct{}
|
||||
stopOnce sync.Once
|
||||
@@ -161,7 +166,7 @@ type mockNode struct {
|
||||
func newMockNode() *mockNode {
|
||||
mock := &mockNode{
|
||||
trigger: make(chan *Trigger),
|
||||
expect: make(chan *Expect),
|
||||
expect: make(chan []Expect),
|
||||
err: make(chan error),
|
||||
stop: make(chan struct{}),
|
||||
}
|
||||
@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
|
||||
select {
|
||||
case trig := <-m.trigger:
|
||||
m.err <- p2p.Send(rw, trig.Code, trig.Msg)
|
||||
case exp := <-m.expect:
|
||||
m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
|
||||
case exps := <-m.expect:
|
||||
m.err <- expectMsgs(rw, exps)
|
||||
case <-m.stop:
|
||||
return nil
|
||||
}
|
||||
@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
|
||||
return <-m.err
|
||||
}
|
||||
|
||||
func (m *mockNode) Expect(exp *Expect) error {
|
||||
func (m *mockNode) Expect(exp ...Expect) error {
|
||||
m.expect <- exp
|
||||
return <-m.err
|
||||
}
|
||||
@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
|
||||
m.stopOnce.Do(func() { close(m.stop) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
|
||||
matched := make([]bool, len(exps))
|
||||
for {
|
||||
msg, err := rw.ReadMsg()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
actualContent, err := ioutil.ReadAll(msg.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var found bool
|
||||
for i, exp := range exps {
|
||||
if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
|
||||
if matched[i] {
|
||||
return fmt.Errorf("message #%d received two times", i)
|
||||
}
|
||||
matched[i] = true
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
expected := make([]string, 0)
|
||||
for i, exp := range exps {
|
||||
if matched[i] {
|
||||
continue
|
||||
}
|
||||
expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
|
||||
}
|
||||
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
|
||||
}
|
||||
done := true
|
||||
for _, m := range matched {
|
||||
if !m {
|
||||
done = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if done {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
for i, m := range matched {
|
||||
if !m {
|
||||
return fmt.Errorf("expected message #%d not received", i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mustEncodeMsg uses rlp to encode a message.
|
||||
// In case of error it panics.
|
||||
func mustEncodeMsg(msg interface{}) []byte {
|
||||
contentEnc, err := rlp.EncodeToBytes(msg)
|
||||
if err != nil {
|
||||
panic("content encode error: " + err.Error())
|
||||
}
|
||||
return contentEnc
|
||||
}
|
||||
|
Reference in New Issue
Block a user