p2p: fixes for actual connections
The unit test hooks were turned on 'in production'.
This commit is contained in:
37
p2p/peer.go
37
p2p/peer.go
@ -1,6 +1,7 @@
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@ -71,7 +72,8 @@ type Peer struct {
|
||||
runlock sync.RWMutex // protects running
|
||||
running map[string]*proto
|
||||
|
||||
protocolHandshakeEnabled bool
|
||||
// disables protocol handshake, for testing
|
||||
noHandshake bool
|
||||
|
||||
protoWG sync.WaitGroup
|
||||
protoErr chan error
|
||||
@ -134,11 +136,11 @@ func (p *Peer) Disconnect(reason DiscReason) {
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (p *Peer) String() string {
|
||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
|
||||
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
|
||||
}
|
||||
|
||||
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
|
||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
|
||||
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
|
||||
return &Peer{
|
||||
Logger: logger.NewLogger(logtag),
|
||||
rw: newFrameRW(conn, msgWriteTimeout),
|
||||
@ -164,33 +166,35 @@ func (p *Peer) run() DiscReason {
|
||||
var readErr = make(chan error, 1)
|
||||
defer p.closeProtocols()
|
||||
defer close(p.closed)
|
||||
defer p.rw.Close()
|
||||
|
||||
// start the read loop
|
||||
go func() { readErr <- p.readLoop() }()
|
||||
|
||||
if p.protocolHandshakeEnabled {
|
||||
if !p.noHandshake {
|
||||
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
|
||||
p.DebugDetailf("Protocol handshake error: %v\n", err)
|
||||
p.rw.Close()
|
||||
return DiscProtocolError
|
||||
}
|
||||
}
|
||||
|
||||
// wait for an error or disconnect
|
||||
// Wait for an error or disconnect.
|
||||
var reason DiscReason
|
||||
select {
|
||||
case err := <-readErr:
|
||||
// We rely on protocols to abort if there is a write error. It
|
||||
// might be more robust to handle them here as well.
|
||||
p.DebugDetailf("Read error: %v\n", err)
|
||||
reason = DiscNetworkError
|
||||
p.rw.Close()
|
||||
return DiscNetworkError
|
||||
|
||||
case err := <-p.protoErr:
|
||||
reason = discReasonForError(err)
|
||||
case reason = <-p.disc:
|
||||
}
|
||||
if reason != DiscNetworkError {
|
||||
p.politeDisconnect(reason)
|
||||
}
|
||||
p.politeDisconnect(reason)
|
||||
|
||||
// Wait for readLoop. It will end because conn is now closed.
|
||||
<-readErr
|
||||
p.Debugf("Disconnected: %v\n", reason)
|
||||
return reason
|
||||
}
|
||||
@ -198,9 +202,9 @@ func (p *Peer) run() DiscReason {
|
||||
func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
// send reason
|
||||
EncodeMsg(p.rw, discMsg, uint(reason))
|
||||
// discard any data that might arrive
|
||||
// Wait for the other side to close the connection.
|
||||
// Discard any data that they send until then.
|
||||
io.Copy(ioutil.Discard, p.rw)
|
||||
close(done)
|
||||
}()
|
||||
@ -208,10 +212,11 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
|
||||
case <-done:
|
||||
case <-time.After(disconnectGracePeriod):
|
||||
}
|
||||
p.rw.Close()
|
||||
}
|
||||
|
||||
func (p *Peer) readLoop() error {
|
||||
if p.protocolHandshakeEnabled {
|
||||
if !p.noHandshake {
|
||||
if err := readProtocolHandshake(p, p.rw); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -264,7 +269,7 @@ func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
|
||||
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
|
||||
}
|
||||
if msg.Size > baseProtocolMaxMsgSize {
|
||||
return newPeerError(errMisc, "message too big")
|
||||
return newPeerError(errInvalidMsg, "message too big")
|
||||
}
|
||||
var hs handshake
|
||||
if err := msg.Decode(&hs); err != nil {
|
||||
@ -326,7 +331,7 @@ func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
|
||||
err := impl.Run(p, rw)
|
||||
if err == nil {
|
||||
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
|
||||
err = newPeerError(errMisc, "protocol returned")
|
||||
err = errors.New("protocol returned")
|
||||
} else {
|
||||
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user