p2p: fixes for actual connections
The unit test hooks were turned on 'in production'.
This commit is contained in:
		| @@ -174,10 +174,10 @@ func (rw *frameRW) ReadMsg() (msg Msg, err error) { | |||||||
| 	// read magic and payload size | 	// read magic and payload size | ||||||
| 	start := make([]byte, 8) | 	start := make([]byte, 8) | ||||||
| 	if _, err = io.ReadFull(rw.bufconn, start); err != nil { | 	if _, err = io.ReadFull(rw.bufconn, start); err != nil { | ||||||
| 		return msg, newPeerError(errRead, "%v", err) | 		return msg, err | ||||||
| 	} | 	} | ||||||
| 	if !bytes.HasPrefix(start, magicToken) { | 	if !bytes.HasPrefix(start, magicToken) { | ||||||
| 		return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) | 		return msg, fmt.Errorf("bad magic token %x", start[:4], magicToken) | ||||||
| 	} | 	} | ||||||
| 	size := binary.BigEndian.Uint32(start[4:]) | 	size := binary.BigEndian.Uint32(start[4:]) | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										37
									
								
								p2p/peer.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								p2p/peer.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | |||||||
| package p2p | package p2p | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| @@ -71,7 +72,8 @@ type Peer struct { | |||||||
| 	runlock   sync.RWMutex // protects running | 	runlock   sync.RWMutex // protects running | ||||||
| 	running   map[string]*proto | 	running   map[string]*proto | ||||||
|  |  | ||||||
| 	protocolHandshakeEnabled bool | 	// disables protocol handshake, for testing | ||||||
|  | 	noHandshake bool | ||||||
|  |  | ||||||
| 	protoWG  sync.WaitGroup | 	protoWG  sync.WaitGroup | ||||||
| 	protoErr chan error | 	protoErr chan error | ||||||
| @@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) { | |||||||
|  |  | ||||||
| // String implements fmt.Stringer. | // String implements fmt.Stringer. | ||||||
| func (p *Peer) String() string { | func (p *Peer) String() string { | ||||||
| 	return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr()) | 	return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr()) | ||||||
| } | } | ||||||
|  |  | ||||||
| func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { | func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer { | ||||||
| 	logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr()) | 	logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr()) | ||||||
| 	return &Peer{ | 	return &Peer{ | ||||||
| 		Logger:    logger.NewLogger(logtag), | 		Logger:    logger.NewLogger(logtag), | ||||||
| 		rw:        newFrameRW(conn, msgWriteTimeout), | 		rw:        newFrameRW(conn, msgWriteTimeout), | ||||||
| @@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason { | |||||||
| 	var readErr = make(chan error, 1) | 	var readErr = make(chan error, 1) | ||||||
| 	defer p.closeProtocols() | 	defer p.closeProtocols() | ||||||
| 	defer close(p.closed) | 	defer close(p.closed) | ||||||
| 	defer p.rw.Close() |  | ||||||
|  |  | ||||||
| 	// start the read loop |  | ||||||
| 	go func() { readErr <- p.readLoop() }() | 	go func() { readErr <- p.readLoop() }() | ||||||
|  |  | ||||||
| 	if p.protocolHandshakeEnabled { | 	if !p.noHandshake { | ||||||
| 		if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { | 		if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil { | ||||||
| 			p.DebugDetailf("Protocol handshake error: %v\n", err) | 			p.DebugDetailf("Protocol handshake error: %v\n", err) | ||||||
|  | 			p.rw.Close() | ||||||
| 			return DiscProtocolError | 			return DiscProtocolError | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// wait for an error or disconnect | 	// Wait for an error or disconnect. | ||||||
| 	var reason DiscReason | 	var reason DiscReason | ||||||
| 	select { | 	select { | ||||||
| 	case err := <-readErr: | 	case err := <-readErr: | ||||||
| 		// We rely on protocols to abort if there is a write error. It | 		// We rely on protocols to abort if there is a write error. It | ||||||
| 		// might be more robust to handle them here as well. | 		// might be more robust to handle them here as well. | ||||||
| 		p.DebugDetailf("Read error: %v\n", err) | 		p.DebugDetailf("Read error: %v\n", err) | ||||||
| 		reason = DiscNetworkError | 		p.rw.Close() | ||||||
|  | 		return DiscNetworkError | ||||||
|  |  | ||||||
| 	case err := <-p.protoErr: | 	case err := <-p.protoErr: | ||||||
| 		reason = discReasonForError(err) | 		reason = discReasonForError(err) | ||||||
| 	case reason = <-p.disc: | 	case reason = <-p.disc: | ||||||
| 	} | 	} | ||||||
| 	if reason != DiscNetworkError { | 	p.politeDisconnect(reason) | ||||||
| 		p.politeDisconnect(reason) |  | ||||||
| 	} | 	// Wait for readLoop. It will end because conn is now closed. | ||||||
|  | 	<-readErr | ||||||
| 	p.Debugf("Disconnected: %v\n", reason) | 	p.Debugf("Disconnected: %v\n", reason) | ||||||
| 	return reason | 	return reason | ||||||
| } | } | ||||||
| @@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason { | |||||||
| func (p *Peer) politeDisconnect(reason DiscReason) { | func (p *Peer) politeDisconnect(reason DiscReason) { | ||||||
| 	done := make(chan struct{}) | 	done := make(chan struct{}) | ||||||
| 	go func() { | 	go func() { | ||||||
| 		// send reason |  | ||||||
| 		EncodeMsg(p.rw, discMsg, uint(reason)) | 		EncodeMsg(p.rw, discMsg, uint(reason)) | ||||||
| 		// discard any data that might arrive | 		// Wait for the other side to close the connection. | ||||||
|  | 		// Discard any data that they send until then. | ||||||
| 		io.Copy(ioutil.Discard, p.rw) | 		io.Copy(ioutil.Discard, p.rw) | ||||||
| 		close(done) | 		close(done) | ||||||
| 	}() | 	}() | ||||||
| @@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) { | |||||||
| 	case <-done: | 	case <-done: | ||||||
| 	case <-time.After(disconnectGracePeriod): | 	case <-time.After(disconnectGracePeriod): | ||||||
| 	} | 	} | ||||||
|  | 	p.rw.Close() | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Peer) readLoop() error { | func (p *Peer) readLoop() error { | ||||||
| 	if p.protocolHandshakeEnabled { | 	if !p.noHandshake { | ||||||
| 		if err := readProtocolHandshake(p, p.rw); err != nil { | 		if err := readProtocolHandshake(p, p.rw); err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error { | |||||||
| 		return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) | 		return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code) | ||||||
| 	} | 	} | ||||||
| 	if msg.Size > baseProtocolMaxMsgSize { | 	if msg.Size > baseProtocolMaxMsgSize { | ||||||
| 		return newPeerError(errMisc, "message too big") | 		return newPeerError(errInvalidMsg, "message too big") | ||||||
| 	} | 	} | ||||||
| 	var hs handshake | 	var hs handshake | ||||||
| 	if err := msg.Decode(&hs); err != nil { | 	if err := msg.Decode(&hs); err != nil { | ||||||
| @@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto { | |||||||
| 		err := impl.Run(p, rw) | 		err := impl.Run(p, rw) | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
| 			p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) | 			p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version) | ||||||
| 			err = newPeerError(errMisc, "protocol returned") | 			err = errors.New("protocol returned") | ||||||
| 		} else { | 		} else { | ||||||
| 			p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) | 			p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err) | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -123,7 +123,7 @@ func discReasonForError(err error) DiscReason { | |||||||
| 		return DiscProtocolError | 		return DiscProtocolError | ||||||
| 	case errPingTimeout: | 	case errPingTimeout: | ||||||
| 		return DiscReadTimeout | 		return DiscReadTimeout | ||||||
| 	case errRead, errWrite, errMisc: | 	case errRead, errWrite: | ||||||
| 		return DiscNetworkError | 		return DiscNetworkError | ||||||
| 	default: | 	default: | ||||||
| 		return DiscSubprotocolError | 		return DiscSubprotocolError | ||||||
|   | |||||||
| @@ -30,10 +30,10 @@ var discard = Protocol{ | |||||||
| 	}, | 	}, | ||||||
| } | } | ||||||
|  |  | ||||||
| func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { | func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { | ||||||
| 	conn1, conn2 := net.Pipe() | 	conn1, conn2 := net.Pipe() | ||||||
| 	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) | 	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) | ||||||
| 	peer.protocolHandshakeEnabled = handshake | 	peer.noHandshake = noHandshake | ||||||
| 	errc := make(chan DiscReason, 1) | 	errc := make(chan DiscReason, 1) | ||||||
| 	go func() { errc <- peer.run() }() | 	go func() { errc <- peer.run() }() | ||||||
| 	return newFrameRW(conn2, msgWriteTimeout), peer, errc | 	return newFrameRW(conn2, msgWriteTimeout), peer, errc | ||||||
| @@ -61,7 +61,7 @@ func TestPeerProtoReadMsg(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rw, peer, errc := testPeer(false, []Protocol{proto}) | 	rw, peer, errc := testPeer(true, []Protocol{proto}) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
|  |  | ||||||
| @@ -100,7 +100,7 @@ func TestPeerProtoReadLargeMsg(t *testing.T) { | |||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rw, peer, errc := testPeer(false, []Protocol{proto}) | 	rw, peer, errc := testPeer(true, []Protocol{proto}) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
|  |  | ||||||
| @@ -130,7 +130,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { | |||||||
| 			return nil | 			return nil | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 	rw, peer, _ := testPeer(false, []Protocol{proto}) | 	rw, peer, _ := testPeer(true, []Protocol{proto}) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	peer.startSubprotocols([]Cap{proto.cap()}) | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
|  |  | ||||||
| @@ -142,7 +142,7 @@ func TestPeerProtoEncodeMsg(t *testing.T) { | |||||||
| func TestPeerWriteForBroadcast(t *testing.T) { | func TestPeerWriteForBroadcast(t *testing.T) { | ||||||
| 	defer testlog(t).detach() | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
| 	rw, peer, peerErr := testPeer(false, []Protocol{discard}) | 	rw, peer, peerErr := testPeer(true, []Protocol{discard}) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	peer.startSubprotocols([]Cap{discard.cap()}) | 	peer.startSubprotocols([]Cap{discard.cap()}) | ||||||
|  |  | ||||||
| @@ -179,7 +179,7 @@ func TestPeerWriteForBroadcast(t *testing.T) { | |||||||
| func TestPeerPing(t *testing.T) { | func TestPeerPing(t *testing.T) { | ||||||
| 	defer testlog(t).detach() | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
| 	rw, _, _ := testPeer(false, nil) | 	rw, _, _ := testPeer(true, nil) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	if err := EncodeMsg(rw, pingMsg); err != nil { | 	if err := EncodeMsg(rw, pingMsg); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| @@ -192,7 +192,7 @@ func TestPeerPing(t *testing.T) { | |||||||
| func TestPeerDisconnect(t *testing.T) { | func TestPeerDisconnect(t *testing.T) { | ||||||
| 	defer testlog(t).detach() | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
| 	rw, _, disc := testPeer(false, nil) | 	rw, _, disc := testPeer(true, nil) | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
| 	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { | 	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| @@ -233,7 +233,7 @@ func TestPeerHandshake(t *testing.T) { | |||||||
| 		{Name: "c", Version: 3, Length: 1, Run: run}, | 		{Name: "c", Version: 3, Length: 1, Run: run}, | ||||||
| 		{Name: "d", Version: 4, Length: 1, Run: run}, | 		{Name: "d", Version: 4, Length: 1, Run: run}, | ||||||
| 	} | 	} | ||||||
| 	rw, p, disc := testPeer(true, protocols) | 	rw, p, disc := testPeer(false, protocols) | ||||||
| 	p.remoteID = remote.ourID | 	p.remoteID = remote.ourID | ||||||
| 	defer rw.Close() | 	defer rw.Close() | ||||||
|  |  | ||||||
| @@ -269,6 +269,7 @@ func TestPeerHandshake(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	close(stop) | 	close(stop) | ||||||
|  | 	expectMsg(rw, discMsg, nil) | ||||||
| 	t.Logf("disc reason: %v", <-disc) | 	t.Logf("disc reason: %v", <-disc) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -408,7 +408,9 @@ func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	srv.newPeerHook(p) | 	if srv.newPeerHook != nil { | ||||||
|  | 		srv.newPeerHook(p) | ||||||
|  | 	} | ||||||
| 	p.run() | 	p.run() | ||||||
| 	srv.removePeer(p) | 	srv.removePeer(p) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -118,6 +118,7 @@ func TestServerBroadcast(t *testing.T) { | |||||||
| 	srv := startTestServer(t, func(p *Peer) { | 	srv := startTestServer(t, func(p *Peer) { | ||||||
| 		p.protocols = []Protocol{discard} | 		p.protocols = []Protocol{discard} | ||||||
| 		p.startSubprotocols([]Cap{discard.cap()}) | 		p.startSubprotocols([]Cap{discard.cap()}) | ||||||
|  | 		p.noHandshake = true | ||||||
| 		connected.Done() | 		connected.Done() | ||||||
| 	}) | 	}) | ||||||
| 	defer srv.Stop() | 	defer srv.Stop() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user