329 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			329 lines
		
	
	
		
			8.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package p2p
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"fmt"
 | 
						|
	"io/ioutil"
 | 
						|
	"net"
 | 
						|
	"reflect"
 | 
						|
	"sort"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
						|
	"github.com/ethereum/go-ethereum/rlp"
 | 
						|
)
 | 
						|
 | 
						|
var discard = Protocol{
 | 
						|
	Name:   "discard",
 | 
						|
	Length: 1,
 | 
						|
	Run: func(p *Peer, rw MsgReadWriter) error {
 | 
						|
		for {
 | 
						|
			msg, err := rw.ReadMsg()
 | 
						|
			if err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
			if err = msg.Discard(); err != nil {
 | 
						|
				return err
 | 
						|
			}
 | 
						|
		}
 | 
						|
	},
 | 
						|
}
 | 
						|
 | 
						|
func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
 | 
						|
	conn1, conn2 := net.Pipe()
 | 
						|
	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
 | 
						|
	peer.noHandshake = noHandshake
 | 
						|
	errc := make(chan DiscReason, 1)
 | 
						|
	go func() { errc <- peer.run() }()
 | 
						|
	return newFrameRW(conn2, msgWriteTimeout), peer, errc
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerProtoReadMsg(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	done := make(chan struct{})
 | 
						|
	proto := Protocol{
 | 
						|
		Name:   "a",
 | 
						|
		Length: 5,
 | 
						|
		Run: func(peer *Peer, rw MsgReadWriter) error {
 | 
						|
			if err := expectMsg(rw, 2, []uint{1}); err != nil {
 | 
						|
				t.Error(err)
 | 
						|
			}
 | 
						|
			if err := expectMsg(rw, 3, []uint{2}); err != nil {
 | 
						|
				t.Error(err)
 | 
						|
			}
 | 
						|
			if err := expectMsg(rw, 4, []uint{3}); err != nil {
 | 
						|
				t.Error(err)
 | 
						|
			}
 | 
						|
			close(done)
 | 
						|
			return nil
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	rw, peer, errc := testPeer(true, []Protocol{proto})
 | 
						|
	defer rw.Close()
 | 
						|
	peer.startSubprotocols([]Cap{proto.cap()})
 | 
						|
 | 
						|
	EncodeMsg(rw, baseProtocolLength+2, 1)
 | 
						|
	EncodeMsg(rw, baseProtocolLength+3, 2)
 | 
						|
	EncodeMsg(rw, baseProtocolLength+4, 3)
 | 
						|
 | 
						|
	select {
 | 
						|
	case <-done:
 | 
						|
	case err := <-errc:
 | 
						|
		t.Errorf("peer returned: %v", err)
 | 
						|
	case <-time.After(2 * time.Second):
 | 
						|
		t.Errorf("receive timeout")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerProtoReadLargeMsg(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	msgsize := uint32(10 * 1024 * 1024)
 | 
						|
	done := make(chan struct{})
 | 
						|
	proto := Protocol{
 | 
						|
		Name:   "a",
 | 
						|
		Length: 5,
 | 
						|
		Run: func(peer *Peer, rw MsgReadWriter) error {
 | 
						|
			msg, err := rw.ReadMsg()
 | 
						|
			if err != nil {
 | 
						|
				t.Errorf("read error: %v", err)
 | 
						|
			}
 | 
						|
			if msg.Size != msgsize+4 {
 | 
						|
				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
 | 
						|
			}
 | 
						|
			msg.Discard()
 | 
						|
			close(done)
 | 
						|
			return nil
 | 
						|
		},
 | 
						|
	}
 | 
						|
 | 
						|
	rw, peer, errc := testPeer(true, []Protocol{proto})
 | 
						|
	defer rw.Close()
 | 
						|
	peer.startSubprotocols([]Cap{proto.cap()})
 | 
						|
 | 
						|
	EncodeMsg(rw, 18, make([]byte, msgsize))
 | 
						|
	select {
 | 
						|
	case <-done:
 | 
						|
	case err := <-errc:
 | 
						|
		t.Errorf("peer returned: %v", err)
 | 
						|
	case <-time.After(2 * time.Second):
 | 
						|
		t.Errorf("receive timeout")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerProtoEncodeMsg(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	proto := Protocol{
 | 
						|
		Name:   "a",
 | 
						|
		Length: 2,
 | 
						|
		Run: func(peer *Peer, rw MsgReadWriter) error {
 | 
						|
			if err := EncodeMsg(rw, 2); err == nil {
 | 
						|
				t.Error("expected error for out-of-range msg code, got nil")
 | 
						|
			}
 | 
						|
			if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
 | 
						|
				t.Errorf("write error: %v", err)
 | 
						|
			}
 | 
						|
			return nil
 | 
						|
		},
 | 
						|
	}
 | 
						|
	rw, peer, _ := testPeer(true, []Protocol{proto})
 | 
						|
	defer rw.Close()
 | 
						|
	peer.startSubprotocols([]Cap{proto.cap()})
 | 
						|
 | 
						|
	if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerWriteForBroadcast(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	rw, peer, peerErr := testPeer(true, []Protocol{discard})
 | 
						|
	defer rw.Close()
 | 
						|
	peer.startSubprotocols([]Cap{discard.cap()})
 | 
						|
 | 
						|
	// test write errors
 | 
						|
	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
 | 
						|
		t.Errorf("expected error for unknown protocol, got nil")
 | 
						|
	}
 | 
						|
	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
 | 
						|
		t.Errorf("expected error for out-of-range msg code, got nil")
 | 
						|
	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
 | 
						|
		t.Errorf("wrong error for out-of-range msg code, got %#v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// setup for reading the message on the other end
 | 
						|
	read := make(chan struct{})
 | 
						|
	go func() {
 | 
						|
		if err := expectMsg(rw, 16, nil); err != nil {
 | 
						|
			t.Error()
 | 
						|
		}
 | 
						|
		close(read)
 | 
						|
	}()
 | 
						|
 | 
						|
	// test successful write
 | 
						|
	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
 | 
						|
		t.Errorf("expect no error for known protocol: %v", err)
 | 
						|
	}
 | 
						|
	select {
 | 
						|
	case <-read:
 | 
						|
	case err := <-peerErr:
 | 
						|
		t.Fatalf("peer stopped: %v", err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerPing(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	rw, _, _ := testPeer(true, nil)
 | 
						|
	defer rw.Close()
 | 
						|
	if err := EncodeMsg(rw, pingMsg); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if err := expectMsg(rw, pongMsg, nil); err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerDisconnect(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	rw, _, disc := testPeer(true, nil)
 | 
						|
	defer rw.Close()
 | 
						|
	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
 | 
						|
		t.Fatal(err)
 | 
						|
	}
 | 
						|
	if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
 | 
						|
		t.Error(err)
 | 
						|
	}
 | 
						|
	rw.Close() // make test end faster
 | 
						|
	if reason := <-disc; reason != DiscRequested {
 | 
						|
		t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPeerHandshake(t *testing.T) {
 | 
						|
	defer testlog(t).detach()
 | 
						|
 | 
						|
	// remote has two matching protocols: a and c
 | 
						|
	remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
 | 
						|
	remoteID := randomID()
 | 
						|
	remote.ourID = &remoteID
 | 
						|
	remote.ourName = "remote peer"
 | 
						|
 | 
						|
	start := make(chan string)
 | 
						|
	stop := make(chan struct{})
 | 
						|
	run := func(p *Peer, rw MsgReadWriter) error {
 | 
						|
		name := rw.(*proto).name
 | 
						|
		if name != "a" && name != "c" {
 | 
						|
			t.Errorf("protocol %q should not be started", name)
 | 
						|
		} else {
 | 
						|
			start <- name
 | 
						|
		}
 | 
						|
		<-stop
 | 
						|
		return nil
 | 
						|
	}
 | 
						|
	protocols := []Protocol{
 | 
						|
		{Name: "a", Version: 1, Length: 1, Run: run},
 | 
						|
		{Name: "b", Version: 2, Length: 1, Run: run},
 | 
						|
		{Name: "c", Version: 3, Length: 1, Run: run},
 | 
						|
		{Name: "d", Version: 4, Length: 1, Run: run},
 | 
						|
	}
 | 
						|
	rw, p, disc := testPeer(false, protocols)
 | 
						|
	p.remoteID = remote.ourID
 | 
						|
	defer rw.Close()
 | 
						|
 | 
						|
	// run the handshake
 | 
						|
	remoteProtocols := []Protocol{protocols[0], protocols[2]}
 | 
						|
	if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
 | 
						|
		t.Fatalf("handshake write error: %v", err)
 | 
						|
	}
 | 
						|
	if err := readProtocolHandshake(remote, rw); err != nil {
 | 
						|
		t.Fatalf("handshake read error: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// check that all protocols have been started
 | 
						|
	var started []string
 | 
						|
	for i := 0; i < 2; i++ {
 | 
						|
		select {
 | 
						|
		case name := <-start:
 | 
						|
			started = append(started, name)
 | 
						|
		case <-time.After(100 * time.Millisecond):
 | 
						|
		}
 | 
						|
	}
 | 
						|
	sort.Strings(started)
 | 
						|
	if !reflect.DeepEqual(started, []string{"a", "c"}) {
 | 
						|
		t.Errorf("wrong protocols started: %v", started)
 | 
						|
	}
 | 
						|
 | 
						|
	// check that metadata has been set
 | 
						|
	if p.ID() != remoteID {
 | 
						|
		t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
 | 
						|
	}
 | 
						|
	if p.Name() != remote.ourName {
 | 
						|
		t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
 | 
						|
	}
 | 
						|
 | 
						|
	close(stop)
 | 
						|
	expectMsg(rw, discMsg, nil)
 | 
						|
	t.Logf("disc reason: %v", <-disc)
 | 
						|
}
 | 
						|
 | 
						|
func TestNewPeer(t *testing.T) {
 | 
						|
	name := "nodename"
 | 
						|
	caps := []Cap{{"foo", 2}, {"bar", 3}}
 | 
						|
	id := randomID()
 | 
						|
	p := NewPeer(id, name, caps)
 | 
						|
	if p.ID() != id {
 | 
						|
		t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
 | 
						|
	}
 | 
						|
	if p.Name() != name {
 | 
						|
		t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
 | 
						|
	}
 | 
						|
	if !reflect.DeepEqual(p.Caps(), caps) {
 | 
						|
		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
 | 
						|
	}
 | 
						|
 | 
						|
	p.Disconnect(DiscAlreadyConnected) // Should not hang
 | 
						|
}
 | 
						|
 | 
						|
// expectMsg reads a message from r and verifies that its
 | 
						|
// code and encoded RLP content match the provided values.
 | 
						|
// If content is nil, the payload is discarded and not verified.
 | 
						|
func expectMsg(r MsgReader, code uint64, content interface{}) error {
 | 
						|
	msg, err := r.ReadMsg()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if msg.Code != code {
 | 
						|
		return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
 | 
						|
	}
 | 
						|
	if content == nil {
 | 
						|
		return msg.Discard()
 | 
						|
	} else {
 | 
						|
		contentEnc, err := rlp.EncodeToBytes(content)
 | 
						|
		if err != nil {
 | 
						|
			panic("content encode error: " + err.Error())
 | 
						|
		}
 | 
						|
		// skip over list header in encoded value. this is temporary.
 | 
						|
		contentEncR := bytes.NewReader(contentEnc)
 | 
						|
		if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
 | 
						|
			panic("content must encode as RLP list")
 | 
						|
		}
 | 
						|
		contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
 | 
						|
 | 
						|
		actualContent, err := ioutil.ReadAll(msg.Payload)
 | 
						|
		if err != nil {
 | 
						|
			return err
 | 
						|
		}
 | 
						|
		if !bytes.Equal(actualContent, contentEnc) {
 | 
						|
			return fmt.Errorf("message payload mismatch:\ngot:  %x\nwant: %x", actualContent, contentEnc)
 | 
						|
		}
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 |