swarm/pss: Message handler refactor (#18169)
This commit is contained in:
171
swarm/pss/pss.go
171
swarm/pss/pss.go
@@ -23,11 +23,13 @@ import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/sha3"
|
||||
"github.com/ethereum/go-ethereum/metrics"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
@@ -136,10 +138,10 @@ type Pss struct {
|
||||
symKeyDecryptCacheCapacity int // max amount of symkeys to keep.
|
||||
|
||||
// message handling
|
||||
handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle()
|
||||
handlersMu sync.RWMutex
|
||||
allowRaw bool
|
||||
hashPool sync.Pool
|
||||
handlers map[Topic]map[*handler]bool // topic and version based pss payload handlers. See pss.Handle()
|
||||
handlersMu sync.RWMutex
|
||||
hashPool sync.Pool
|
||||
topicHandlerCaps map[Topic]*handlerCaps // caches capabilities of each topic's handlers (see handlerCap* consts in types.go)
|
||||
|
||||
// process
|
||||
quitC chan struct{}
|
||||
@@ -180,11 +182,12 @@ func NewPss(k *network.Kademlia, params *PssParams) (*Pss, error) {
|
||||
symKeyDecryptCache: make([]*string, params.SymKeyCacheCapacity),
|
||||
symKeyDecryptCacheCapacity: params.SymKeyCacheCapacity,
|
||||
|
||||
handlers: make(map[Topic]map[*Handler]bool),
|
||||
allowRaw: params.AllowRaw,
|
||||
handlers: make(map[Topic]map[*handler]bool),
|
||||
topicHandlerCaps: make(map[Topic]*handlerCaps),
|
||||
|
||||
hashPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return storage.MakeHashFunc(storage.DefaultHash)()
|
||||
return sha3.NewKeccak256()
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -313,30 +316,54 @@ func (p *Pss) PublicKey() *ecdsa.PublicKey {
|
||||
//
|
||||
// Returns a deregister function which needs to be called to
|
||||
// deregister the handler,
|
||||
func (p *Pss) Register(topic *Topic, handler Handler) func() {
|
||||
func (p *Pss) Register(topic *Topic, hndlr *handler) func() {
|
||||
p.handlersMu.Lock()
|
||||
defer p.handlersMu.Unlock()
|
||||
handlers := p.handlers[*topic]
|
||||
if handlers == nil {
|
||||
handlers = make(map[*Handler]bool)
|
||||
handlers = make(map[*handler]bool)
|
||||
p.handlers[*topic] = handlers
|
||||
log.Debug("registered handler", "caps", hndlr.caps)
|
||||
}
|
||||
handlers[&handler] = true
|
||||
return func() { p.deregister(topic, &handler) }
|
||||
if hndlr.caps == nil {
|
||||
hndlr.caps = &handlerCaps{}
|
||||
}
|
||||
handlers[hndlr] = true
|
||||
if _, ok := p.topicHandlerCaps[*topic]; !ok {
|
||||
p.topicHandlerCaps[*topic] = &handlerCaps{}
|
||||
}
|
||||
if hndlr.caps.raw {
|
||||
p.topicHandlerCaps[*topic].raw = true
|
||||
}
|
||||
if hndlr.caps.prox {
|
||||
p.topicHandlerCaps[*topic].prox = true
|
||||
}
|
||||
return func() { p.deregister(topic, hndlr) }
|
||||
}
|
||||
func (p *Pss) deregister(topic *Topic, h *Handler) {
|
||||
func (p *Pss) deregister(topic *Topic, hndlr *handler) {
|
||||
p.handlersMu.Lock()
|
||||
defer p.handlersMu.Unlock()
|
||||
handlers := p.handlers[*topic]
|
||||
if len(handlers) == 1 {
|
||||
if len(handlers) > 1 {
|
||||
delete(p.handlers, *topic)
|
||||
// topic caps might have changed now that a handler is gone
|
||||
caps := &handlerCaps{}
|
||||
for h := range handlers {
|
||||
if h.caps.raw {
|
||||
caps.raw = true
|
||||
}
|
||||
if h.caps.prox {
|
||||
caps.prox = true
|
||||
}
|
||||
}
|
||||
p.topicHandlerCaps[*topic] = caps
|
||||
return
|
||||
}
|
||||
delete(handlers, h)
|
||||
delete(handlers, hndlr)
|
||||
}
|
||||
|
||||
// get all registered handlers for respective topics
|
||||
func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
|
||||
func (p *Pss) getHandlers(topic Topic) map[*handler]bool {
|
||||
p.handlersMu.RLock()
|
||||
defer p.handlersMu.RUnlock()
|
||||
return p.handlers[topic]
|
||||
@@ -348,12 +375,11 @@ func (p *Pss) getHandlers(topic Topic) map[*Handler]bool {
|
||||
// Only passes error to pss protocol handler if payload is not valid pssmsg
|
||||
func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
|
||||
metrics.GetOrRegisterCounter("pss.handlepssmsg", nil).Inc(1)
|
||||
|
||||
pssmsg, ok := msg.(*PssMsg)
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid message type. Expected *PssMsg, got %T ", msg)
|
||||
}
|
||||
log.Trace("handler", "self", label(p.Kademlia.BaseAddr()), "topic", label(pssmsg.Payload.Topic[:]))
|
||||
if int64(pssmsg.Expire) < time.Now().Unix() {
|
||||
metrics.GetOrRegisterCounter("pss.expire", nil).Inc(1)
|
||||
log.Warn("pss filtered expired message", "from", common.ToHex(p.Kademlia.BaseAddr()), "to", common.ToHex(pssmsg.To))
|
||||
@@ -365,13 +391,34 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
|
||||
}
|
||||
p.addFwdCache(pssmsg)
|
||||
|
||||
if !p.isSelfPossibleRecipient(pssmsg) {
|
||||
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()))
|
||||
psstopic := Topic(pssmsg.Payload.Topic)
|
||||
|
||||
// raw is simplest handler contingency to check, so check that first
|
||||
var isRaw bool
|
||||
if pssmsg.isRaw() {
|
||||
if !p.topicHandlerCaps[psstopic].raw {
|
||||
log.Debug("No handler for raw message", "topic", psstopic)
|
||||
return nil
|
||||
}
|
||||
isRaw = true
|
||||
}
|
||||
|
||||
// check if we can be recipient:
|
||||
// - no prox handler on message and partial address matches
|
||||
// - prox handler on message and we are in prox regardless of partial address match
|
||||
// store this result so we don't calculate again on every handler
|
||||
var isProx bool
|
||||
if _, ok := p.topicHandlerCaps[psstopic]; ok {
|
||||
isProx = p.topicHandlerCaps[psstopic].prox
|
||||
}
|
||||
isRecipient := p.isSelfPossibleRecipient(pssmsg, isProx)
|
||||
if !isRecipient {
|
||||
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(p.BaseAddr()), "prox", isProx)
|
||||
return p.enqueue(pssmsg)
|
||||
}
|
||||
|
||||
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()))
|
||||
if err := p.process(pssmsg); err != nil {
|
||||
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(p.BaseAddr()), "prox", isProx, "raw", isRaw, "topic", label(pssmsg.Payload.Topic[:]))
|
||||
if err := p.process(pssmsg, isRaw, isProx); err != nil {
|
||||
qerr := p.enqueue(pssmsg)
|
||||
if qerr != nil {
|
||||
return fmt.Errorf("process fail: processerr %v, queueerr: %v", err, qerr)
|
||||
@@ -384,7 +431,7 @@ func (p *Pss) handlePssMsg(ctx context.Context, msg interface{}) error {
|
||||
// Entry point to processing a message for which the current node can be the intended recipient.
|
||||
// Attempts symmetric and asymmetric decryption with stored keys.
|
||||
// Dispatches message to all handlers matching the message topic
|
||||
func (p *Pss) process(pssmsg *PssMsg) error {
|
||||
func (p *Pss) process(pssmsg *PssMsg, raw bool, prox bool) error {
|
||||
metrics.GetOrRegisterCounter("pss.process", nil).Inc(1)
|
||||
|
||||
var err error
|
||||
@@ -397,10 +444,8 @@ func (p *Pss) process(pssmsg *PssMsg) error {
|
||||
|
||||
envelope := pssmsg.Payload
|
||||
psstopic := Topic(envelope.Topic)
|
||||
if pssmsg.isRaw() {
|
||||
if !p.allowRaw {
|
||||
return errors.New("raw message support disabled")
|
||||
}
|
||||
|
||||
if raw {
|
||||
payload = pssmsg.Payload.Data
|
||||
} else {
|
||||
if pssmsg.isSym() {
|
||||
@@ -422,19 +467,27 @@ func (p *Pss) process(pssmsg *PssMsg) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
p.executeHandlers(psstopic, payload, from, asymmetric, keyid)
|
||||
p.executeHandlers(psstopic, payload, from, raw, prox, asymmetric, keyid)
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, asymmetric bool, keyid string) {
|
||||
func (p *Pss) executeHandlers(topic Topic, payload []byte, from *PssAddress, raw bool, prox bool, asymmetric bool, keyid string) {
|
||||
handlers := p.getHandlers(topic)
|
||||
peer := p2p.NewPeer(enode.ID{}, fmt.Sprintf("%x", from), []p2p.Cap{})
|
||||
for f := range handlers {
|
||||
err := (*f)(payload, peer, asymmetric, keyid)
|
||||
for h := range handlers {
|
||||
if !h.caps.raw && raw {
|
||||
log.Warn("norawhandler")
|
||||
continue
|
||||
}
|
||||
if !h.caps.prox && prox {
|
||||
log.Warn("noproxhandler")
|
||||
continue
|
||||
}
|
||||
err := (h.f)(payload, peer, asymmetric, keyid)
|
||||
if err != nil {
|
||||
log.Warn("Pss handler %p failed: %v", f, err)
|
||||
log.Warn("Pss handler failed", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -445,9 +498,23 @@ func (p *Pss) isSelfRecipient(msg *PssMsg) bool {
|
||||
}
|
||||
|
||||
// test match of leftmost bytes in given message to node's Kademlia address
|
||||
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg) bool {
|
||||
func (p *Pss) isSelfPossibleRecipient(msg *PssMsg, prox bool) bool {
|
||||
local := p.Kademlia.BaseAddr()
|
||||
return bytes.Equal(msg.To, local[:len(msg.To)])
|
||||
|
||||
// if a partial address matches we are possible recipient regardless of prox
|
||||
// if not and prox is not set, we are surely not
|
||||
if bytes.Equal(msg.To, local[:len(msg.To)]) {
|
||||
|
||||
return true
|
||||
} else if !prox {
|
||||
return false
|
||||
}
|
||||
|
||||
depth := p.Kademlia.NeighbourhoodDepth()
|
||||
po, _ := p.Kademlia.Pof(p.Kademlia.BaseAddr(), msg.To, 0)
|
||||
log.Trace("selfpossible", "po", po, "depth", depth)
|
||||
|
||||
return depth <= po
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
@@ -684,9 +751,6 @@ func (p *Pss) enqueue(msg *PssMsg) error {
|
||||
//
|
||||
// Will fail if raw messages are disallowed
|
||||
func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
|
||||
if !p.allowRaw {
|
||||
return errors.New("Raw messages not enabled")
|
||||
}
|
||||
pssMsgParams := &msgParams{
|
||||
raw: true,
|
||||
}
|
||||
@@ -699,7 +763,17 @@ func (p *Pss) SendRaw(address PssAddress, topic Topic, msg []byte) error {
|
||||
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
|
||||
pssMsg.Payload = payload
|
||||
p.addFwdCache(pssMsg)
|
||||
return p.enqueue(pssMsg)
|
||||
err := p.enqueue(pssMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if we have a proxhandler on this topic
|
||||
// also deliver message to ourselves
|
||||
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
|
||||
return p.process(pssMsg, true, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send a message using symmetric encryption
|
||||
@@ -800,7 +874,16 @@ func (p *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key []by
|
||||
pssMsg.To = to
|
||||
pssMsg.Expire = uint32(time.Now().Add(p.msgTTL).Unix())
|
||||
pssMsg.Payload = envelope
|
||||
return p.enqueue(pssMsg)
|
||||
err = p.enqueue(pssMsg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, ok := p.topicHandlerCaps[topic]; ok {
|
||||
if p.isSelfPossibleRecipient(pssMsg, true) && p.topicHandlerCaps[topic].prox {
|
||||
return p.process(pssMsg, true, true)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct
|
||||
@@ -895,6 +978,10 @@ func (p *Pss) cleanFwdCache() {
|
||||
}
|
||||
}
|
||||
|
||||
func label(b []byte) string {
|
||||
return fmt.Sprintf("%04x", b[:2])
|
||||
}
|
||||
|
||||
// add a message to the cache
|
||||
func (p *Pss) addFwdCache(msg *PssMsg) error {
|
||||
metrics.GetOrRegisterCounter("pss.addfwdcache", nil).Inc(1)
|
||||
@@ -934,10 +1021,14 @@ func (p *Pss) checkFwdCache(msg *PssMsg) bool {
|
||||
|
||||
// Digest of message
|
||||
func (p *Pss) digest(msg *PssMsg) pssDigest {
|
||||
hasher := p.hashPool.Get().(storage.SwarmHash)
|
||||
return p.digestBytes(msg.serialize())
|
||||
}
|
||||
|
||||
func (p *Pss) digestBytes(msg []byte) pssDigest {
|
||||
hasher := p.hashPool.Get().(hash.Hash)
|
||||
defer p.hashPool.Put(hasher)
|
||||
hasher.Reset()
|
||||
hasher.Write(msg.serialize())
|
||||
hasher.Write(msg)
|
||||
digest := pssDigest{}
|
||||
key := hasher.Sum(nil)
|
||||
copy(digest[:], key[:digestLength])
|
||||
|
Reference in New Issue
Block a user