p2p/testing: check for all expectations in TestExchanges

Handle all expectations in ProtocolSession.TestExchanges in any
order that are received.
This commit is contained in:
Janos Guljas
2018-02-12 18:40:25 +01:00
parent 407339085f
commit e07603bbc4
2 changed files with 205 additions and 62 deletions

View File

@@ -24,7 +24,11 @@ that can be used to send and receive messages
package testing
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"strings"
"sync"
"testing"
@@ -34,6 +38,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
@@ -152,7 +157,7 @@ type mockNode struct {
testNode
trigger chan *Trigger
expect chan *Expect
expect chan []Expect
err chan error
stop chan struct{}
stopOnce sync.Once
@@ -161,7 +166,7 @@ type mockNode struct {
func newMockNode() *mockNode {
mock := &mockNode{
trigger: make(chan *Trigger),
expect: make(chan *Expect),
expect: make(chan []Expect),
err: make(chan error),
stop: make(chan struct{}),
}
@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
select {
case trig := <-m.trigger:
m.err <- p2p.Send(rw, trig.Code, trig.Msg)
case exp := <-m.expect:
m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
return nil
}
@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
return <-m.err
}
func (m *mockNode) Expect(exp *Expect) error {
func (m *mockNode) Expect(exp ...Expect) error {
m.expect <- exp
return <-m.err
}
@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
m.stopOnce.Do(func() { close(m.stop) })
return nil
}
func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
matched := make([]bool, len(exps))
for {
msg, err := rw.ReadMsg()
if err != nil {
if err == io.EOF {
break
}
return err
}
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
}
var found bool
for i, exp := range exps {
if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
matched[i] = true
found = true
break
}
}
if !found {
expected := make([]string, 0)
for i, exp := range exps {
if matched[i] {
continue
}
expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
done := true
for _, m := range matched {
if !m {
done = false
break
}
}
if done {
return nil
}
}
for i, m := range matched {
if !m {
return fmt.Errorf("expected message #%d not received", i)
}
}
return nil
}
// mustEncodeMsg uses rlp to encode a message.
// In case of error it panics.
func mustEncodeMsg(msg interface{}) []byte {
contentEnc, err := rlp.EncodeToBytes(msg)
if err != nil {
panic("content encode error: " + err.Error())
}
return contentEnc
}