p2p: disable encryption handshake

The diff is a bit bigger than expected because the protocol handshake
logic has moved out of Peer. This is necessary because the protocol
handshake will have custom framing in the final protocol.
This commit is contained in:
Felix Lange
2015-02-19 01:52:03 +01:00
parent 4322632c59
commit 73f94f3755
7 changed files with 274 additions and 314 deletions

View File

@ -33,37 +33,14 @@ const (
peersMsg = 0x05
)
// handshake is the RLP structure of the protocol handshake.
type handshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
NodeID discover.NodeID
}
// Peer represents a connected remote node.
type Peer struct {
// Peers have all the log methods.
// Use them to display messages related to the peer.
*logger.Logger
infoMu sync.Mutex
name string
caps []Cap
ourID, remoteID *discover.NodeID
ourName string
rw *frameRW
// These fields maintain the running protocols.
protocols []Protocol
runlock sync.RWMutex // protects running
running map[string]*proto
// disables protocol handshake, for testing
noHandshake bool
rw *conn
running map[string]*protoRW
protoWG sync.WaitGroup
protoErr chan error
@ -73,36 +50,27 @@ type Peer struct {
// NewPeer returns a peer for testing purposes.
func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
conn, _ := net.Pipe()
peer := newPeer(conn, nil, "", nil, &id)
peer.setHandshakeInfo(name, caps)
pipe, _ := net.Pipe()
conn := newConn(pipe, &protoHandshake{ID: id, Name: name, Caps: caps})
peer := newPeer(conn, nil)
close(peer.closed) // ensures Disconnect doesn't block
return peer
}
// ID returns the node's public key.
func (p *Peer) ID() discover.NodeID {
return *p.remoteID
return p.rw.ID
}
// Name returns the node name that the remote node advertised.
func (p *Peer) Name() string {
// this needs a lock because the information is part of the
// protocol handshake.
p.infoMu.Lock()
name := p.name
p.infoMu.Unlock()
return name
return p.rw.Name
}
// Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap {
// this needs a lock because the information is part of the
// protocol handshake.
p.infoMu.Lock()
caps := p.caps
p.infoMu.Unlock()
return caps
// TODO: maybe return copy
return p.rw.Caps
}
// RemoteAddr returns the remote address of the network connection.
@ -126,30 +94,20 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer.
func (p *Peer) String() string {
return fmt.Sprintf("Peer %.8x %v", p.remoteID[:], p.RemoteAddr())
return fmt.Sprintf("Peer %.8x %v", p.rw.ID[:], p.RemoteAddr())
}
func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", remoteID[:], conn.RemoteAddr())
return &Peer{
Logger: logger.NewLogger(logtag),
rw: newFrameRW(conn, msgWriteTimeout),
ourID: ourID,
ourName: ourName,
remoteID: remoteID,
protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
func newPeer(conn *conn, protocols []Protocol) *Peer {
logtag := fmt.Sprintf("Peer %.8x %v", conn.ID[:], conn.RemoteAddr())
p := &Peer{
Logger: logger.NewLogger(logtag),
rw: conn,
running: matchProtocols(protocols, conn.Caps, conn),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
}
func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
p.infoMu.Lock()
p.name = name
p.caps = caps
p.infoMu.Unlock()
return p
}
func (p *Peer) run() DiscReason {
@ -157,16 +115,9 @@ func (p *Peer) run() DiscReason {
defer p.closeProtocols()
defer close(p.closed)
p.startProtocols()
go func() { readErr <- p.readLoop() }()
if !p.noHandshake {
if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err)
p.rw.Close()
return DiscProtocolError
}
}
// Wait for an error or disconnect.
var reason DiscReason
select {
@ -206,11 +157,6 @@ func (p *Peer) politeDisconnect(reason DiscReason) {
}
func (p *Peer) readLoop() error {
if !p.noHandshake {
if err := readProtocolHandshake(p, p.rw); err != nil {
return err
}
}
for {
msg, err := p.rw.ReadMsg()
if err != nil {
@ -249,105 +195,51 @@ func (p *Peer) handle(msg Msg) error {
return nil
}
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
// read and handle remote handshake
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Code == discMsg {
// disconnect before protocol handshake is valid according to the
// spec and we send it ourself if Server.addPeer fails.
var reason DiscReason
rlp.Decode(msg.Payload, &reason)
return discRequestedError(reason)
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errInvalidMsg, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if hs.NodeID == *p.remoteID {
return newPeerError(errPubkeyForbidden, "node ID mismatch")
}
// TODO: remove Caps with empty name
p.setHandshakeInfo(hs.Name, hs.Caps)
p.startSubprotocols(hs.Caps)
return nil
}
func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
var caps []interface{}
for _, proto := range ps {
caps = append(caps, proto.cap())
}
return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
}
// startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) {
// matchProtocols creates structures for matching named subprotocols.
func matchProtocols(protocols []Protocol, caps []Cap, rw MsgReadWriter) map[string]*protoRW {
sort.Sort(capsByName(caps))
p.runlock.Lock()
defer p.runlock.Unlock()
offset := baseProtocolLength
result := make(map[string]*protoRW)
outer:
for _, cap := range caps {
for _, proto := range p.protocols {
if proto.Name == cap.Name &&
proto.Version == cap.Version &&
p.running[cap.Name] == nil {
p.running[cap.Name] = p.startProto(offset, proto)
for _, proto := range protocols {
if proto.Name == cap.Name && proto.Version == cap.Version && result[cap.Name] == nil {
result[cap.Name] = &protoRW{Protocol: proto, offset: offset, in: make(chan Msg), w: rw}
offset += proto.Length
continue outer
}
}
}
return result
}
func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
rw := &proto{
name: impl.Name,
in: make(chan Msg),
offset: offset,
maxcode: impl.Length,
w: p.rw,
func (p *Peer) startProtocols() {
for _, proto := range p.running {
proto := proto
p.DebugDetailf("Starting protocol %s/%d\n", proto.Name, proto.Version)
p.protoWG.Add(1)
go func() {
err := proto.Run(p, proto)
if err == nil {
p.DebugDetailf("Protocol %s/%d returned\n", proto.Name, proto.Version)
err = errors.New("protocol returned")
} else {
p.DebugDetailf("Protocol %s/%d error: %v\n", proto.Name, proto.Version, err)
}
select {
case p.protoErr <- err:
case <-p.closed:
}
p.protoWG.Done()
}()
}
p.protoWG.Add(1)
go func() {
err := impl.Run(p, rw)
if err == nil {
p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
err = errors.New("protocol returned")
} else {
p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
}
select {
case p.protoErr <- err:
case <-p.closed:
}
p.protoWG.Done()
}()
return rw
}
// getProto finds the protocol responsible for handling
// the given message code.
func (p *Peer) getProto(code uint64) (*proto, error) {
p.runlock.RLock()
defer p.runlock.RUnlock()
func (p *Peer) getProto(code uint64) (*protoRW, error) {
for _, proto := range p.running {
if code >= proto.offset && code < proto.offset+proto.maxcode {
if code >= proto.offset && code < proto.offset+proto.Length {
return proto, nil
}
}
@ -355,46 +247,43 @@ func (p *Peer) getProto(code uint64) (*proto, error) {
}
func (p *Peer) closeProtocols() {
p.runlock.RLock()
for _, p := range p.running {
close(p.in)
}
p.runlock.RUnlock()
p.protoWG.Wait()
}
// writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock()
proto, ok := p.running[protoName]
p.runlock.RUnlock()
if !ok {
return fmt.Errorf("protocol %s not handled by peer", protoName)
}
if msg.Code >= proto.maxcode {
if msg.Code >= proto.Length {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
}
msg.Code += proto.offset
return p.rw.WriteMsg(msg)
}
type proto struct {
name string
in chan Msg
maxcode, offset uint64
w MsgWriter
type protoRW struct {
Protocol
in chan Msg
offset uint64
w MsgWriter
}
func (rw *proto) WriteMsg(msg Msg) error {
if msg.Code >= rw.maxcode {
func (rw *protoRW) WriteMsg(msg Msg) error {
if msg.Code >= rw.Length {
return newPeerError(errInvalidMsgCode, "not handled")
}
msg.Code += rw.offset
return rw.w.WriteMsg(msg)
}
func (rw *proto) ReadMsg() (Msg, error) {
func (rw *protoRW) ReadMsg() (Msg, error) {
msg, ok := <-rw.in
if !ok {
return msg, io.EOF