swarm/pss: Message handler refactor (#18169)

This commit is contained in:
lash
2018-11-26 13:52:04 +01:00
committed by Anton Evangelatov
parent ca228569e4
commit 197d609b9a
10 changed files with 644 additions and 109 deletions

View File

@@ -23,11 +23,13 @@ import (
"crypto/rand"
"errors"
"fmt"
"hash"
"sync"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/sha3"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/enode"
@@ -136,10 +138,10 @@ type Pss struct {
symKeyDecryptCacheCapacity int // max amount of symkeys to keep.
// message handling
handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle()
handlersMu sync.RWMutex
allowRaw bool
hashPool sync.Pool
handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle()
handlersMu sync.RWMutex
hashPool sync.Pool
topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go)
// process
quitC chan struct{}
@@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) {
symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity),
symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity,
handlers: make(map[Topic]map[*Handler]bool),
allowRaw: params.AllowRaw,
handlers: make(map[Topic]map[*handler]bool),
topicHandlerCaps: make(map[Topic]*handlerCaps),
hashPool: sync.Pool{
New: func() interface{} {
return storage.MakeHashFunc(storage.DefaultHash)()
return sha3.NewKeccak256()
},
},
}
@@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey {
//
// Returns a deregister function which needs to be called to
// deregister the handler,
func (p *Pss) Register(topic *Topic, handler Handler) func() {
func (p *Pss) Register(topic *Topic, hndlr *handler) func() {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
if handlers == nil {
handlers = make(map[*Handler]bool)
handlers = make(map[*handler]bool)
p.handlers[*topic] = handlers
log.Debug("registered handler", "caps", hndlr.caps)
}
handlers[&handler] = true
return func() { p.deregister(topic, &handler) }
if hndlr.caps == nil {
hndlr.caps = &handlerCaps{}
}
handlers[hndlr] = true
if _, ok := p.topicHandlerCaps[*topic]; !ok {
p.topicHandlerCaps[*topic] = &handlerCaps{}
}
if hndlr.caps.raw {
p.topicHandlerCaps[*topic].raw = true
}
if hndlr.caps.prox {
p.topicHandlerCaps[*topic].prox = true
}
return func() { p.deregister(topic, hndlr) }
}
func (p *Pss) deregister(topic *Topic, h *Handler) {
func (p *Pss) deregister(topic *Topic, hndlr *handler) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
handlers := p.handlers[*topic]
if len(handlers) == 1 {
if len(handlers) > 1 {
delete(p.handlers, *topic)
// topic caps might have changed now that a handler is gone
caps := &handlerCaps{}
for h := range handlers {
if h.caps.raw {
caps.raw = true
}
if h.caps.prox {
caps.prox = true
}
}
p.topicHandlerCaps[*topic] = caps
return
}
delete(handlers, h)
delete(handlers, hndlr)
}
// get all registered handlers for respective topics
func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
func (p *Pss) getHandlers(topic Topic) map[*handler]bool {
p.handlersMu.RLock()
defer p.handlersMu.RUnlock()
return p.handlers[topic]
@@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
// Only passes error to pss protocol handler if payload is not valid pssmsg
func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1)
pssmsg, ok := msg.(*PssMsg)
if !ok {
return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg)
}
log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:]))
if int64(pssmsg.Expire) < time.Now().Unix() {
metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1)
log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To))
@@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
}
p.addFwdCache(pssmsg)
if !p.isSelfPossibleRecipient(pssmsg) {
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()))
psstopic := Topic(pssmsg.Payload.Topic)
// raw is simplest handler contingency to check, so check that first
var isRaw bool
if pssmsg.isRaw() {
if !p.topicHandlerCaps[psstopic].raw {
log.Debug("No handler for raw message", "topic", psstopic)
return nil
}
isRaw = true
}
// check if we can be recipient:
// - no prox handler on message and partial address matches
// - prox handler on message and we are in prox regardless of partial address match
// store this result so we don't calculate again on every handler
var isProx bool
if _, ok := p.topicHandlerCaps[psstopic]; ok {
isProx = p.topicHandlerCaps[psstopic].prox
}
isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx)
if !isRecipient {
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx)
return p.enqueue(pssmsg)
}
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()))
if err := p.process(pssmsg); err != nil {
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:]))
if err := p.process(pssmsg, isRaw, isProx); err != nil {
qerr := p.enqueue(pssmsg)
if qerr != nil {
return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr)
@@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
// Entry point to processing a message for which the current node can be the intended recipient.
// Attempts symmetric and asymmetric decryption with stored keys.
// Dispatches message to all handlers matching the message topic
func (p *Pss) process(pssmsg *PssMsg) error {
func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error {
metrics.GetOrRegisterCounter("pss.process", nil).Inc(1)
var err error
@@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error {
envelope := pssmsg.Payload
psstopic := Topic(envelope.Topic)
if pssmsg.isRaw() {
if !p.allowRaw {
return errors.New("raw message support disabled")
}
if raw {
payload = pssmsg.Payload.Data
} else {
if pssmsg.isSym() {
@@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error {
return err
}
}
p.executeHandlers(psstopic, payload, from, asymmetric, keyid)
p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid)
return nil
}
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) {
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) {
handlers := p.getHandlers(topic)
peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{})
for f := range handlers {
err := (*f)(payload, peer, asymmetric, keyid)
for h := range handlers {
if !h.caps.raw && raw {
log.Warn("norawhandler")
continue
}
if !h.caps.prox && prox {
log.Warn("noproxhandler")
continue
}
err := (h.f)(payload, peer, asymmetric, keyid)
if err != nil {
log.Warn("Pss handler %p failed: %v", f, err)
log.Warn("Pss handler failed", "err", err)
}
}
}
@@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool {
}
// test match of leftmost bytes in given message to node's Kademlia address
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool {
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool {
local := p.Kademlia.BaseAddr()
return bytes.Equal(msg.To, local[:len(msg.To)])
// if a partial address matches we are possible recipient regardless of prox
// if not and prox is not set, we are surely not
if bytes.Equal(msg.To, local[:len(msg.To)]) {
return true
} else if !prox {
return false
}
depth := p.Kademlia.NeighbourhoodDepth()
po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0)
log.Trace("selfpossible", "po", po, "depth", depth)
return depth <= po
}
/////////////////////////////////////////////////////////////////////
@@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error {
//
// Will fail if raw messages are disallowed
func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
if !p.allowRaw {
return errors.New("Raw messages not enabled")
}
pssMsgParams := &msgParams{
raw: true,
}
@@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = payload
p.addFwdCache(pssMsg)
return p.enqueue(pssMsg)
err := p.enqueue(pssMsg)
if err != nil {
return err
}
// if we have a proxhandler on this topic
// also deliver message to ourselves
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
return p.process(pssMsg, true, true)
}
return nil
}
// Send a message using symmetric encryption
@@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by
pssMsg.To = to
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
pssMsg.Payload = envelope
return p.enqueue(pssMsg)
err = p.enqueue(pssMsg)
if err != nil {
return err
}
if _, ok := p.topicHandlerCaps[topic]; ok {
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
return p.process(pssMsg, true, true)
}
}
return nil
}
// Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct
@@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() {
}
}
func label(b []byte) string {
return fmt.Sprintf("%04x", b[:2])
}
// add a message to the cache
func (p *Pss) addFwdCache(msg *PssMsg) error {
metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1)
@@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool {
// Digest of message
func (p *Pss) digest(msg *PssMsg) pssDigest {
hasher := p.hashPool.Get().(storage.SwarmHash)
return p.digestBytes(msg.serialize())
}
func (p *Pss) digestBytes(msg []byte) pssDigest {
hasher := p.hashPool.Get().(hash.Hash)
defer p.hashPool.Put(hasher)
hasher.Reset()
hasher.Write(msg.serialize())
hasher.Write(msg)
digest := pssDigest{}
key := hasher.Sum(nil)
copy(digest[:], key[:digestLength])