p2p/testing: check for all expectations in TestExchanges
Handle all expectations in ProtocolSession.TestExchanges in any order that are received.
This commit is contained in:
		| @@ -19,13 +19,17 @@ package testing | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/ethereum/go-ethereum/log" | ||||
| 	"github.com/ethereum/go-ethereum/p2p" | ||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||
| 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters" | ||||
| ) | ||||
|  | ||||
| var errTimedOut = errors.New("timed out") | ||||
|  | ||||
| // ProtocolSession is a quasi simulation of a pivot node running | ||||
| // a service and a number of dummy peers that can send (trigger) or | ||||
| // receive (expect) messages | ||||
| @@ -46,6 +50,7 @@ type Exchange struct { | ||||
| 	Label    string | ||||
| 	Triggers []Trigger | ||||
| 	Expects  []Expect | ||||
| 	Timeout  time.Duration | ||||
| } | ||||
|  | ||||
| // Trigger is part of the exchange, incoming message for the pivot node | ||||
| @@ -102,78 +107,147 @@ func (self *ProtocolSession) trigger(trig Trigger) error { | ||||
| } | ||||
|  | ||||
| // expect checks an expectation of a message sent out by the pivot node | ||||
| func (self *ProtocolSession) expect(exp Expect) error { | ||||
| 	if exp.Msg == nil { | ||||
| 		return errors.New("no message to expect") | ||||
| 	} | ||||
| 	simNode, ok := self.adapter.GetNode(exp.Peer) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs)) | ||||
| 	} | ||||
| 	mockNode, ok := simNode.Services()[0].(*mockNode) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer) | ||||
| func (self *ProtocolSession) expect(exps []Expect) error { | ||||
| 	// construct a map of expectations for each node | ||||
| 	peerExpects := make(map[discover.NodeID][]Expect) | ||||
| 	for _, exp := range exps { | ||||
| 		if exp.Msg == nil { | ||||
| 			return errors.New("no message to expect") | ||||
| 		} | ||||
| 		peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) | ||||
| 	} | ||||
|  | ||||
| 	// construct a map of mockNodes for each node | ||||
| 	mockNodes := make(map[discover.NodeID]*mockNode) | ||||
| 	for nodeID := range peerExpects { | ||||
| 		simNode, ok := self.adapter.GetNode(nodeID) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs)) | ||||
| 		} | ||||
| 		mockNode, ok := simNode.Services()[0].(*mockNode) | ||||
| 		if !ok { | ||||
| 			return fmt.Errorf("trigger: peer %v is not a mock", nodeID) | ||||
| 		} | ||||
| 		mockNodes[nodeID] = mockNode | ||||
| 	} | ||||
|  | ||||
| 	// done chanell cancels all created goroutines when function returns | ||||
| 	done := make(chan struct{}) | ||||
| 	defer close(done) | ||||
| 	// errc catches the first error from | ||||
| 	errc := make(chan error) | ||||
|  | ||||
| 	wg := &sync.WaitGroup{} | ||||
| 	wg.Add(len(mockNodes)) | ||||
| 	for nodeID, mockNode := range mockNodes { | ||||
| 		nodeID := nodeID | ||||
| 		mockNode := mockNode | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
|  | ||||
| 			// Sum all Expect timeouts to give the maximum | ||||
| 			// time for all expectations to finish. | ||||
| 			// mockNode.Expect checks all received messages against | ||||
| 			// a list of expected messages and timeout for each | ||||
| 			// of them can not be checked separately. | ||||
| 			var t time.Duration | ||||
| 			for _, exp := range peerExpects[nodeID] { | ||||
| 				if exp.Timeout == time.Duration(0) { | ||||
| 					t += 2000 * time.Millisecond | ||||
| 				} else { | ||||
| 					t += exp.Timeout | ||||
| 				} | ||||
| 			} | ||||
| 			alarm := time.NewTimer(t) | ||||
| 			defer alarm.Stop() | ||||
|  | ||||
| 			// expectErrc is used to check if error returned | ||||
| 			// from mockNode.Expect is not nil and to send it to | ||||
| 			// errc only in that case. | ||||
| 			// done channel will be closed when function | ||||
| 			expectErrc := make(chan error) | ||||
| 			go func() { | ||||
| 				select { | ||||
| 				case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): | ||||
| 				case <-done: | ||||
| 				case <-alarm.C: | ||||
| 				} | ||||
| 			}() | ||||
|  | ||||
| 			select { | ||||
| 			case err := <-expectErrc: | ||||
| 				if err != nil { | ||||
| 					select { | ||||
| 					case errc <- err: | ||||
| 					case <-done: | ||||
| 					case <-alarm.C: | ||||
| 						errc <- errTimedOut | ||||
| 					} | ||||
| 				} | ||||
| 			case <-done: | ||||
| 			case <-alarm.C: | ||||
| 				errc <- errTimedOut | ||||
| 			} | ||||
|  | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		errc <- mockNode.Expect(&exp) | ||||
| 		wg.Wait() | ||||
| 		// close errc when all goroutines finish to return nill err from errc | ||||
| 		close(errc) | ||||
| 	}() | ||||
|  | ||||
| 	t := exp.Timeout | ||||
| 	if t == time.Duration(0) { | ||||
| 		t = 2000 * time.Millisecond | ||||
| 	} | ||||
| 	select { | ||||
| 	case err := <-errc: | ||||
| 		return err | ||||
| 	case <-time.After(t): | ||||
| 		return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer) | ||||
| 	} | ||||
| 	return <-errc | ||||
| } | ||||
|  | ||||
| // TestExchanges tests a series of exchanges against the session | ||||
| func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error { | ||||
| 	// launch all triggers of this exchanges | ||||
|  | ||||
| 	for _, e := range exchanges { | ||||
| 		errc := make(chan error, len(e.Triggers)+len(e.Expects)) | ||||
| 		for _, trig := range e.Triggers { | ||||
| 			errc <- self.trigger(trig) | ||||
| 		} | ||||
|  | ||||
| 		// each expectation is spawned in separate go-routine | ||||
| 		// expectations of an exchange are conjunctive but unordered, i.e., | ||||
| 		// only all of them arriving constitutes a pass | ||||
| 		// each expectation is meant to be for a different peer, otherwise they are expected to panic | ||||
| 		// testing of an exchange blocks until all expectations are decided | ||||
| 		// an expectation is decided if | ||||
| 		//  expected message arrives OR | ||||
| 		// an unexpected message arrives (panic) | ||||
| 		// times out on their individual timeout | ||||
| 		for _, ex := range e.Expects { | ||||
| 			// expect msg spawned to separate go routine | ||||
| 			go func(exp Expect) { | ||||
| 				errc <- self.expect(exp) | ||||
| 			}(ex) | ||||
| 		} | ||||
|  | ||||
| 		// time out globally or finish when all expectations satisfied | ||||
| 		timeout := time.After(5 * time.Second) | ||||
| 		for i := 0; i < len(e.Triggers)+len(e.Expects); i++ { | ||||
| 			select { | ||||
| 			case err := <-errc: | ||||
| 				if err != nil { | ||||
| 					return fmt.Errorf("exchange failed with: %v", err) | ||||
| 				} | ||||
| 			case <-timeout: | ||||
| 				return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label) | ||||
| 			} | ||||
| 	for i, e := range exchanges { | ||||
| 		if err := self.testExchange(e); err != nil { | ||||
| 			return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) | ||||
| 		} | ||||
| 		log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // testExchange tests a single Exchange. | ||||
| // Default timeout value is 2 seconds. | ||||
| func (self *ProtocolSession) testExchange(e Exchange) error { | ||||
| 	errc := make(chan error) | ||||
| 	done := make(chan struct{}) | ||||
| 	defer close(done) | ||||
|  | ||||
| 	go func() { | ||||
| 		for _, trig := range e.Triggers { | ||||
| 			err := self.trigger(trig) | ||||
| 			if err != nil { | ||||
| 				errc <- err | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		select { | ||||
| 		case errc <- self.expect(e.Expects): | ||||
| 		case <-done: | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	// time out globally or finish when all expectations satisfied | ||||
| 	t := e.Timeout | ||||
| 	if t == 0 { | ||||
| 		t = 2000 * time.Millisecond | ||||
| 	} | ||||
| 	alarm := time.NewTimer(t) | ||||
| 	select { | ||||
| 	case err := <-errc: | ||||
| 		return err | ||||
| 	case <-alarm.C: | ||||
| 		return errTimedOut | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestDisconnected tests the disconnections given as arguments | ||||
| // the disconnect structs describe what disconnect error is expected on which peer | ||||
| func (self *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { | ||||
|   | ||||
| @@ -24,7 +24,11 @@ that can be used to send and receive messages | ||||
| package testing | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"io/ioutil" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
|  | ||||
| @@ -34,6 +38,7 @@ import ( | ||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||
| 	"github.com/ethereum/go-ethereum/p2p/simulations" | ||||
| 	"github.com/ethereum/go-ethereum/p2p/simulations/adapters" | ||||
| 	"github.com/ethereum/go-ethereum/rlp" | ||||
| 	"github.com/ethereum/go-ethereum/rpc" | ||||
| ) | ||||
|  | ||||
| @@ -152,7 +157,7 @@ type mockNode struct { | ||||
| 	testNode | ||||
|  | ||||
| 	trigger  chan *Trigger | ||||
| 	expect   chan *Expect | ||||
| 	expect   chan []Expect | ||||
| 	err      chan error | ||||
| 	stop     chan struct{} | ||||
| 	stopOnce sync.Once | ||||
| @@ -161,7 +166,7 @@ type mockNode struct { | ||||
| func newMockNode() *mockNode { | ||||
| 	mock := &mockNode{ | ||||
| 		trigger: make(chan *Trigger), | ||||
| 		expect:  make(chan *Expect), | ||||
| 		expect:  make(chan []Expect), | ||||
| 		err:     make(chan error), | ||||
| 		stop:    make(chan struct{}), | ||||
| 	} | ||||
| @@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { | ||||
| 		select { | ||||
| 		case trig := <-m.trigger: | ||||
| 			m.err <- p2p.Send(rw, trig.Code, trig.Msg) | ||||
| 		case exp := <-m.expect: | ||||
| 			m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg) | ||||
| 		case exps := <-m.expect: | ||||
| 			m.err <- expectMsgs(rw, exps) | ||||
| 		case <-m.stop: | ||||
| 			return nil | ||||
| 		} | ||||
| @@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error { | ||||
| 	return <-m.err | ||||
| } | ||||
|  | ||||
| func (m *mockNode) Expect(exp *Expect) error { | ||||
| func (m *mockNode) Expect(exp ...Expect) error { | ||||
| 	m.expect <- exp | ||||
| 	return <-m.err | ||||
| } | ||||
| @@ -198,3 +203,67 @@ func (m *mockNode) Stop() error { | ||||
| 	m.stopOnce.Do(func() { close(m.stop) }) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { | ||||
| 	matched := make([]bool, len(exps)) | ||||
| 	for { | ||||
| 		msg, err := rw.ReadMsg() | ||||
| 		if err != nil { | ||||
| 			if err == io.EOF { | ||||
| 				break | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| 		actualContent, err := ioutil.ReadAll(msg.Payload) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		var found bool | ||||
| 		for i, exp := range exps { | ||||
| 			if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) { | ||||
| 				if matched[i] { | ||||
| 					return fmt.Errorf("message #%d received two times", i) | ||||
| 				} | ||||
| 				matched[i] = true | ||||
| 				found = true | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if !found { | ||||
| 			expected := make([]string, 0) | ||||
| 			for i, exp := range exps { | ||||
| 				if matched[i] { | ||||
| 					continue | ||||
| 				} | ||||
| 				expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg))) | ||||
| 			} | ||||
| 			return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) | ||||
| 		} | ||||
| 		done := true | ||||
| 		for _, m := range matched { | ||||
| 			if !m { | ||||
| 				done = false | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 		if done { | ||||
| 			return nil | ||||
| 		} | ||||
| 	} | ||||
| 	for i, m := range matched { | ||||
| 		if !m { | ||||
| 			return fmt.Errorf("expected message #%d not received", i) | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // mustEncodeMsg uses rlp to encode a message. | ||||
| // In case of error it panics. | ||||
| func mustEncodeMsg(msg interface{}) []byte { | ||||
| 	contentEnc, err := rlp.EncodeToBytes(msg) | ||||
| 	if err != nil { | ||||
| 		panic("content encode error: " + err.Error()) | ||||
| 	} | ||||
| 	return contentEnc | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user