p2p: enforce connection retry limit on server side (#19684)
The dialer limits itself to one attempt every 30s. Apply the same limit in Server and reject peers which try to connect too eagerly. The check against the limit happens right after accepting the connection. Further changes in this commit ensure we pass the Server logger down to Peer instances, discovery and dialState. Unit test logging now works in all Server tests.
This commit is contained in:
@ -19,6 +19,7 @@ package p2p
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"errors"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
@ -26,6 +27,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/internal/testlog"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/p2p/enode"
|
||||
"github.com/ethereum/go-ethereum/p2p/enr"
|
||||
@ -74,6 +76,7 @@ func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *
|
||||
MaxPeers: 10,
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
PrivateKey: newkey(),
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
server := &Server{
|
||||
Config: config,
|
||||
@ -359,6 +362,7 @@ func TestServerAtCap(t *testing.T) {
|
||||
PrivateKey: newkey(),
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
TrustedNodes: []*enode.Node{newNode(trustedID, nil)},
|
||||
},
|
||||
}
|
||||
@ -377,19 +381,19 @@ func TestServerAtCap(t *testing.T) {
|
||||
// Inject a few connections to fill up the peer set.
|
||||
for i := 0; i < 10; i++ {
|
||||
c := newconn(randomID())
|
||||
if err := srv.checkpoint(c, srv.addpeer); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil {
|
||||
t.Fatalf("could not add conn %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
// Try inserting a non-trusted connection.
|
||||
anotherID := randomID()
|
||||
c := newconn(anotherID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||
t.Error("wrong error for insert:", err)
|
||||
}
|
||||
// Try inserting a trusted connection.
|
||||
c = newconn(trustedID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||
}
|
||||
if !c.is(trustedConn) {
|
||||
@ -399,14 +403,14 @@ func TestServerAtCap(t *testing.T) {
|
||||
// Remove from trusted set and try again
|
||||
srv.RemoveTrustedPeer(newNode(trustedID, nil))
|
||||
c = newconn(trustedID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers {
|
||||
t.Error("wrong error for insert:", err)
|
||||
}
|
||||
|
||||
// Add anotherID to trusted set and try again
|
||||
srv.AddTrustedPeer(newNode(anotherID, nil))
|
||||
c = newconn(anotherID)
|
||||
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
|
||||
if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil {
|
||||
t.Error("unexpected error for trusted conn @posthandshake:", err)
|
||||
}
|
||||
if !c.is(trustedConn) {
|
||||
@ -430,10 +434,11 @@ func TestServerPeerLimits(t *testing.T) {
|
||||
|
||||
srv := &Server{
|
||||
Config: Config{
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 0,
|
||||
NoDial: true,
|
||||
Protocols: []Protocol{discard},
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 0,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
},
|
||||
newTransport: func(fd net.Conn) transport { return tp },
|
||||
log: log.New(),
|
||||
@ -541,29 +546,35 @@ func TestServerSetupConn(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
srv := &Server{
|
||||
Config: Config{
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
Protocols: []Protocol{discard},
|
||||
},
|
||||
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||
log: log.New(),
|
||||
}
|
||||
if !test.dontstart {
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatalf("couldn't start server: %v", err)
|
||||
t.Run(test.wantCalls, func(t *testing.T) {
|
||||
cfg := Config{
|
||||
PrivateKey: srvkey,
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
}
|
||||
}
|
||||
p1, _ := net.Pipe()
|
||||
srv.SetupConn(p1, test.flags, test.dialDest)
|
||||
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
|
||||
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
|
||||
}
|
||||
if test.tt.calls != test.wantCalls {
|
||||
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||
}
|
||||
srv := &Server{
|
||||
Config: cfg,
|
||||
newTransport: func(fd net.Conn) transport { return test.tt },
|
||||
log: cfg.Logger,
|
||||
}
|
||||
if !test.dontstart {
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatalf("couldn't start server: %v", err)
|
||||
}
|
||||
defer srv.Stop()
|
||||
}
|
||||
p1, _ := net.Pipe()
|
||||
srv.SetupConn(p1, test.flags, test.dialDest)
|
||||
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
|
||||
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
|
||||
}
|
||||
if test.tt.calls != test.wantCalls {
|
||||
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -616,3 +627,100 @@ func randomID() (id enode.ID) {
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// This test checks that inbound connections are throttled by IP.
|
||||
func TestServerInboundThrottle(t *testing.T) {
|
||||
const timeout = 5 * time.Second
|
||||
newTransportCalled := make(chan struct{})
|
||||
srv := &Server{
|
||||
Config: Config{
|
||||
PrivateKey: newkey(),
|
||||
ListenAddr: "127.0.0.1:0",
|
||||
MaxPeers: 10,
|
||||
NoDial: true,
|
||||
NoDiscovery: true,
|
||||
Protocols: []Protocol{discard},
|
||||
Logger: testlog.Logger(t, log.LvlTrace),
|
||||
},
|
||||
newTransport: func(fd net.Conn) transport {
|
||||
newTransportCalled <- struct{}{}
|
||||
return newRLPX(fd)
|
||||
},
|
||||
listenFunc: func(network, laddr string) (net.Listener, error) {
|
||||
fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444}
|
||||
return listenFakeAddr(network, laddr, fakeAddr)
|
||||
},
|
||||
}
|
||||
if err := srv.Start(); err != nil {
|
||||
t.Fatal("can't start: ", err)
|
||||
}
|
||||
defer srv.Stop()
|
||||
|
||||
// Dial the test server.
|
||||
conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||
if err != nil {
|
||||
t.Fatalf("could not dial: %v", err)
|
||||
}
|
||||
select {
|
||||
case <-newTransportCalled:
|
||||
// OK
|
||||
case <-time.After(timeout):
|
||||
t.Error("newTransport not called")
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
// Dial again. This time the server should close the connection immediately.
|
||||
connClosed := make(chan struct{})
|
||||
conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout)
|
||||
if err != nil {
|
||||
t.Fatalf("could not dial: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
go func() {
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
buf := make([]byte, 10)
|
||||
if n, err := conn.Read(buf); err != io.EOF || n != 0 {
|
||||
t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n)
|
||||
}
|
||||
connClosed <- struct{}{}
|
||||
}()
|
||||
select {
|
||||
case <-connClosed:
|
||||
// OK
|
||||
case <-newTransportCalled:
|
||||
t.Error("newTransport called for second attempt")
|
||||
case <-time.After(timeout):
|
||||
t.Error("connection not closed within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) {
|
||||
l, err := net.Listen(network, laddr)
|
||||
if err == nil {
|
||||
l = &fakeAddrListener{l, remoteAddr}
|
||||
}
|
||||
return l, err
|
||||
}
|
||||
|
||||
// fakeAddrListener is a listener that creates connections with a mocked remote address.
|
||||
type fakeAddrListener struct {
|
||||
net.Listener
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
type fakeAddrConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (l *fakeAddrListener) Accept() (net.Conn, error) {
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fakeAddrConn{c, l.remoteAddr}, nil
|
||||
}
|
||||
|
||||
func (c *fakeAddrConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
Reference in New Issue
Block a user