les, les/lespay: implement new server pool (#20758)
This PR reimplements the light client server pool. It is also a first step to move certain logic into a new lespay package. This package will contain the implementation of the lespay token sale functions, the token buying and selling logic and other components related to peer selection/prioritization and service quality evaluation. Over the long term this package will be reusable for incentivizing future protocols. Since the LES peer logic is now based on enode.Iterator, it can now use DNS-based fallback discovery to find servers. This document describes the function of the new components: https://gist.github.com/zsfelfoldi/3c7ace895234b7b345ab4f71dab102d4
This commit is contained in:
@ -51,16 +51,17 @@ import (
|
||||
type LightEthereum struct {
|
||||
lesCommons
|
||||
|
||||
peers *serverPeerSet
|
||||
reqDist *requestDistributor
|
||||
retriever *retrieveManager
|
||||
odr *LesOdr
|
||||
relay *lesTxRelay
|
||||
handler *clientHandler
|
||||
txPool *light.TxPool
|
||||
blockchain *light.LightChain
|
||||
serverPool *serverPool
|
||||
valueTracker *lpc.ValueTracker
|
||||
peers *serverPeerSet
|
||||
reqDist *requestDistributor
|
||||
retriever *retrieveManager
|
||||
odr *LesOdr
|
||||
relay *lesTxRelay
|
||||
handler *clientHandler
|
||||
txPool *light.TxPool
|
||||
blockchain *light.LightChain
|
||||
serverPool *serverPool
|
||||
valueTracker *lpc.ValueTracker
|
||||
dialCandidates enode.Iterator
|
||||
|
||||
bloomRequests chan chan *bloombits.Retrieval // Channel receiving bloom data retrieval requests
|
||||
bloomIndexer *core.ChainIndexer // Bloom indexer operating during block imports
|
||||
@ -104,11 +105,19 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
|
||||
engine: eth.CreateConsensusEngine(ctx, chainConfig, &config.Ethash, nil, false, chainDb),
|
||||
bloomRequests: make(chan chan *bloombits.Retrieval),
|
||||
bloomIndexer: eth.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
|
||||
serverPool: newServerPool(chainDb, config.UltraLightServers),
|
||||
valueTracker: lpc.NewValueTracker(lespayDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
|
||||
}
|
||||
peers.subscribe((*vtSubscription)(leth.valueTracker))
|
||||
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool)
|
||||
|
||||
dnsdisc, err := leth.setupDiscovery(&ctx.Config.P2P)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leth.serverPool = newServerPool(lespayDb, []byte("serverpool:"), leth.valueTracker, dnsdisc, time.Second, nil, &mclock.System{}, config.UltraLightServers)
|
||||
peers.subscribe(leth.serverPool)
|
||||
leth.dialCandidates = leth.serverPool.dialIterator
|
||||
|
||||
leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout)
|
||||
leth.relay = newLesTxRelay(peers, leth.retriever)
|
||||
|
||||
leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.retriever)
|
||||
@ -140,11 +149,6 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
|
||||
leth.chtIndexer.Start(leth.blockchain)
|
||||
leth.bloomIndexer.Start(leth.blockchain)
|
||||
|
||||
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
|
||||
if leth.handler.ulc != nil {
|
||||
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
|
||||
leth.blockchain.DisableCheckFreq()
|
||||
}
|
||||
// Rewind the chain in case of an incompatible config upgrade.
|
||||
if compat, ok := genesisErr.(*params.ConfigCompatError); ok {
|
||||
log.Warn("Rewinding chain to upgrade configuration", "err", compat)
|
||||
@ -159,6 +163,11 @@ func New(ctx *node.ServiceContext, config *eth.Config) (*LightEthereum, error) {
|
||||
}
|
||||
leth.ApiBackend.gpo = gasprice.NewOracle(leth.ApiBackend, gpoParams)
|
||||
|
||||
leth.handler = newClientHandler(config.UltraLightServers, config.UltraLightFraction, checkpoint, leth)
|
||||
if leth.handler.ulc != nil {
|
||||
log.Warn("Ultra light client is enabled", "trustedNodes", len(leth.handler.ulc.keys), "minTrustedFraction", leth.handler.ulc.fraction)
|
||||
leth.blockchain.DisableCheckFreq()
|
||||
}
|
||||
return leth, nil
|
||||
}
|
||||
|
||||
@ -260,7 +269,7 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
|
||||
return p.Info()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}, s.dialCandidates)
|
||||
}
|
||||
|
||||
// Start implements node.Service, starting all internal goroutines needed by the
|
||||
@ -268,15 +277,12 @@ func (s *LightEthereum) Protocols() []p2p.Protocol {
|
||||
func (s *LightEthereum) Start(srvr *p2p.Server) error {
|
||||
log.Warn("Light client mode is an experimental feature")
|
||||
|
||||
s.serverPool.start()
|
||||
// Start bloom request workers.
|
||||
s.wg.Add(bloomServiceThreads)
|
||||
s.startBloomHandlers(params.BloomBitsBlocksClient)
|
||||
|
||||
s.netRPCService = ethapi.NewPublicNetAPI(srvr, s.config.NetworkId)
|
||||
|
||||
// clients are searching for the first advertised protocol in the list
|
||||
protocolVersion := AdvertiseProtocolVersions[0]
|
||||
s.serverPool.start(srvr, lesTopic(s.blockchain.Genesis().Hash(), protocolVersion))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -284,6 +290,8 @@ func (s *LightEthereum) Start(srvr *p2p.Server) error {
|
||||
// Ethereum protocol.
|
||||
func (s *LightEthereum) Stop() error {
|
||||
close(s.closeCh)
|
||||
s.serverPool.stop()
|
||||
s.valueTracker.Stop()
|
||||
s.peers.close()
|
||||
s.reqDist.close()
|
||||
s.odr.Stop()
|
||||
@ -295,8 +303,6 @@ func (s *LightEthereum) Stop() error {
|
||||
s.txPool.Stop()
|
||||
s.engine.Close()
|
||||
s.eventMux.Stop()
|
||||
s.serverPool.stop()
|
||||
s.valueTracker.Stop()
|
||||
s.chainDb.Close()
|
||||
s.wg.Wait()
|
||||
log.Info("Light ethereum stopped")
|
||||
|
@ -64,7 +64,7 @@ func newClientHandler(ulcServers []string, ulcFraction int, checkpoint *params.T
|
||||
if checkpoint != nil {
|
||||
height = (checkpoint.SectionIndex+1)*params.CHTFrequency - 1
|
||||
}
|
||||
handler.fetcher = newLightFetcher(handler)
|
||||
handler.fetcher = newLightFetcher(handler, backend.serverPool.getTimeout)
|
||||
handler.downloader = downloader.New(height, backend.chainDb, nil, backend.eventMux, nil, backend.blockchain, handler.removePeer)
|
||||
handler.backend.peers.subscribe((*downloaderPeerNotify)(handler))
|
||||
return handler
|
||||
@ -85,14 +85,9 @@ func (h *clientHandler) runPeer(version uint, p *p2p.Peer, rw p2p.MsgReadWriter)
|
||||
}
|
||||
peer := newServerPeer(int(version), h.backend.config.NetworkId, trusted, p, newMeteredMsgWriter(rw, int(version)))
|
||||
defer peer.close()
|
||||
peer.poolEntry = h.backend.serverPool.connect(peer, peer.Node())
|
||||
if peer.poolEntry == nil {
|
||||
return p2p.DiscRequested
|
||||
}
|
||||
h.wg.Add(1)
|
||||
defer h.wg.Done()
|
||||
err := h.handle(peer)
|
||||
h.backend.serverPool.disconnect(peer.poolEntry)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -129,10 +124,6 @@ func (h *clientHandler) handle(p *serverPeer) error {
|
||||
|
||||
h.fetcher.announce(p, &announceData{Hash: p.headInfo.Hash, Number: p.headInfo.Number, Td: p.headInfo.Td})
|
||||
|
||||
// pool entry can be nil during the unit test.
|
||||
if p.poolEntry != nil {
|
||||
h.backend.serverPool.registered(p.poolEntry)
|
||||
}
|
||||
// Mark the peer starts to be served.
|
||||
atomic.StoreUint32(&p.serving, 1)
|
||||
defer atomic.StoreUint32(&p.serving, 0)
|
||||
|
@ -81,7 +81,7 @@ type NodeInfo struct {
|
||||
}
|
||||
|
||||
// makeProtocols creates protocol descriptors for the given LES versions.
|
||||
func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}) []p2p.Protocol {
|
||||
func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p *p2p.Peer, rw p2p.MsgReadWriter) error, peerInfo func(id enode.ID) interface{}, dialCandidates enode.Iterator) []p2p.Protocol {
|
||||
protos := make([]p2p.Protocol, len(versions))
|
||||
for i, version := range versions {
|
||||
version := version
|
||||
@ -93,7 +93,8 @@ func (c *lesCommons) makeProtocols(versions []uint, runPeer func(version uint, p
|
||||
Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
|
||||
return runPeer(version, peer, rw)
|
||||
},
|
||||
PeerInfo: peerInfo,
|
||||
PeerInfo: peerInfo,
|
||||
DialCandidates: dialCandidates,
|
||||
}
|
||||
}
|
||||
return protos
|
||||
|
@ -180,12 +180,11 @@ func (d *requestDistributor) loop() {
|
||||
type selectPeerItem struct {
|
||||
peer distPeer
|
||||
req *distReq
|
||||
weight int64
|
||||
weight uint64
|
||||
}
|
||||
|
||||
// Weight implements wrsItem interface
|
||||
func (sp selectPeerItem) Weight() int64 {
|
||||
return sp.weight
|
||||
func selectPeerWeight(i interface{}) uint64 {
|
||||
return i.(selectPeerItem).weight
|
||||
}
|
||||
|
||||
// nextRequest returns the next possible request from any peer, along with the
|
||||
@ -220,9 +219,9 @@ func (d *requestDistributor) nextRequest() (distPeer, *distReq, time.Duration) {
|
||||
wait, bufRemain := peer.waitBefore(cost)
|
||||
if wait == 0 {
|
||||
if sel == nil {
|
||||
sel = utils.NewWeightedRandomSelect()
|
||||
sel = utils.NewWeightedRandomSelect(selectPeerWeight)
|
||||
}
|
||||
sel.Update(selectPeerItem{peer: peer, req: req, weight: int64(bufRemain*1000000) + 1})
|
||||
sel.Update(selectPeerItem{peer: peer, req: req, weight: uint64(bufRemain*1000000) + 1})
|
||||
} else {
|
||||
if bestWait == 0 || wait < bestWait {
|
||||
bestWait = wait
|
||||
|
@ -17,6 +17,9 @@
|
||||
package les
|
||||
|
||||
import (
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/p2p/dnsdisc"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
)
|
||||
|
||||
@ -30,3 +33,12 @@ type lesEntry struct {
|
||||
func (e lesEntry) ENRKey() string {
|
||||
return "les"
|
||||
}
|
||||
|
||||
// setupDiscovery creates the node discovery source for the eth protocol.
|
||||
func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) {
|
||||
if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
client := dnsdisc.NewClient(dnsdisc.Config{})
|
||||
return client.NewIterator(eth.config.DiscoveryURLs...)
|
||||
}
|
||||
|
@ -40,8 +40,9 @@ const (
|
||||
// ODR system to ensure that we only request data related to a certain block from peers who have already processed
|
||||
// and announced that block.
|
||||
type lightFetcher struct {
|
||||
handler *clientHandler
|
||||
chain *light.LightChain
|
||||
handler *clientHandler
|
||||
chain *light.LightChain
|
||||
softRequestTimeout func() time.Duration
|
||||
|
||||
lock sync.Mutex // lock protects access to the fetcher's internal state variables except sent requests
|
||||
maxConfirmedTd *big.Int
|
||||
@ -109,18 +110,19 @@ type fetchResponse struct {
|
||||
}
|
||||
|
||||
// newLightFetcher creates a new light fetcher
|
||||
func newLightFetcher(h *clientHandler) *lightFetcher {
|
||||
func newLightFetcher(h *clientHandler, softRequestTimeout func() time.Duration) *lightFetcher {
|
||||
f := &lightFetcher{
|
||||
handler: h,
|
||||
chain: h.backend.blockchain,
|
||||
peers: make(map[*serverPeer]*fetcherPeerInfo),
|
||||
deliverChn: make(chan fetchResponse, 100),
|
||||
requested: make(map[uint64]fetchRequest),
|
||||
timeoutChn: make(chan uint64),
|
||||
requestTrigger: make(chan struct{}, 1),
|
||||
syncDone: make(chan *serverPeer),
|
||||
closeCh: make(chan struct{}),
|
||||
maxConfirmedTd: big.NewInt(0),
|
||||
handler: h,
|
||||
chain: h.backend.blockchain,
|
||||
peers: make(map[*serverPeer]*fetcherPeerInfo),
|
||||
deliverChn: make(chan fetchResponse, 100),
|
||||
requested: make(map[uint64]fetchRequest),
|
||||
timeoutChn: make(chan uint64),
|
||||
requestTrigger: make(chan struct{}, 1),
|
||||
syncDone: make(chan *serverPeer),
|
||||
closeCh: make(chan struct{}),
|
||||
maxConfirmedTd: big.NewInt(0),
|
||||
softRequestTimeout: softRequestTimeout,
|
||||
}
|
||||
h.backend.peers.subscribe(f)
|
||||
|
||||
@ -163,7 +165,7 @@ func (f *lightFetcher) syncLoop() {
|
||||
f.lock.Unlock()
|
||||
} else {
|
||||
go func() {
|
||||
time.Sleep(softRequestTimeout)
|
||||
time.Sleep(f.softRequestTimeout())
|
||||
f.reqMu.Lock()
|
||||
req, ok := f.requested[reqID]
|
||||
if ok {
|
||||
@ -187,7 +189,6 @@ func (f *lightFetcher) syncLoop() {
|
||||
}
|
||||
f.reqMu.Unlock()
|
||||
if ok {
|
||||
f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), true)
|
||||
req.peer.Log().Debug("Fetching data timed out hard")
|
||||
go f.handler.removePeer(req.peer.id)
|
||||
}
|
||||
@ -201,9 +202,6 @@ func (f *lightFetcher) syncLoop() {
|
||||
delete(f.requested, resp.reqID)
|
||||
}
|
||||
f.reqMu.Unlock()
|
||||
if ok {
|
||||
f.handler.backend.serverPool.adjustResponseTime(req.peer.poolEntry, time.Duration(mclock.Now()-req.sent), req.timeout)
|
||||
}
|
||||
f.lock.Lock()
|
||||
if !ok || !(f.syncing || f.processResponse(req, resp)) {
|
||||
resp.peer.Log().Debug("Failed processing response")
|
||||
@ -879,12 +877,10 @@ func (f *lightFetcher) checkUpdateStats(p *serverPeer, newEntry *updateStatsEntr
|
||||
fp.firstUpdateStats = newEntry
|
||||
}
|
||||
for fp.firstUpdateStats != nil && fp.firstUpdateStats.time <= now-mclock.AbsTime(blockDelayTimeout) {
|
||||
f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, blockDelayTimeout)
|
||||
fp.firstUpdateStats = fp.firstUpdateStats.next
|
||||
}
|
||||
if fp.confirmedTd != nil {
|
||||
for fp.firstUpdateStats != nil && fp.firstUpdateStats.td.Cmp(fp.confirmedTd) <= 0 {
|
||||
f.handler.backend.serverPool.adjustBlockDelay(p.poolEntry, time.Duration(now-fp.firstUpdateStats.time))
|
||||
fp.firstUpdateStats = fp.firstUpdateStats.next
|
||||
}
|
||||
}
|
||||
|
107
les/lespay/client/fillset.go
Normal file
107
les/lespay/client/fillset.go
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
// FillSet tries to read nodes from an input iterator and add them to a node set by
|
||||
// setting the specified node state flag(s) until the size of the set reaches the target.
|
||||
// Note that other mechanisms (like other FillSet instances reading from different inputs)
|
||||
// can also set the same flag(s) and FillSet will always care about the total number of
|
||||
// nodes having those flags.
|
||||
type FillSet struct {
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
ns *nodestate.NodeStateMachine
|
||||
input enode.Iterator
|
||||
closed bool
|
||||
flags nodestate.Flags
|
||||
count, target int
|
||||
}
|
||||
|
||||
// NewFillSet creates a new FillSet
|
||||
func NewFillSet(ns *nodestate.NodeStateMachine, input enode.Iterator, flags nodestate.Flags) *FillSet {
|
||||
fs := &FillSet{
|
||||
ns: ns,
|
||||
input: input,
|
||||
flags: flags,
|
||||
}
|
||||
fs.cond = sync.NewCond(&fs.lock)
|
||||
|
||||
ns.SubscribeState(flags, func(n *enode.Node, oldState, newState nodestate.Flags) {
|
||||
fs.lock.Lock()
|
||||
if oldState.Equals(flags) {
|
||||
fs.count--
|
||||
}
|
||||
if newState.Equals(flags) {
|
||||
fs.count++
|
||||
}
|
||||
if fs.target > fs.count {
|
||||
fs.cond.Signal()
|
||||
}
|
||||
fs.lock.Unlock()
|
||||
})
|
||||
|
||||
go fs.readLoop()
|
||||
return fs
|
||||
}
|
||||
|
||||
// readLoop keeps reading nodes from the input and setting the specified flags for them
|
||||
// whenever the node set size is under the current target
|
||||
func (fs *FillSet) readLoop() {
|
||||
for {
|
||||
fs.lock.Lock()
|
||||
for fs.target <= fs.count && !fs.closed {
|
||||
fs.cond.Wait()
|
||||
}
|
||||
|
||||
fs.lock.Unlock()
|
||||
if !fs.input.Next() {
|
||||
return
|
||||
}
|
||||
fs.ns.SetState(fs.input.Node(), fs.flags, nodestate.Flags{}, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTarget sets the current target for node set size. If the previous target was not
|
||||
// reached and FillSet was still waiting for the next node from the input then the next
|
||||
// incoming node will be added to the set regardless of the target. This ensures that
|
||||
// all nodes coming from the input are eventually added to the set.
|
||||
func (fs *FillSet) SetTarget(target int) {
|
||||
fs.lock.Lock()
|
||||
defer fs.lock.Unlock()
|
||||
|
||||
fs.target = target
|
||||
if fs.target > fs.count {
|
||||
fs.cond.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts FillSet down and closes the input iterator
|
||||
func (fs *FillSet) Close() {
|
||||
fs.lock.Lock()
|
||||
defer fs.lock.Unlock()
|
||||
|
||||
fs.closed = true
|
||||
fs.input.Close()
|
||||
fs.cond.Signal()
|
||||
}
|
113
les/lespay/client/fillset_test.go
Normal file
113
les/lespay/client/fillset_test.go
Normal file
@ -0,0 +1,113 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
type testIter struct {
|
||||
waitCh chan struct{}
|
||||
nodeCh chan *enode.Node
|
||||
node *enode.Node
|
||||
}
|
||||
|
||||
func (i *testIter) Next() bool {
|
||||
i.waitCh <- struct{}{}
|
||||
i.node = <-i.nodeCh
|
||||
return i.node != nil
|
||||
}
|
||||
|
||||
func (i *testIter) Node() *enode.Node {
|
||||
return i.node
|
||||
}
|
||||
|
||||
func (i *testIter) Close() {}
|
||||
|
||||
func (i *testIter) push() {
|
||||
var id enode.ID
|
||||
rand.Read(id[:])
|
||||
i.nodeCh <- enode.SignNull(new(enr.Record), id)
|
||||
}
|
||||
|
||||
func (i *testIter) waiting(timeout time.Duration) bool {
|
||||
select {
|
||||
case <-i.waitCh:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillSet(t *testing.T) {
|
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
|
||||
iter := &testIter{
|
||||
waitCh: make(chan struct{}),
|
||||
nodeCh: make(chan *enode.Node),
|
||||
}
|
||||
fs := NewFillSet(ns, iter, sfTest1)
|
||||
ns.Start()
|
||||
|
||||
expWaiting := func(i int, push bool) {
|
||||
for ; i > 0; i-- {
|
||||
if !iter.waiting(time.Second * 10) {
|
||||
t.Fatalf("FillSet not waiting for new nodes")
|
||||
}
|
||||
if push {
|
||||
iter.push()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expNotWaiting := func() {
|
||||
if iter.waiting(time.Millisecond * 100) {
|
||||
t.Fatalf("FillSet unexpectedly waiting for new nodes")
|
||||
}
|
||||
}
|
||||
|
||||
expNotWaiting()
|
||||
fs.SetTarget(3)
|
||||
expWaiting(3, true)
|
||||
expNotWaiting()
|
||||
fs.SetTarget(100)
|
||||
expWaiting(2, true)
|
||||
expWaiting(1, false)
|
||||
// lower the target before the previous one has been filled up
|
||||
fs.SetTarget(0)
|
||||
iter.push()
|
||||
expNotWaiting()
|
||||
fs.SetTarget(10)
|
||||
expWaiting(4, true)
|
||||
expNotWaiting()
|
||||
// remove all previosly set flags
|
||||
ns.ForEach(sfTest1, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
|
||||
ns.SetState(node, nodestate.Flags{}, sfTest1, 0)
|
||||
})
|
||||
// now expect FillSet to fill the set up again with 10 new nodes
|
||||
expWaiting(10, true)
|
||||
expNotWaiting()
|
||||
|
||||
fs.Close()
|
||||
ns.Stop()
|
||||
}
|
123
les/lespay/client/queueiterator.go
Normal file
123
les/lespay/client/queueiterator.go
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
// QueueIterator returns nodes from the specified selectable set in the same order as
|
||||
// they entered the set.
|
||||
type QueueIterator struct {
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
|
||||
ns *nodestate.NodeStateMachine
|
||||
queue []*enode.Node
|
||||
nextNode *enode.Node
|
||||
waitCallback func(bool)
|
||||
fifo, closed bool
|
||||
}
|
||||
|
||||
// NewQueueIterator creates a new QueueIterator. Nodes are selectable if they have all the required
|
||||
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
|
||||
// disables further selectability until it is removed or times out.
|
||||
func NewQueueIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, fifo bool, waitCallback func(bool)) *QueueIterator {
|
||||
qi := &QueueIterator{
|
||||
ns: ns,
|
||||
fifo: fifo,
|
||||
waitCallback: waitCallback,
|
||||
}
|
||||
qi.cond = sync.NewCond(&qi.lock)
|
||||
|
||||
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
|
||||
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
|
||||
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
|
||||
if newMatch == oldMatch {
|
||||
return
|
||||
}
|
||||
|
||||
qi.lock.Lock()
|
||||
defer qi.lock.Unlock()
|
||||
|
||||
if newMatch {
|
||||
qi.queue = append(qi.queue, n)
|
||||
} else {
|
||||
id := n.ID()
|
||||
for i, qn := range qi.queue {
|
||||
if qn.ID() == id {
|
||||
copy(qi.queue[i:len(qi.queue)-1], qi.queue[i+1:])
|
||||
qi.queue = qi.queue[:len(qi.queue)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
qi.cond.Signal()
|
||||
})
|
||||
return qi
|
||||
}
|
||||
|
||||
// Next moves to the next selectable node.
|
||||
func (qi *QueueIterator) Next() bool {
|
||||
qi.lock.Lock()
|
||||
if !qi.closed && len(qi.queue) == 0 {
|
||||
if qi.waitCallback != nil {
|
||||
qi.waitCallback(true)
|
||||
}
|
||||
for !qi.closed && len(qi.queue) == 0 {
|
||||
qi.cond.Wait()
|
||||
}
|
||||
if qi.waitCallback != nil {
|
||||
qi.waitCallback(false)
|
||||
}
|
||||
}
|
||||
if qi.closed {
|
||||
qi.nextNode = nil
|
||||
qi.lock.Unlock()
|
||||
return false
|
||||
}
|
||||
// Move to the next node in queue.
|
||||
if qi.fifo {
|
||||
qi.nextNode = qi.queue[0]
|
||||
copy(qi.queue[:len(qi.queue)-1], qi.queue[1:])
|
||||
qi.queue = qi.queue[:len(qi.queue)-1]
|
||||
} else {
|
||||
qi.nextNode = qi.queue[len(qi.queue)-1]
|
||||
qi.queue = qi.queue[:len(qi.queue)-1]
|
||||
}
|
||||
qi.lock.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// Close ends the iterator.
|
||||
func (qi *QueueIterator) Close() {
|
||||
qi.lock.Lock()
|
||||
qi.closed = true
|
||||
qi.lock.Unlock()
|
||||
qi.cond.Signal()
|
||||
}
|
||||
|
||||
// Node returns the current node.
|
||||
func (qi *QueueIterator) Node() *enode.Node {
|
||||
qi.lock.Lock()
|
||||
defer qi.lock.Unlock()
|
||||
|
||||
return qi.nextNode
|
||||
}
|
106
les/lespay/client/queueiterator_test.go
Normal file
106
les/lespay/client/queueiterator_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
func testNodeID(i int) enode.ID {
|
||||
return enode.ID{42, byte(i % 256), byte(i / 256)}
|
||||
}
|
||||
|
||||
func testNodeIndex(id enode.ID) int {
|
||||
if id[0] != 42 {
|
||||
return -1
|
||||
}
|
||||
return int(id[1]) + int(id[2])*256
|
||||
}
|
||||
|
||||
func testNode(i int) *enode.Node {
|
||||
return enode.SignNull(new(enr.Record), testNodeID(i))
|
||||
}
|
||||
|
||||
func TestQueueIteratorFIFO(t *testing.T) {
|
||||
testQueueIterator(t, true)
|
||||
}
|
||||
|
||||
func TestQueueIteratorLIFO(t *testing.T) {
|
||||
testQueueIterator(t, false)
|
||||
}
|
||||
|
||||
func testQueueIterator(t *testing.T, fifo bool) {
|
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
|
||||
qi := NewQueueIterator(ns, sfTest2, sfTest3.Or(sfTest4), fifo, nil)
|
||||
ns.Start()
|
||||
for i := 1; i <= iterTestNodeCount; i++ {
|
||||
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
|
||||
}
|
||||
next := func() int {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
qi.Next()
|
||||
close(ch)
|
||||
}()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("Iterator.Next() timeout")
|
||||
}
|
||||
node := qi.Node()
|
||||
ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
|
||||
return testNodeIndex(node.ID())
|
||||
}
|
||||
exp := func(i int) {
|
||||
n := next()
|
||||
if n != i {
|
||||
t.Errorf("Wrong item returned by iterator (expected %d, got %d)", i, n)
|
||||
}
|
||||
}
|
||||
explist := func(list []int) {
|
||||
for i := range list {
|
||||
if fifo {
|
||||
exp(list[i])
|
||||
} else {
|
||||
exp(list[len(list)-1-i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
|
||||
explist([]int{1, 2, 3})
|
||||
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(5), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(5), sfTest3, nodestate.Flags{}, 0)
|
||||
explist([]int{4, 6})
|
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(2), sfTest3, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest3, 0)
|
||||
explist([]int{1, 3, 2})
|
||||
ns.Stop()
|
||||
}
|
@ -213,6 +213,15 @@ func (vt *ValueTracker) StatsExpirer() *utils.Expirer {
|
||||
return &vt.statsExpirer
|
||||
}
|
||||
|
||||
// StatsExpirer returns the current expiration factor so that other values can be expired
|
||||
// with the same rate as the service value statistics.
|
||||
func (vt *ValueTracker) StatsExpFactor() utils.ExpirationFactor {
|
||||
vt.statsExpLock.RLock()
|
||||
defer vt.statsExpLock.RUnlock()
|
||||
|
||||
return vt.statsExpFactor
|
||||
}
|
||||
|
||||
// loadFromDb loads the value tracker's state from the database and converts saved
|
||||
// request basket index mapping if it does not match the specified index to name mapping.
|
||||
func (vt *ValueTracker) loadFromDb(mapping []string) error {
|
||||
@ -500,16 +509,3 @@ func (vt *ValueTracker) RequestStats() []RequestStatsItem {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// TotalServiceValue returns the total service value provided by the given node (as
|
||||
// a function of the weights which are calculated from the request timeout value).
|
||||
func (vt *ValueTracker) TotalServiceValue(nv *NodeValueTracker, weights ResponseTimeWeights) float64 {
|
||||
vt.statsExpLock.RLock()
|
||||
expFactor := vt.statsExpFactor
|
||||
vt.statsExpLock.RUnlock()
|
||||
|
||||
nv.lock.Lock()
|
||||
defer nv.lock.Unlock()
|
||||
|
||||
return nv.rtStats.Value(weights, expFactor)
|
||||
}
|
||||
|
128
les/lespay/client/wrsiterator.go
Normal file
128
les/lespay/client/wrsiterator.go
Normal file
@ -0,0 +1,128 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/ethereum/go-ethereum/les/utils"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
// WrsIterator returns nodes from the specified selectable set with a weighted random
|
||||
// selection. Selection weights are provided by a callback function.
|
||||
type WrsIterator struct {
|
||||
lock sync.Mutex
|
||||
cond *sync.Cond
|
||||
|
||||
ns *nodestate.NodeStateMachine
|
||||
wrs *utils.WeightedRandomSelect
|
||||
nextNode *enode.Node
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewWrsIterator creates a new WrsIterator. Nodes are selectable if they have all the required
|
||||
// and none of the disabled flags set. When a node is selected the selectedFlag is set which also
|
||||
// disables further selectability until it is removed or times out.
|
||||
func NewWrsIterator(ns *nodestate.NodeStateMachine, requireFlags, disableFlags nodestate.Flags, weightField nodestate.Field) *WrsIterator {
|
||||
wfn := func(i interface{}) uint64 {
|
||||
n := ns.GetNode(i.(enode.ID))
|
||||
if n == nil {
|
||||
return 0
|
||||
}
|
||||
wt, _ := ns.GetField(n, weightField).(uint64)
|
||||
return wt
|
||||
}
|
||||
|
||||
w := &WrsIterator{
|
||||
ns: ns,
|
||||
wrs: utils.NewWeightedRandomSelect(wfn),
|
||||
}
|
||||
w.cond = sync.NewCond(&w.lock)
|
||||
|
||||
ns.SubscribeField(weightField, func(n *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) {
|
||||
if state.HasAll(requireFlags) && state.HasNone(disableFlags) {
|
||||
w.lock.Lock()
|
||||
w.wrs.Update(n.ID())
|
||||
w.lock.Unlock()
|
||||
w.cond.Signal()
|
||||
}
|
||||
})
|
||||
|
||||
ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState nodestate.Flags) {
|
||||
oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags)
|
||||
newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags)
|
||||
if newMatch == oldMatch {
|
||||
return
|
||||
}
|
||||
|
||||
w.lock.Lock()
|
||||
if newMatch {
|
||||
w.wrs.Update(n.ID())
|
||||
} else {
|
||||
w.wrs.Remove(n.ID())
|
||||
}
|
||||
w.lock.Unlock()
|
||||
w.cond.Signal()
|
||||
})
|
||||
return w
|
||||
}
|
||||
|
||||
// Next selects the next node.
|
||||
func (w *WrsIterator) Next() bool {
|
||||
w.nextNode = w.chooseNode()
|
||||
return w.nextNode != nil
|
||||
}
|
||||
|
||||
func (w *WrsIterator) chooseNode() *enode.Node {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
for {
|
||||
for !w.closed && w.wrs.IsEmpty() {
|
||||
w.cond.Wait()
|
||||
}
|
||||
if w.closed {
|
||||
return nil
|
||||
}
|
||||
// Choose the next node at random. Even though w.wrs is guaranteed
|
||||
// non-empty here, Choose might return nil if all items have weight
|
||||
// zero.
|
||||
if c := w.wrs.Choose(); c != nil {
|
||||
id := c.(enode.ID)
|
||||
w.wrs.Remove(id)
|
||||
return w.ns.GetNode(id)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Close ends the iterator.
|
||||
func (w *WrsIterator) Close() {
|
||||
w.lock.Lock()
|
||||
w.closed = true
|
||||
w.lock.Unlock()
|
||||
w.cond.Signal()
|
||||
}
|
||||
|
||||
// Node returns the current node.
|
||||
func (w *WrsIterator) Node() *enode.Node {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
return w.nextNode
|
||||
}
|
103
les/lespay/client/wrsiterator_test.go
Normal file
103
les/lespay/client/wrsiterator_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
"github.com/ethereum/go-ethereum/p2p/nodestate"
|
||||
)
|
||||
|
||||
var (
|
||||
testSetup = &nodestate.Setup{}
|
||||
sfTest1 = testSetup.NewFlag("test1")
|
||||
sfTest2 = testSetup.NewFlag("test2")
|
||||
sfTest3 = testSetup.NewFlag("test3")
|
||||
sfTest4 = testSetup.NewFlag("test4")
|
||||
sfiTestWeight = testSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
|
||||
)
|
||||
|
||||
const iterTestNodeCount = 6
|
||||
|
||||
func TestWrsIterator(t *testing.T) {
|
||||
ns := nodestate.NewNodeStateMachine(nil, nil, &mclock.Simulated{}, testSetup)
|
||||
w := NewWrsIterator(ns, sfTest2, sfTest3.Or(sfTest4), sfiTestWeight)
|
||||
ns.Start()
|
||||
for i := 1; i <= iterTestNodeCount; i++ {
|
||||
ns.SetState(testNode(i), sfTest1, nodestate.Flags{}, 0)
|
||||
ns.SetField(testNode(i), sfiTestWeight, uint64(1))
|
||||
}
|
||||
next := func() int {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
w.Next()
|
||||
close(ch)
|
||||
}()
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(time.Second * 5):
|
||||
t.Fatalf("Iterator.Next() timeout")
|
||||
}
|
||||
node := w.Node()
|
||||
ns.SetState(node, sfTest4, nodestate.Flags{}, 0)
|
||||
return testNodeIndex(node.ID())
|
||||
}
|
||||
set := make(map[int]bool)
|
||||
expset := func() {
|
||||
for len(set) > 0 {
|
||||
n := next()
|
||||
if !set[n] {
|
||||
t.Errorf("Item returned by iterator not in the expected set (got %d)", n)
|
||||
}
|
||||
delete(set, n)
|
||||
}
|
||||
}
|
||||
|
||||
ns.SetState(testNode(1), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(2), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(3), sfTest2, nodestate.Flags{}, 0)
|
||||
set[1] = true
|
||||
set[2] = true
|
||||
set[3] = true
|
||||
expset()
|
||||
ns.SetState(testNode(4), sfTest2, nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(5), sfTest2.Or(sfTest3), nodestate.Flags{}, 0)
|
||||
ns.SetState(testNode(6), sfTest2, nodestate.Flags{}, 0)
|
||||
set[4] = true
|
||||
set[6] = true
|
||||
expset()
|
||||
ns.SetField(testNode(2), sfiTestWeight, uint64(0))
|
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
|
||||
set[1] = true
|
||||
set[3] = true
|
||||
expset()
|
||||
ns.SetField(testNode(2), sfiTestWeight, uint64(1))
|
||||
ns.SetState(testNode(2), nodestate.Flags{}, sfTest2, 0)
|
||||
ns.SetState(testNode(1), nodestate.Flags{}, sfTest4, 0)
|
||||
ns.SetState(testNode(2), sfTest2, sfTest4, 0)
|
||||
ns.SetState(testNode(3), nodestate.Flags{}, sfTest4, 0)
|
||||
set[1] = true
|
||||
set[2] = true
|
||||
set[3] = true
|
||||
expset()
|
||||
ns.Stop()
|
||||
}
|
@ -107,6 +107,13 @@ var (
|
||||
|
||||
requestRTT = metrics.NewRegisteredTimer("les/client/req/rtt", nil)
|
||||
requestSendDelay = metrics.NewRegisteredTimer("les/client/req/sendDelay", nil)
|
||||
|
||||
serverSelectableGauge = metrics.NewRegisteredGauge("les/client/serverPool/selectable", nil)
|
||||
serverDialedMeter = metrics.NewRegisteredMeter("les/client/serverPool/dialed", nil)
|
||||
serverConnectedGauge = metrics.NewRegisteredGauge("les/client/serverPool/connected", nil)
|
||||
sessionValueMeter = metrics.NewRegisteredMeter("les/client/serverPool/sessionValue", nil)
|
||||
totalValueGauge = metrics.NewRegisteredGauge("les/client/serverPool/totalValue", nil)
|
||||
suggestedTimeoutGauge = metrics.NewRegisteredGauge("les/client/serverPool/timeout", nil)
|
||||
)
|
||||
|
||||
// meteredMsgReadWriter is a wrapper around a p2p.MsgReadWriter, capable of
|
||||
|
@ -336,7 +336,6 @@ type serverPeer struct {
|
||||
checkpointNumber uint64 // The block height which the checkpoint is registered.
|
||||
checkpoint params.TrustedCheckpoint // The advertised checkpoint sent by server.
|
||||
|
||||
poolEntry *poolEntry // Statistic for server peer.
|
||||
fcServer *flowcontrol.ServerNode // Client side mirror token bucket.
|
||||
vtLock sync.Mutex
|
||||
valueTracker *lpc.ValueTracker
|
||||
|
@ -130,7 +130,6 @@ func init() {
|
||||
}
|
||||
requestMapping[uint32(code)] = rm
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type errCode int
|
||||
|
@ -24,22 +24,20 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
"github.com/ethereum/go-ethereum/light"
|
||||
)
|
||||
|
||||
var (
|
||||
retryQueue = time.Millisecond * 100
|
||||
softRequestTimeout = time.Millisecond * 500
|
||||
hardRequestTimeout = time.Second * 10
|
||||
)
|
||||
|
||||
// retrieveManager is a layer on top of requestDistributor which takes care of
|
||||
// matching replies by request ID and handles timeouts and resends if necessary.
|
||||
type retrieveManager struct {
|
||||
dist *requestDistributor
|
||||
peers *serverPeerSet
|
||||
serverPool peerSelector
|
||||
dist *requestDistributor
|
||||
peers *serverPeerSet
|
||||
softRequestTimeout func() time.Duration
|
||||
|
||||
lock sync.RWMutex
|
||||
sentReqs map[uint64]*sentReq
|
||||
@ -48,11 +46,6 @@ type retrieveManager struct {
|
||||
// validatorFunc is a function that processes a reply message
|
||||
type validatorFunc func(distPeer, *Msg) error
|
||||
|
||||
// peerSelector receives feedback info about response times and timeouts
|
||||
type peerSelector interface {
|
||||
adjustResponseTime(*poolEntry, time.Duration, bool)
|
||||
}
|
||||
|
||||
// sentReq represents a request sent and tracked by retrieveManager
|
||||
type sentReq struct {
|
||||
rm *retrieveManager
|
||||
@ -99,12 +92,12 @@ const (
|
||||
)
|
||||
|
||||
// newRetrieveManager creates the retrieve manager
|
||||
func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, serverPool peerSelector) *retrieveManager {
|
||||
func newRetrieveManager(peers *serverPeerSet, dist *requestDistributor, srto func() time.Duration) *retrieveManager {
|
||||
return &retrieveManager{
|
||||
peers: peers,
|
||||
dist: dist,
|
||||
serverPool: serverPool,
|
||||
sentReqs: make(map[uint64]*sentReq),
|
||||
peers: peers,
|
||||
dist: dist,
|
||||
sentReqs: make(map[uint64]*sentReq),
|
||||
softRequestTimeout: srto,
|
||||
}
|
||||
}
|
||||
|
||||
@ -325,8 +318,7 @@ func (r *sentReq) tryRequest() {
|
||||
return
|
||||
}
|
||||
|
||||
reqSent := mclock.Now()
|
||||
srto, hrto := false, false
|
||||
hrto := false
|
||||
|
||||
r.lock.RLock()
|
||||
s, ok := r.sentTo[p]
|
||||
@ -338,11 +330,7 @@ func (r *sentReq) tryRequest() {
|
||||
defer func() {
|
||||
// send feedback to server pool and remove peer if hard timeout happened
|
||||
pp, ok := p.(*serverPeer)
|
||||
if ok && r.rm.serverPool != nil {
|
||||
respTime := time.Duration(mclock.Now() - reqSent)
|
||||
r.rm.serverPool.adjustResponseTime(pp.poolEntry, respTime, srto)
|
||||
}
|
||||
if hrto {
|
||||
if hrto && ok {
|
||||
pp.Log().Debug("Request timed out hard")
|
||||
if r.rm.peers != nil {
|
||||
r.rm.peers.unregister(pp.id)
|
||||
@ -363,8 +351,7 @@ func (r *sentReq) tryRequest() {
|
||||
}
|
||||
r.eventsCh <- reqPeerEvent{event, p}
|
||||
return
|
||||
case <-time.After(softRequestTimeout):
|
||||
srto = true
|
||||
case <-time.After(r.rm.softRequestTimeout()):
|
||||
r.eventsCh <- reqPeerEvent{rpSoftTimeout, p}
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,7 @@ func (s *LesServer) Protocols() []p2p.Protocol {
|
||||
return p.Info()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}, nil)
|
||||
// Add "les" ENR entries.
|
||||
for i := range ps {
|
||||
ps[i].Attributes = []enr.Entry{&lesEntry{}}
|
||||
|
1263
les/serverpool.go
1263
les/serverpool.go
File diff suppressed because it is too large
Load Diff
352
les/serverpool_test.go
Normal file
352
les/serverpool_test.go
Normal file
@ -0,0 +1,352 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package les
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common/mclock"
|
||||
"github.com/ethereum/go-ethereum/ethdb"
|
||||
"github.com/ethereum/go-ethereum/ethdb/memorydb"
|
||||
lpc "github.com/ethereum/go-ethereum/les/lespay/client"
|
||||
"github.com/ethereum/go-ethereum/p2p"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
)
|
||||
|
||||
const (
|
||||
spTestNodes = 1000
|
||||
spTestTarget = 5
|
||||
spTestLength = 10000
|
||||
spMinTotal = 40000
|
||||
spMaxTotal = 50000
|
||||
)
|
||||
|
||||
func testNodeID(i int) enode.ID {
|
||||
return enode.ID{42, byte(i % 256), byte(i / 256)}
|
||||
}
|
||||
|
||||
func testNodeIndex(id enode.ID) int {
|
||||
if id[0] != 42 {
|
||||
return -1
|
||||
}
|
||||
return int(id[1]) + int(id[2])*256
|
||||
}
|
||||
|
||||
type serverPoolTest struct {
|
||||
db ethdb.KeyValueStore
|
||||
clock *mclock.Simulated
|
||||
quit chan struct{}
|
||||
preNeg, preNegFail bool
|
||||
vt *lpc.ValueTracker
|
||||
sp *serverPool
|
||||
input enode.Iterator
|
||||
testNodes []spTestNode
|
||||
trusted []string
|
||||
waitCount, waitEnded int32
|
||||
|
||||
cycle, conn, servedConn int
|
||||
serviceCycles, dialCount int
|
||||
disconnect map[int][]int
|
||||
}
|
||||
|
||||
type spTestNode struct {
|
||||
connectCycles, waitCycles int
|
||||
nextConnCycle, totalConn int
|
||||
connected, service bool
|
||||
peer *serverPeer
|
||||
}
|
||||
|
||||
func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest {
|
||||
nodes := make([]*enode.Node, spTestNodes)
|
||||
for i := range nodes {
|
||||
nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i))
|
||||
}
|
||||
return &serverPoolTest{
|
||||
clock: &mclock.Simulated{},
|
||||
db: memorydb.New(),
|
||||
input: enode.CycleNodes(nodes),
|
||||
testNodes: make([]spTestNode, spTestNodes),
|
||||
preNeg: preNeg,
|
||||
preNegFail: preNegFail,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) beginWait() {
|
||||
// ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state
|
||||
for atomic.AddInt32(&s.waitCount, 1) > preNegLimit {
|
||||
atomic.AddInt32(&s.waitCount, -1)
|
||||
s.clock.Run(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) endWait() {
|
||||
atomic.AddInt32(&s.waitCount, -1)
|
||||
atomic.AddInt32(&s.waitEnded, 1)
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) addTrusted(i int) {
|
||||
s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String())
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) start() {
|
||||
var testQuery queryFunc
|
||||
if s.preNeg {
|
||||
testQuery = func(node *enode.Node) int {
|
||||
idx := testNodeIndex(node.ID())
|
||||
n := &s.testNodes[idx]
|
||||
canConnect := !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle
|
||||
if s.preNegFail {
|
||||
// simulate a scenario where UDP queries never work
|
||||
s.beginWait()
|
||||
s.clock.Sleep(time.Second * 5)
|
||||
s.endWait()
|
||||
return -1
|
||||
} else {
|
||||
switch idx % 3 {
|
||||
case 0:
|
||||
// pre-neg returns true only if connection is possible
|
||||
if canConnect {
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
case 1:
|
||||
// pre-neg returns true but connection might still fail
|
||||
return 1
|
||||
case 2:
|
||||
// pre-neg returns true if connection is possible, otherwise timeout (node unresponsive)
|
||||
if canConnect {
|
||||
return 1
|
||||
} else {
|
||||
s.beginWait()
|
||||
s.clock.Sleep(time.Second * 5)
|
||||
s.endWait()
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.vt = lpc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000))
|
||||
s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, s.input, 0, testQuery, s.clock, s.trusted)
|
||||
s.sp.validSchemes = enode.ValidSchemesForTesting
|
||||
s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) }
|
||||
s.disconnect = make(map[int][]int)
|
||||
s.sp.start()
|
||||
s.quit = make(chan struct{})
|
||||
go func() {
|
||||
last := int32(-1)
|
||||
for {
|
||||
select {
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
c := atomic.LoadInt32(&s.waitEnded)
|
||||
if c == last {
|
||||
// advance clock if test is stuck (might happen in rare cases)
|
||||
s.clock.Run(time.Second)
|
||||
}
|
||||
last = c
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) stop() {
|
||||
close(s.quit)
|
||||
s.sp.stop()
|
||||
s.vt.Stop()
|
||||
for i := range s.testNodes {
|
||||
n := &s.testNodes[i]
|
||||
if n.connected {
|
||||
n.totalConn += s.cycle
|
||||
}
|
||||
n.connected = false
|
||||
n.peer = nil
|
||||
n.nextConnCycle = 0
|
||||
}
|
||||
s.conn, s.servedConn = 0, 0
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) run() {
|
||||
for count := spTestLength; count > 0; count-- {
|
||||
if dcList := s.disconnect[s.cycle]; dcList != nil {
|
||||
for _, idx := range dcList {
|
||||
n := &s.testNodes[idx]
|
||||
s.sp.unregisterPeer(n.peer)
|
||||
n.totalConn += s.cycle
|
||||
n.connected = false
|
||||
n.peer = nil
|
||||
s.conn--
|
||||
if n.service {
|
||||
s.servedConn--
|
||||
}
|
||||
n.nextConnCycle = s.cycle + n.waitCycles
|
||||
}
|
||||
delete(s.disconnect, s.cycle)
|
||||
}
|
||||
if s.conn < spTestTarget {
|
||||
s.dialCount++
|
||||
s.beginWait()
|
||||
s.sp.dialIterator.Next()
|
||||
s.endWait()
|
||||
dial := s.sp.dialIterator.Node()
|
||||
id := dial.ID()
|
||||
idx := testNodeIndex(id)
|
||||
n := &s.testNodes[idx]
|
||||
if !n.connected && n.connectCycles != 0 && s.cycle >= n.nextConnCycle {
|
||||
s.conn++
|
||||
if n.service {
|
||||
s.servedConn++
|
||||
}
|
||||
n.totalConn -= s.cycle
|
||||
n.connected = true
|
||||
dc := s.cycle + n.connectCycles
|
||||
s.disconnect[dc] = append(s.disconnect[dc], idx)
|
||||
n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}}
|
||||
s.sp.registerPeer(n.peer)
|
||||
if n.service {
|
||||
s.vt.Served(s.vt.GetNode(id), []lpc.ServedRequest{{ReqType: 0, Amount: 100}}, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
s.serviceCycles += s.servedConn
|
||||
s.clock.Run(time.Second)
|
||||
s.cycle++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) {
|
||||
for ; count > 0; count-- {
|
||||
idx := rand.Intn(spTestNodes)
|
||||
for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected {
|
||||
idx = rand.Intn(spTestNodes)
|
||||
}
|
||||
res = append(res, idx)
|
||||
s.testNodes[idx] = spTestNode{
|
||||
connectCycles: conn,
|
||||
waitCycles: wait,
|
||||
service: service,
|
||||
}
|
||||
if trusted {
|
||||
s.addTrusted(idx)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) resetNodes() {
|
||||
for i, n := range s.testNodes {
|
||||
if n.connected {
|
||||
n.totalConn += s.cycle
|
||||
s.sp.unregisterPeer(n.peer)
|
||||
}
|
||||
s.testNodes[i] = spTestNode{totalConn: n.totalConn}
|
||||
}
|
||||
s.conn, s.servedConn = 0, 0
|
||||
s.disconnect = make(map[int][]int)
|
||||
s.trusted = nil
|
||||
}
|
||||
|
||||
func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) {
|
||||
var sum int
|
||||
for _, idx := range nodes {
|
||||
n := &s.testNodes[idx]
|
||||
if n.connected {
|
||||
n.totalConn += s.cycle
|
||||
}
|
||||
sum += n.totalConn
|
||||
n.totalConn = 0
|
||||
if n.connected {
|
||||
n.totalConn -= s.cycle
|
||||
}
|
||||
}
|
||||
if sum < spMinTotal || sum > spMaxTotal {
|
||||
t.Errorf("Total connection amount %d outside expected range %d to %d", sum, spMinTotal, spMaxTotal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerPool(t *testing.T) { testServerPool(t, false, false) }
|
||||
func TestServerPoolWithPreNeg(t *testing.T) { testServerPool(t, true, false) }
|
||||
func TestServerPoolWithPreNegFail(t *testing.T) { testServerPool(t, true, true) }
|
||||
func testServerPool(t *testing.T, preNeg, fail bool) {
|
||||
s := newServerPoolTest(preNeg, fail)
|
||||
nodes := s.setNodes(100, 200, 200, true, false)
|
||||
s.setNodes(100, 20, 20, false, false)
|
||||
s.start()
|
||||
s.run()
|
||||
s.stop()
|
||||
s.checkNodes(t, nodes)
|
||||
}
|
||||
|
||||
func TestServerPoolChangedNodes(t *testing.T) { testServerPoolChangedNodes(t, false) }
|
||||
func TestServerPoolChangedNodesWithPreNeg(t *testing.T) { testServerPoolChangedNodes(t, true) }
|
||||
func testServerPoolChangedNodes(t *testing.T, preNeg bool) {
|
||||
s := newServerPoolTest(preNeg, false)
|
||||
nodes := s.setNodes(100, 200, 200, true, false)
|
||||
s.setNodes(100, 20, 20, false, false)
|
||||
s.start()
|
||||
s.run()
|
||||
s.checkNodes(t, nodes)
|
||||
for i := 0; i < 3; i++ {
|
||||
s.resetNodes()
|
||||
nodes := s.setNodes(100, 200, 200, true, false)
|
||||
s.setNodes(100, 20, 20, false, false)
|
||||
s.run()
|
||||
s.checkNodes(t, nodes)
|
||||
}
|
||||
s.stop()
|
||||
}
|
||||
|
||||
func TestServerPoolRestartNoDiscovery(t *testing.T) { testServerPoolRestartNoDiscovery(t, false) }
|
||||
func TestServerPoolRestartNoDiscoveryWithPreNeg(t *testing.T) {
|
||||
testServerPoolRestartNoDiscovery(t, true)
|
||||
}
|
||||
func testServerPoolRestartNoDiscovery(t *testing.T, preNeg bool) {
|
||||
s := newServerPoolTest(preNeg, false)
|
||||
nodes := s.setNodes(100, 200, 200, true, false)
|
||||
s.setNodes(100, 20, 20, false, false)
|
||||
s.start()
|
||||
s.run()
|
||||
s.stop()
|
||||
s.checkNodes(t, nodes)
|
||||
s.input = nil
|
||||
s.start()
|
||||
s.run()
|
||||
s.stop()
|
||||
s.checkNodes(t, nodes)
|
||||
}
|
||||
|
||||
func TestServerPoolTrustedNoDiscovery(t *testing.T) { testServerPoolTrustedNoDiscovery(t, false) }
|
||||
func TestServerPoolTrustedNoDiscoveryWithPreNeg(t *testing.T) {
|
||||
testServerPoolTrustedNoDiscovery(t, true)
|
||||
}
|
||||
func testServerPoolTrustedNoDiscovery(t *testing.T, preNeg bool) {
|
||||
s := newServerPoolTest(preNeg, false)
|
||||
trusted := s.setNodes(200, 200, 200, true, true)
|
||||
s.input = nil
|
||||
s.start()
|
||||
s.run()
|
||||
s.stop()
|
||||
s.checkNodes(t, trusted)
|
||||
}
|
@ -508,7 +508,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer
|
||||
clock = &mclock.Simulated{}
|
||||
}
|
||||
dist := newRequestDistributor(speers, clock)
|
||||
rm := newRetrieveManager(speers, dist, nil)
|
||||
rm := newRetrieveManager(speers, dist, func() time.Duration { return time.Millisecond * 500 })
|
||||
odr := NewLesOdr(cdb, light.TestClientIndexerConfig, rm)
|
||||
|
||||
sindexers := testIndexers(sdb, nil, light.TestServerIndexerConfig)
|
||||
|
@ -63,14 +63,7 @@ func ExpFactor(logOffset Fixed64) ExpirationFactor {
|
||||
// Value calculates the expired value based on a floating point base and integer
|
||||
// power-of-2 exponent. This function should be used by multi-value expired structures.
|
||||
func (e ExpirationFactor) Value(base float64, exp uint64) float64 {
|
||||
res := base / e.Factor
|
||||
if exp > e.Exp {
|
||||
res *= float64(uint64(1) << (exp - e.Exp))
|
||||
}
|
||||
if exp < e.Exp {
|
||||
res /= float64(uint64(1) << (e.Exp - exp))
|
||||
}
|
||||
return res
|
||||
return base / e.Factor * math.Pow(2, float64(int64(exp-e.Exp)))
|
||||
}
|
||||
|
||||
// value calculates the value at the given moment.
|
||||
|
@ -16,28 +16,44 @@
|
||||
|
||||
package utils
|
||||
|
||||
import "math/rand"
|
||||
import (
|
||||
"math/rand"
|
||||
)
|
||||
|
||||
// wrsItem interface should be implemented by any entries that are to be selected from
|
||||
// a WeightedRandomSelect set. Note that recalculating monotonously decreasing item
|
||||
// weights on-demand (without constantly calling Update) is allowed
|
||||
type wrsItem interface {
|
||||
Weight() int64
|
||||
}
|
||||
|
||||
// WeightedRandomSelect is capable of weighted random selection from a set of items
|
||||
type WeightedRandomSelect struct {
|
||||
root *wrsNode
|
||||
idx map[wrsItem]int
|
||||
}
|
||||
type (
|
||||
// WeightedRandomSelect is capable of weighted random selection from a set of items
|
||||
WeightedRandomSelect struct {
|
||||
root *wrsNode
|
||||
idx map[WrsItem]int
|
||||
wfn WeightFn
|
||||
}
|
||||
WrsItem interface{}
|
||||
WeightFn func(interface{}) uint64
|
||||
)
|
||||
|
||||
// NewWeightedRandomSelect returns a new WeightedRandomSelect structure
|
||||
func NewWeightedRandomSelect() *WeightedRandomSelect {
|
||||
return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[wrsItem]int)}
|
||||
func NewWeightedRandomSelect(wfn WeightFn) *WeightedRandomSelect {
|
||||
return &WeightedRandomSelect{root: &wrsNode{maxItems: wrsBranches}, idx: make(map[WrsItem]int), wfn: wfn}
|
||||
}
|
||||
|
||||
// Update updates an item's weight, adds it if it was non-existent or removes it if
|
||||
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
|
||||
func (w *WeightedRandomSelect) Update(item WrsItem) {
|
||||
w.setWeight(item, w.wfn(item))
|
||||
}
|
||||
|
||||
// Remove removes an item from the set
|
||||
func (w *WeightedRandomSelect) Remove(item WrsItem) {
|
||||
w.setWeight(item, 0)
|
||||
}
|
||||
|
||||
// IsEmpty returns true if the set is empty
|
||||
func (w *WeightedRandomSelect) IsEmpty() bool {
|
||||
return w.root.sumWeight == 0
|
||||
}
|
||||
|
||||
// setWeight sets an item's weight to a specific value (removes it if zero)
|
||||
func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
|
||||
func (w *WeightedRandomSelect) setWeight(item WrsItem, weight uint64) {
|
||||
idx, ok := w.idx[item]
|
||||
if ok {
|
||||
w.root.setWeight(idx, weight)
|
||||
@ -58,33 +74,22 @@ func (w *WeightedRandomSelect) setWeight(item wrsItem, weight int64) {
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates an item's weight, adds it if it was non-existent or removes it if
|
||||
// the new weight is zero. Note that explicitly updating decreasing weights is not necessary.
|
||||
func (w *WeightedRandomSelect) Update(item wrsItem) {
|
||||
w.setWeight(item, item.Weight())
|
||||
}
|
||||
|
||||
// Remove removes an item from the set
|
||||
func (w *WeightedRandomSelect) Remove(item wrsItem) {
|
||||
w.setWeight(item, 0)
|
||||
}
|
||||
|
||||
// Choose randomly selects an item from the set, with a chance proportional to its
|
||||
// current weight. If the weight of the chosen element has been decreased since the
|
||||
// last stored value, returns it with a newWeight/oldWeight chance, otherwise just
|
||||
// updates its weight and selects another one
|
||||
func (w *WeightedRandomSelect) Choose() wrsItem {
|
||||
func (w *WeightedRandomSelect) Choose() WrsItem {
|
||||
for {
|
||||
if w.root.sumWeight == 0 {
|
||||
return nil
|
||||
}
|
||||
val := rand.Int63n(w.root.sumWeight)
|
||||
val := uint64(rand.Int63n(int64(w.root.sumWeight)))
|
||||
choice, lastWeight := w.root.choose(val)
|
||||
weight := choice.Weight()
|
||||
weight := w.wfn(choice)
|
||||
if weight != lastWeight {
|
||||
w.setWeight(choice, weight)
|
||||
}
|
||||
if weight >= lastWeight || rand.Int63n(lastWeight) < weight {
|
||||
if weight >= lastWeight || uint64(rand.Int63n(int64(lastWeight))) < weight {
|
||||
return choice
|
||||
}
|
||||
}
|
||||
@ -92,16 +97,16 @@ func (w *WeightedRandomSelect) Choose() wrsItem {
|
||||
|
||||
const wrsBranches = 8 // max number of branches in the wrsNode tree
|
||||
|
||||
// wrsNode is a node of a tree structure that can store wrsItems or further wrsNodes.
|
||||
// wrsNode is a node of a tree structure that can store WrsItems or further wrsNodes.
|
||||
type wrsNode struct {
|
||||
items [wrsBranches]interface{}
|
||||
weights [wrsBranches]int64
|
||||
sumWeight int64
|
||||
weights [wrsBranches]uint64
|
||||
sumWeight uint64
|
||||
level, itemCnt, maxItems int
|
||||
}
|
||||
|
||||
// insert recursively inserts a new item to the tree and returns the item index
|
||||
func (n *wrsNode) insert(item wrsItem, weight int64) int {
|
||||
func (n *wrsNode) insert(item WrsItem, weight uint64) int {
|
||||
branch := 0
|
||||
for n.items[branch] != nil && (n.level == 0 || n.items[branch].(*wrsNode).itemCnt == n.items[branch].(*wrsNode).maxItems) {
|
||||
branch++
|
||||
@ -129,7 +134,7 @@ func (n *wrsNode) insert(item wrsItem, weight int64) int {
|
||||
|
||||
// setWeight updates the weight of a certain item (which should exist) and returns
|
||||
// the change of the last weight value stored in the tree
|
||||
func (n *wrsNode) setWeight(idx int, weight int64) int64 {
|
||||
func (n *wrsNode) setWeight(idx int, weight uint64) uint64 {
|
||||
if n.level == 0 {
|
||||
oldWeight := n.weights[idx]
|
||||
n.weights[idx] = weight
|
||||
@ -152,12 +157,12 @@ func (n *wrsNode) setWeight(idx int, weight int64) int64 {
|
||||
return diff
|
||||
}
|
||||
|
||||
// Choose recursively selects an item from the tree and returns it along with its weight
|
||||
func (n *wrsNode) choose(val int64) (wrsItem, int64) {
|
||||
// choose recursively selects an item from the tree and returns it along with its weight
|
||||
func (n *wrsNode) choose(val uint64) (WrsItem, uint64) {
|
||||
for i, w := range n.weights {
|
||||
if val < w {
|
||||
if n.level == 0 {
|
||||
return n.items[i].(wrsItem), n.weights[i]
|
||||
return n.items[i].(WrsItem), n.weights[i]
|
||||
}
|
||||
return n.items[i].(*wrsNode).choose(val)
|
||||
}
|
||||
|
@ -26,17 +26,18 @@ type testWrsItem struct {
|
||||
widx *int
|
||||
}
|
||||
|
||||
func (t *testWrsItem) Weight() int64 {
|
||||
func testWeight(i interface{}) uint64 {
|
||||
t := i.(*testWrsItem)
|
||||
w := *t.widx
|
||||
if w == -1 || w == t.idx {
|
||||
return int64(t.idx + 1)
|
||||
return uint64(t.idx + 1)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func TestWeightedRandomSelect(t *testing.T) {
|
||||
testFn := func(cnt int) {
|
||||
s := NewWeightedRandomSelect()
|
||||
s := NewWeightedRandomSelect(testWeight)
|
||||
w := -1
|
||||
list := make([]testWrsItem, cnt)
|
||||
for i := range list {
|
||||
|
Reference in New Issue
Block a user