p2p: rework protocol API

This commit is contained in:
Felix Lange
2014-11-04 13:21:44 +01:00
parent 8cf9ed0ea5
commit f38052c499
14 changed files with 1042 additions and 1307 deletions

View File

@ -2,43 +2,101 @@ package p2p
import (
"bytes"
"fmt"
"net"
"sort"
"sync"
"time"
"github.com/ethereum/go-ethereum/ethutil"
)
// Protocol is implemented by P2P subprotocols.
type Protocol interface {
Start()
Stop()
HandleIn(*Msg, chan *Msg)
HandleOut(*Msg) bool
// Start is called when the protocol becomes active.
// It should read and write messages from rw.
// Messages must be fully consumed.
//
// The connection is closed when Start returns. It should return
// any protocol-level error (such as an I/O error) that is
// encountered.
Start(peer *Peer, rw MsgReadWriter) error
// Offset should return the number of message codes
// used by the protocol.
Offset() MsgCode
Name() string
}
type MsgReader interface {
ReadMsg() (Msg, error)
}
type MsgWriter interface {
WriteMsg(Msg) error
}
// MsgReadWriter is passed to protocols. Protocol implementations can
// use it to write messages back to a connected peer.
type MsgReadWriter interface {
MsgReader
MsgWriter
}
type MsgHandler func(code MsgCode, data *ethutil.Value) error
// MsgLoop reads messages off the given reader and
// calls the handler function for each decoded message until
// it returns an error or the peer connection is closed.
//
// If a message is larger than the given maximum size, RunProtocol
// returns an appropriate error.n
func MsgLoop(r MsgReader, maxsize uint32, handler MsgHandler) error {
for {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Size > maxsize {
return NewPeerError(InvalidMsg, "size %d exceeds maximum size of %d", msg.Size, maxsize)
}
value, err := msg.Data()
if err != nil {
return err
}
if err := handler(msg.Code, value); err != nil {
return err
}
}
}
// the ÐΞVp2p base protocol
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
type bpMsg struct {
code MsgCode
data *ethutil.Value
}
const (
P2PVersion = 0
pingTimeout = 2
pingGracePeriod = 2
p2pVersion = 0
pingTimeout = 2 * time.Second
pingGracePeriod = 2 * time.Second
)
const (
HandshakeMsg = iota
DiscMsg
PingMsg
PongMsg
GetPeersMsg
PeersMsg
offset = 16
// message codes
handshakeMsg = iota
discMsg
pingMsg
pongMsg
getPeersMsg
peersMsg
)
type ProtocolState uint8
const (
nullState = iota
handshakeReceived
baseProtocolOffset MsgCode = 16
baseProtocolMaxMsgSize = 500 * 1024
)
type DiscReason byte
@ -62,7 +120,7 @@ const (
DiscSubprotocolError = 0x10
)
var discReasonToString = map[DiscReason]string{
var discReasonToString = [DiscSubprotocolError + 1]string{
DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol",
@ -82,197 +140,178 @@ func (d DiscReason) String() string {
if len(discReasonToString) < int(d) {
return "Unknown"
}
return discReasonToString[d]
}
type BaseProtocol struct {
peer *Peer
state ProtocolState
stateLock sync.RWMutex
func (bp *baseProtocol) Ping() {
}
func NewBaseProtocol(peer *Peer) *BaseProtocol {
self := &BaseProtocol{
peer: peer,
func (bp *baseProtocol) Offset() MsgCode {
return baseProtocolOffset
}
func (bp *baseProtocol) Start(peer *Peer, rw MsgReadWriter) error {
bp.peer, bp.rw = peer, rw
// Do the handshake.
// TODO: disconnect is valid before handshake, too.
rw.WriteMsg(bp.peer.server.handshakeMsg())
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return NewPeerError(ProtocolBreach, " first message must be handshake")
}
data, err := msg.Data()
if err != nil {
return NewPeerError(InvalidMsg, "%v", err)
}
if err := bp.handleHandshake(data); err != nil {
return err
}
return self
msgin := make(chan bpMsg)
done := make(chan error, 1)
go func() {
done <- MsgLoop(rw, baseProtocolMaxMsgSize,
func(code MsgCode, data *ethutil.Value) error {
msgin <- bpMsg{code, data}
return nil
})
}()
return bp.loop(msgin, done)
}
func (self *BaseProtocol) Start() {
if self.peer != nil {
self.peer.Write("", self.peer.Server().Handshake())
go self.peer.Messenger().PingPong(
pingTimeout*time.Second,
pingGracePeriod*time.Second,
self.Ping,
self.Timeout,
)
}
}
func (bp *baseProtocol) loop(msgin <-chan bpMsg, quit <-chan error) error {
logger.Debugf("pingpong keepalive started at %v\n", time.Now())
messenger := bp.rw.(*proto).messenger
pingTimer := time.NewTimer(pingTimeout)
pinged := true
func (self *BaseProtocol) Stop() {
}
func (self *BaseProtocol) Ping() {
msg, _ := NewMsg(PingMsg)
self.peer.Write("", msg)
}
func (self *BaseProtocol) Timeout() {
self.peerError(PingTimeout, "")
}
func (self *BaseProtocol) Name() string {
return ""
}
func (self *BaseProtocol) Offset() MsgCode {
return offset
}
func (self *BaseProtocol) CheckState(state ProtocolState) bool {
self.stateLock.RLock()
self.stateLock.RUnlock()
if self.state != state {
return false
} else {
return true
}
}
func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) {
if msg.Code() == HandshakeMsg {
self.handleHandshake(msg)
} else {
if !self.CheckState(handshakeReceived) {
self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code())
close(response)
return
}
switch msg.Code() {
case DiscMsg:
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint()))
self.peer.Server().PeerDisconnect() <- DisconnectRequest{
addr: self.peer.Address,
reason: DiscRequested,
for {
select {
case msg := <-msgin:
if err := bp.handle(msg.code, msg.data); err != nil {
return err
}
case PingMsg:
out, _ := NewMsg(PongMsg)
response <- out
case PongMsg:
case GetPeersMsg:
// Peer asked for list of connected peers
if out, err := self.peer.Server().PeersMessage(); err != nil {
response <- out
case err := <-quit:
return err
case <-messenger.pulse:
pingTimer.Reset(pingTimeout)
pinged = false
case <-pingTimer.C:
if pinged {
return NewPeerError(PingTimeout, "")
}
case PeersMsg:
self.handlePeers(msg)
default:
self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code())
logger.Debugf("pinging at %v\n", time.Now())
if err := bp.rw.WriteMsg(NewMsg(pingMsg)); err != nil {
return NewPeerError(WriteError, "%v", err)
}
pinged = true
pingTimer.Reset(pingTimeout)
}
}
close(response)
}
func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) {
// somewhat overly paranoid
allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived)
return
}
func (bp *baseProtocol) handle(code MsgCode, data *ethutil.Value) error {
switch code {
case handshakeMsg:
return NewPeerError(ProtocolBreach, " extra handshake received")
func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) {
err := NewPeerError(errorCode, format, v...)
logger.Warnln(err)
fmt.Println(self.peer, err)
if self.peer != nil {
self.peer.PeerErrorChan() <- err
case discMsg:
logger.Infof("Disconnect requested from peer %v, reason", DiscReason(data.Get(0).Uint()))
bp.peer.server.PeerDisconnect() <- DisconnectRequest{
addr: bp.peer.Address,
reason: DiscRequested,
}
case pingMsg:
return bp.rw.WriteMsg(NewMsg(pongMsg))
case pongMsg:
// reply for ping
case getPeersMsg:
// Peer asked for list of connected peers.
peersRLP := bp.peer.server.encodedPeerList()
if peersRLP != nil {
msg := Msg{
Code: peersMsg,
Size: uint32(len(peersRLP)),
Payload: bytes.NewReader(peersRLP),
}
return bp.rw.WriteMsg(msg)
}
case peersMsg:
bp.handlePeers(data)
default:
return NewPeerError(InvalidMsgCode, "unknown message code %v", code)
}
return nil
}
func (self *BaseProtocol) handlePeers(msg *Msg) {
it := msg.Data().NewIterator()
func (bp *baseProtocol) handlePeers(data *ethutil.Value) {
it := data.NewIterator()
for it.Next() {
ip := net.IP(it.Value().Get(0).Bytes())
port := it.Value().Get(1).Uint()
address := &net.TCPAddr{IP: ip, Port: int(port)}
go self.peer.Server().PeerConnect(address)
go bp.peer.server.PeerConnect(address)
}
}
func (self *BaseProtocol) handleHandshake(msg *Msg) {
self.stateLock.Lock()
defer self.stateLock.Unlock()
if self.state != nullState {
self.peerError(ProtocolBreach, "extra handshake")
return
}
c := msg.Data()
func (bp *baseProtocol) handleHandshake(c *ethutil.Value) error {
var (
p2pVersion = c.Get(0).Uint()
id = c.Get(1).Str()
caps = c.Get(2)
port = c.Get(3).Uint()
pubkey = c.Get(4).Bytes()
remoteVersion = c.Get(0).Uint()
id = c.Get(1).Str()
caps = c.Get(2)
port = c.Get(3).Uint()
pubkey = c.Get(4).Bytes()
)
fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey)
// Check correctness of p2p protocol version
if p2pVersion != P2PVersion {
self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion)
return
if remoteVersion != p2pVersion {
return NewPeerError(P2PVersionMismatch, "Require protocol %d, received %d\n", p2pVersion, remoteVersion)
}
// Handle the pub key (validation, uniqueness)
if len(pubkey) == 0 {
self.peerError(PubkeyMissing, "not supplied in handshake.")
return
return NewPeerError(PubkeyMissing, "not supplied in handshake.")
}
if len(pubkey) != 64 {
self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
return
return NewPeerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8)
}
// Self connect detection
if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 {
self.peerError(PubkeyForbidden, "not allowed to connect to self")
return
// self connect detection
if bytes.Compare(bp.peer.server.ClientIdentity().Pubkey()[1:], pubkey) == 0 {
return NewPeerError(PubkeyForbidden, "not allowed to connect to bp")
}
// register pubkey on server. this also sets the pubkey on the peer (need lock)
if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil {
self.peerError(PubkeyForbidden, err.Error())
return
if err := bp.peer.server.RegisterPubkey(bp.peer, pubkey); err != nil {
return NewPeerError(PubkeyForbidden, err.Error())
}
// check port
if self.peer.Inbound {
if bp.peer.Inbound {
uint16port := uint16(port)
if self.peer.Port > 0 && self.peer.Port != uint16port {
self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port)
return
if bp.peer.Port > 0 && bp.peer.Port != uint16port {
return NewPeerError(PortMismatch, "port mismatch: %v != %v", bp.peer.Port, port)
} else {
self.peer.Port = uint16port
bp.peer.Port = uint16port
}
}
capsIt := caps.NewIterator()
for capsIt.Next() {
cap := capsIt.Value().Str()
self.peer.Caps = append(self.peer.Caps, cap)
bp.peer.Caps = append(bp.peer.Caps, cap)
}
sort.Strings(self.peer.Caps)
self.peer.Messenger().AddProtocols(self.peer.Caps)
self.peer.Id = id
self.state = handshakeReceived
//p.ethereum.PushPeer(p)
// p.ethereum.reactor.Post("peerList", p.ethereum.Peers())
return
sort.Strings(bp.peer.Caps)
bp.rw.(*proto).messenger.setRemoteProtocols(bp.peer.Caps)
bp.peer.Id = id
return nil
}