Merge pull request #3325 from fjl/p2p-netrestrict
Prevent relay of invalid IPs, add --netrestrict
This commit is contained in:
		| @@ -29,6 +29,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discv5" | 	"github.com/ethereum/go-ethereum/p2p/discv5" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func main() { | func main() { | ||||||
| @@ -39,6 +40,7 @@ func main() { | |||||||
| 		nodeKeyFile = flag.String("nodekey", "", "private key filename") | 		nodeKeyFile = flag.String("nodekey", "", "private key filename") | ||||||
| 		nodeKeyHex  = flag.String("nodekeyhex", "", "private key as hex (for testing)") | 		nodeKeyHex  = flag.String("nodekeyhex", "", "private key as hex (for testing)") | ||||||
| 		natdesc     = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)") | 		natdesc     = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:<IP>)") | ||||||
|  | 		netrestrict = flag.String("netrestrict", "", "restrict network communication to the given IP networks (CIDR masks)") | ||||||
| 		runv5       = flag.Bool("v5", false, "run a v5 topic discovery bootnode") | 		runv5       = flag.Bool("v5", false, "run a v5 topic discovery bootnode") | ||||||
|  |  | ||||||
| 		nodeKey *ecdsa.PrivateKey | 		nodeKey *ecdsa.PrivateKey | ||||||
| @@ -81,12 +83,20 @@ func main() { | |||||||
| 		os.Exit(0) | 		os.Exit(0) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	var restrictList *netutil.Netlist | ||||||
|  | 	if *netrestrict != "" { | ||||||
|  | 		restrictList, err = netutil.ParseNetlist(*netrestrict) | ||||||
|  | 		if err != nil { | ||||||
|  | 			utils.Fatalf("-netrestrict: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if *runv5 { | 	if *runv5 { | ||||||
| 		if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { | 		if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { | ||||||
| 			utils.Fatalf("%v", err) | 			utils.Fatalf("%v", err) | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { | 		if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { | ||||||
| 			utils.Fatalf("%v", err) | 			utils.Fatalf("%v", err) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -96,6 +96,7 @@ func init() { | |||||||
| 		utils.BootnodesFlag, | 		utils.BootnodesFlag, | ||||||
| 		utils.KeyStoreDirFlag, | 		utils.KeyStoreDirFlag, | ||||||
| 		utils.ListenPortFlag, | 		utils.ListenPortFlag, | ||||||
|  | 		utils.NetrestrictFlag, | ||||||
| 		utils.MaxPeersFlag, | 		utils.MaxPeersFlag, | ||||||
| 		utils.NATFlag, | 		utils.NATFlag, | ||||||
| 		utils.NodeKeyFileFlag, | 		utils.NodeKeyFileFlag, | ||||||
|   | |||||||
| @@ -148,6 +148,7 @@ participating. | |||||||
| 		utils.NatspecEnabledFlag, | 		utils.NatspecEnabledFlag, | ||||||
| 		utils.NoDiscoverFlag, | 		utils.NoDiscoverFlag, | ||||||
| 		utils.DiscoveryV5Flag, | 		utils.DiscoveryV5Flag, | ||||||
|  | 		utils.NetrestrictFlag, | ||||||
| 		utils.NodeKeyFileFlag, | 		utils.NodeKeyFileFlag, | ||||||
| 		utils.NodeKeyHexFlag, | 		utils.NodeKeyHexFlag, | ||||||
| 		utils.RPCEnabledFlag, | 		utils.RPCEnabledFlag, | ||||||
|   | |||||||
| @@ -45,6 +45,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discv5" | 	"github.com/ethereum/go-ethereum/p2p/discv5" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/params" | 	"github.com/ethereum/go-ethereum/params" | ||||||
| 	"github.com/ethereum/go-ethereum/pow" | 	"github.com/ethereum/go-ethereum/pow" | ||||||
| 	"github.com/ethereum/go-ethereum/rpc" | 	"github.com/ethereum/go-ethereum/rpc" | ||||||
| @@ -366,10 +367,16 @@ var ( | |||||||
| 		Name:  "v5disc", | 		Name:  "v5disc", | ||||||
| 		Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism", | 		Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism", | ||||||
| 	} | 	} | ||||||
|  | 	NetrestrictFlag = cli.StringFlag{ | ||||||
|  | 		Name:  "netrestrict", | ||||||
|  | 		Usage: "Restricts network communication to the given IP networks (CIDR masks)", | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	WhisperEnabledFlag = cli.BoolFlag{ | 	WhisperEnabledFlag = cli.BoolFlag{ | ||||||
| 		Name:  "shh", | 		Name:  "shh", | ||||||
| 		Usage: "Enable Whisper", | 		Usage: "Enable Whisper", | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// ATM the url is left to the user and deployment to | 	// ATM the url is left to the user and deployment to | ||||||
| 	JSpathFlag = cli.StringFlag{ | 	JSpathFlag = cli.StringFlag{ | ||||||
| 		Name:  "jspath", | 		Name:  "jspath", | ||||||
| @@ -693,6 +700,14 @@ func MakeNode(ctx *cli.Context, name, gitCommit string) *node.Node { | |||||||
| 		config.MaxPeers = 0 | 		config.MaxPeers = 0 | ||||||
| 		config.ListenAddr = ":0" | 		config.ListenAddr = ":0" | ||||||
| 	} | 	} | ||||||
|  | 	if netrestrict := ctx.GlobalString(NetrestrictFlag.Name); netrestrict != "" { | ||||||
|  | 		list, err := netutil.ParseNetlist(netrestrict) | ||||||
|  | 		if err != nil { | ||||||
|  | 			Fatalf("Option %q: %v", NetrestrictFlag.Name, err) | ||||||
|  | 		} | ||||||
|  | 		config.NetRestrict = list | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	stack, err := node.New(config) | 	stack, err := node.New(config) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		Fatalf("Failed to create the protocol stack: %v", err) | 		Fatalf("Failed to create the protocol stack: %v", err) | ||||||
|   | |||||||
| @@ -34,6 +34,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discv5" | 	"github.com/ethereum/go-ethereum/p2p/discv5" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -103,6 +104,10 @@ type Config struct { | |||||||
| 	// Listener address for the V5 discovery protocol UDP traffic. | 	// Listener address for the V5 discovery protocol UDP traffic. | ||||||
| 	DiscoveryV5Addr string | 	DiscoveryV5Addr string | ||||||
|  |  | ||||||
|  | 	// Restrict communication to white listed IP networks. | ||||||
|  | 	// The whitelist only applies when non-nil. | ||||||
|  | 	NetRestrict *netutil.Netlist | ||||||
|  |  | ||||||
| 	// BootstrapNodes used to establish connectivity with the rest of the network. | 	// BootstrapNodes used to establish connectivity with the rest of the network. | ||||||
| 	BootstrapNodes []*discover.Node | 	BootstrapNodes []*discover.Node | ||||||
|  |  | ||||||
|   | |||||||
| @@ -165,6 +165,7 @@ func (n *Node) Start() error { | |||||||
| 		TrustedNodes:     n.config.TrusterNodes(), | 		TrustedNodes:     n.config.TrusterNodes(), | ||||||
| 		NodeDatabase:     n.config.NodeDB(), | 		NodeDatabase:     n.config.NodeDB(), | ||||||
| 		ListenAddr:       n.config.ListenAddr, | 		ListenAddr:       n.config.ListenAddr, | ||||||
|  | 		NetRestrict:      n.config.NetRestrict, | ||||||
| 		NAT:              n.config.NAT, | 		NAT:              n.config.NAT, | ||||||
| 		Dialer:           n.config.Dialer, | 		Dialer:           n.config.Dialer, | ||||||
| 		NoDial:           n.config.NoDial, | 		NoDial:           n.config.NoDial, | ||||||
|   | |||||||
							
								
								
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								p2p/dial.go
									
									
									
									
									
								
							| @@ -19,6 +19,7 @@ package p2p | |||||||
| import ( | import ( | ||||||
| 	"container/heap" | 	"container/heap" | ||||||
| 	"crypto/rand" | 	"crypto/rand" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"time" | 	"time" | ||||||
| @@ -26,6 +27,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -48,6 +50,7 @@ const ( | |||||||
| type dialstate struct { | type dialstate struct { | ||||||
| 	maxDynDials int | 	maxDynDials int | ||||||
| 	ntab        discoverTable | 	ntab        discoverTable | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
|  |  | ||||||
| 	lookupRunning bool | 	lookupRunning bool | ||||||
| 	dialing       map[discover.NodeID]connFlag | 	dialing       map[discover.NodeID]connFlag | ||||||
| @@ -100,10 +103,11 @@ type waitExpireTask struct { | |||||||
| 	time.Duration | 	time.Duration | ||||||
| } | } | ||||||
|  |  | ||||||
| func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { | func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { | ||||||
| 	s := &dialstate{ | 	s := &dialstate{ | ||||||
| 		maxDynDials: maxdyn, | 		maxDynDials: maxdyn, | ||||||
| 		ntab:        ntab, | 		ntab:        ntab, | ||||||
|  | 		netrestrict: netrestrict, | ||||||
| 		static:      make(map[discover.NodeID]*dialTask), | 		static:      make(map[discover.NodeID]*dialTask), | ||||||
| 		dialing:     make(map[discover.NodeID]connFlag), | 		dialing:     make(map[discover.NodeID]connFlag), | ||||||
| 		randomNodes: make([]*discover.Node, maxdyn/2), | 		randomNodes: make([]*discover.Node, maxdyn/2), | ||||||
| @@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) { | |||||||
|  |  | ||||||
| func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { | func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { | ||||||
| 	var newtasks []task | 	var newtasks []task | ||||||
| 	isDialing := func(id discover.NodeID) bool { |  | ||||||
| 		_, found := s.dialing[id] |  | ||||||
| 		return found || peers[id] != nil || s.hist.contains(id) |  | ||||||
| 	} |  | ||||||
| 	addDial := func(flag connFlag, n *discover.Node) bool { | 	addDial := func(flag connFlag, n *discover.Node) bool { | ||||||
| 		if isDialing(n.ID) { | 		if err := s.checkDial(n, peers); err != nil { | ||||||
|  | 			glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err) | ||||||
| 			return false | 			return false | ||||||
| 		} | 		} | ||||||
| 		s.dialing[n.ID] = flag | 		s.dialing[n.ID] = flag | ||||||
| @@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now | |||||||
|  |  | ||||||
| 	// Create dials for static nodes if they are not connected. | 	// Create dials for static nodes if they are not connected. | ||||||
| 	for id, t := range s.static { | 	for id, t := range s.static { | ||||||
| 		if !isDialing(id) { | 		err := s.checkDial(t.dest, peers) | ||||||
|  | 		switch err { | ||||||
|  | 		case errNotWhitelisted, errSelf: | ||||||
|  | 			glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err) | ||||||
|  | 			delete(s.static, t.dest.ID) | ||||||
|  | 		case nil: | ||||||
| 			s.dialing[id] = t.flags | 			s.dialing[id] = t.flags | ||||||
| 			newtasks = append(newtasks, t) | 			newtasks = append(newtasks, t) | ||||||
| 		} | 		} | ||||||
| @@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now | |||||||
| 	return newtasks | 	return newtasks | ||||||
| } | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	errSelf             = errors.New("is self") | ||||||
|  | 	errAlreadyDialing   = errors.New("already dialing") | ||||||
|  | 	errAlreadyConnected = errors.New("already connected") | ||||||
|  | 	errRecentlyDialed   = errors.New("recently dialed") | ||||||
|  | 	errNotWhitelisted   = errors.New("not contained in netrestrict whitelist") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { | ||||||
|  | 	_, dialing := s.dialing[n.ID] | ||||||
|  | 	switch { | ||||||
|  | 	case dialing: | ||||||
|  | 		return errAlreadyDialing | ||||||
|  | 	case peers[n.ID] != nil: | ||||||
|  | 		return errAlreadyConnected | ||||||
|  | 	case s.ntab != nil && n.ID == s.ntab.Self().ID: | ||||||
|  | 		return errSelf | ||||||
|  | 	case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): | ||||||
|  | 		return errNotWhitelisted | ||||||
|  | 	case s.hist.contains(n.ID): | ||||||
|  | 		return errRecentlyDialed | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *dialstate) taskDone(t task, now time.Time) { | func (s *dialstate) taskDone(t task, now time.Time) { | ||||||
| 	switch t := t.(type) { | 	switch t := t.(type) { | ||||||
| 	case *dialTask: | 	case *dialTask: | ||||||
|   | |||||||
| @@ -25,6 +25,7 @@ import ( | |||||||
|  |  | ||||||
| 	"github.com/davecgh/go-spew/spew" | 	"github.com/davecgh/go-spew/spew" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func init() { | func init() { | ||||||
| @@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, | |||||||
| // This test checks that dynamic dials are launched from discovery results. | // This test checks that dynamic dials are launched from discovery results. | ||||||
| func TestDialStateDynDial(t *testing.T) { | func TestDialStateDynDial(t *testing.T) { | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(nil, fakeTable{}, 5), | 		init: newDialState(nil, fakeTable{}, 5, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// A discovery query is launched. | 			// A discovery query is launched. | ||||||
| 			{ | 			{ | ||||||
| @@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(nil, table, 10), | 		init: newDialState(nil, table, 10, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. | 			// 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. | ||||||
| 			{ | 			{ | ||||||
| @@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // This test checks that candidates that do not match the netrestrict list are not dialed. | ||||||
|  | func TestDialStateNetRestrict(t *testing.T) { | ||||||
|  | 	// This table always returns the same random nodes | ||||||
|  | 	// in the order given below. | ||||||
|  | 	table := fakeTable{ | ||||||
|  | 		{ID: uintID(1), IP: net.ParseIP("127.0.0.1")}, | ||||||
|  | 		{ID: uintID(2), IP: net.ParseIP("127.0.0.2")}, | ||||||
|  | 		{ID: uintID(3), IP: net.ParseIP("127.0.0.3")}, | ||||||
|  | 		{ID: uintID(4), IP: net.ParseIP("127.0.0.4")}, | ||||||
|  | 		{ID: uintID(5), IP: net.ParseIP("127.0.2.5")}, | ||||||
|  | 		{ID: uintID(6), IP: net.ParseIP("127.0.2.6")}, | ||||||
|  | 		{ID: uintID(7), IP: net.ParseIP("127.0.2.7")}, | ||||||
|  | 		{ID: uintID(8), IP: net.ParseIP("127.0.2.8")}, | ||||||
|  | 	} | ||||||
|  | 	restrict := new(netutil.Netlist) | ||||||
|  | 	restrict.Add("127.0.2.0/24") | ||||||
|  |  | ||||||
|  | 	runDialTest(t, dialtest{ | ||||||
|  | 		init: newDialState(nil, table, 10, restrict), | ||||||
|  | 		rounds: []round{ | ||||||
|  | 			{ | ||||||
|  | 				new: []task{ | ||||||
|  | 					&dialTask{flags: dynDialedConn, dest: table[4]}, | ||||||
|  | 					&discoverTask{}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| // This test checks that static dials are launched. | // This test checks that static dials are launched. | ||||||
| func TestDialStateStaticDial(t *testing.T) { | func TestDialStateStaticDial(t *testing.T) { | ||||||
| 	wantStatic := []*discover.Node{ | 	wantStatic := []*discover.Node{ | ||||||
| @@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(wantStatic, fakeTable{}, 0), | 		init: newDialState(wantStatic, fakeTable{}, 0, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// Static dials are launched for the nodes that | 			// Static dials are launched for the nodes that | ||||||
| 			// aren't yet connected. | 			// aren't yet connected. | ||||||
| @@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	runDialTest(t, dialtest{ | 	runDialTest(t, dialtest{ | ||||||
| 		init: newDialState(wantStatic, fakeTable{}, 0), | 		init: newDialState(wantStatic, fakeTable{}, 0, nil), | ||||||
| 		rounds: []round{ | 		rounds: []round{ | ||||||
| 			// Static dials are launched for the nodes that | 			// Static dials are launched for the nodes that | ||||||
| 			// aren't yet connected. | 			// aren't yet connected. | ||||||
| @@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) { | |||||||
| func TestDialResolve(t *testing.T) { | func TestDialResolve(t *testing.T) { | ||||||
| 	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) | 	resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) | ||||||
| 	table := &resolveMock{answer: resolved} | 	table := &resolveMock{answer: resolved} | ||||||
| 	state := newDialState(nil, table, 0) | 	state := newDialState(nil, table, 0, nil) | ||||||
|  |  | ||||||
| 	// Check that the task is generated with an incomplete ID. | 	// Check that the task is generated with an incomplete ID. | ||||||
| 	dest := discover.NewNode(uintID(1), nil, 0, 0) | 	dest := discover.NewNode(uintID(1), nil, 0, 0) | ||||||
|   | |||||||
| @@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) { | |||||||
| func nodeAtDistance(base common.Hash, ld int) (n *Node) { | func nodeAtDistance(base common.Hash, ld int) (n *Node) { | ||||||
| 	n = new(Node) | 	n = new(Node) | ||||||
| 	n.sha = hashAtDistance(base, ld) | 	n.sha = hashAtDistance(base, ld) | ||||||
|  | 	n.IP = net.IP{10, 0, 2, byte(ld)} | ||||||
| 	copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID | 	copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID | ||||||
| 	return n | 	return n | ||||||
| } | } | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/rlp" | 	"github.com/ethereum/go-ethereum/rlp" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -126,8 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { | |||||||
| 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} | 	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} | ||||||
| } | } | ||||||
|  |  | ||||||
| func nodeFromRPC(rn rpcNode) (*Node, error) { | func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { | ||||||
| 	// TODO: don't accept localhost, LAN addresses from internet hosts | 	if rn.UDP <= 1024 { | ||||||
|  | 		return nil, errors.New("low port") | ||||||
|  | 	} | ||||||
|  | 	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { | ||||||
|  | 		return nil, errors.New("not contained in netrestrict whitelist") | ||||||
|  | 	} | ||||||
| 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | ||||||
| 	err := n.validateComplete() | 	err := n.validateComplete() | ||||||
| 	return n, err | 	return n, err | ||||||
| @@ -151,6 +160,7 @@ type conn interface { | |||||||
| // udp implements the RPC protocol. | // udp implements the RPC protocol. | ||||||
| type udp struct { | type udp struct { | ||||||
| 	conn        conn | 	conn        conn | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
| 	priv        *ecdsa.PrivateKey | 	priv        *ecdsa.PrivateKey | ||||||
| 	ourEndpoint rpcEndpoint | 	ourEndpoint rpcEndpoint | ||||||
|  |  | ||||||
| @@ -201,7 +211,7 @@ type reply struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // ListenUDP returns a new table that listens for UDP packets on laddr. | // ListenUDP returns a new table that listens for UDP packets on laddr. | ||||||
| func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { | func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { | ||||||
| 	addr, err := net.ResolveUDPAddr("udp", laddr) | 	addr, err := net.ResolveUDPAddr("udp", laddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -210,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	tab, _, err := newUDP(priv, conn, natm, nodeDBPath) | 	tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -218,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP | |||||||
| 	return tab, nil | 	return tab, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { | func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { | ||||||
| 	udp := &udp{ | 	udp := &udp{ | ||||||
| 		conn:       c, | 		conn:        c, | ||||||
| 		priv:       priv, | 		priv:        priv, | ||||||
| 		closing:    make(chan struct{}), | 		netrestrict: netrestrict, | ||||||
| 		gotreply:   make(chan reply), | 		closing:     make(chan struct{}), | ||||||
| 		addpending: make(chan *pending), | 		gotreply:    make(chan reply), | ||||||
|  | 		addpending:  make(chan *pending), | ||||||
| 	} | 	} | ||||||
| 	realaddr := c.LocalAddr().(*net.UDPAddr) | 	realaddr := c.LocalAddr().(*net.UDPAddr) | ||||||
| 	if natm != nil { | 	if natm != nil { | ||||||
| @@ -281,9 +292,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node | |||||||
| 		reply := r.(*neighbors) | 		reply := r.(*neighbors) | ||||||
| 		for _, rn := range reply.Nodes { | 		for _, rn := range reply.Nodes { | ||||||
| 			nreceived++ | 			nreceived++ | ||||||
| 			if n, err := nodeFromRPC(rn); err == nil { | 			n, err := t.nodeFromRPC(toaddr, rn) | ||||||
| 				nodes = append(nodes, n) | 			if err != nil { | ||||||
|  | 				glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) | ||||||
|  | 				continue | ||||||
| 			} | 			} | ||||||
|  | 			nodes = append(nodes, n) | ||||||
| 		} | 		} | ||||||
| 		return nreceived >= bucketSize | 		return nreceived >= bucketSize | ||||||
| 	}) | 	}) | ||||||
| @@ -479,13 +493,6 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, | |||||||
| 	return packet, nil | 	return packet, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func isTemporaryError(err error) bool { |  | ||||||
| 	tempErr, ok := err.(interface { |  | ||||||
| 		Temporary() bool |  | ||||||
| 	}) |  | ||||||
| 	return ok && tempErr.Temporary() || isPacketTooBig(err) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // readLoop runs in its own goroutine. it handles incoming UDP packets. | // readLoop runs in its own goroutine. it handles incoming UDP packets. | ||||||
| func (t *udp) readLoop() { | func (t *udp) readLoop() { | ||||||
| 	defer t.conn.Close() | 	defer t.conn.Close() | ||||||
| @@ -495,7 +502,7 @@ func (t *udp) readLoop() { | |||||||
| 	buf := make([]byte, 1280) | 	buf := make([]byte, 1280) | ||||||
| 	for { | 	for { | ||||||
| 		nbytes, from, err := t.conn.ReadFromUDP(buf) | 		nbytes, from, err := t.conn.ReadFromUDP(buf) | ||||||
| 		if isTemporaryError(err) { | 		if netutil.IsTemporaryError(err) { | ||||||
| 			// Ignore temporary read errors. | 			// Ignore temporary read errors. | ||||||
| 			glog.V(logger.Debug).Infof("Temporary read error: %v", err) | 			glog.V(logger.Debug).Infof("Temporary read error: %v", err) | ||||||
| 			continue | 			continue | ||||||
| @@ -602,6 +609,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte | |||||||
| 	// Send neighbors in chunks with at most maxNeighbors per packet | 	// Send neighbors in chunks with at most maxNeighbors per packet | ||||||
| 	// to stay below the 1280 byte limit. | 	// to stay below the 1280 byte limit. | ||||||
| 	for i, n := range closest { | 	for i, n := range closest { | ||||||
|  | 		if netutil.CheckRelayIP(from.IP, n.IP) != nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
| 		p.Nodes = append(p.Nodes, nodeToRPC(n)) | 		p.Nodes = append(p.Nodes, nodeToRPC(n)) | ||||||
| 		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { | 		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { | ||||||
| 			t.send(from, neighborsPacket, p) | 			t.send(from, neighborsPacket, p) | ||||||
|   | |||||||
| @@ -43,56 +43,6 @@ func init() { | |||||||
| 	spew.Config.DisableMethods = true | 	spew.Config.DisableMethods = true | ||||||
| } | } | ||||||
|  |  | ||||||
| // This test checks that isPacketTooBig correctly identifies |  | ||||||
| // errors that result from receiving a UDP packet larger |  | ||||||
| // than the supplied receive buffer. |  | ||||||
| func TestIsPacketTooBig(t *testing.T) { |  | ||||||
| 	listener, err := net.ListenPacket("udp", "127.0.0.1:0") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer listener.Close() |  | ||||||
| 	sender, err := net.Dial("udp", listener.LocalAddr().String()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer sender.Close() |  | ||||||
|  |  | ||||||
| 	sendN := 1800 |  | ||||||
| 	recvN := 300 |  | ||||||
| 	for i := 0; i < 20; i++ { |  | ||||||
| 		go func() { |  | ||||||
| 			buf := make([]byte, sendN) |  | ||||||
| 			for i := range buf { |  | ||||||
| 				buf[i] = byte(i) |  | ||||||
| 			} |  | ||||||
| 			sender.Write(buf) |  | ||||||
| 		}() |  | ||||||
|  |  | ||||||
| 		buf := make([]byte, recvN) |  | ||||||
| 		listener.SetDeadline(time.Now().Add(1 * time.Second)) |  | ||||||
| 		n, _, err := listener.ReadFrom(buf) |  | ||||||
| 		if err != nil { |  | ||||||
| 			if nerr, ok := err.(net.Error); ok && nerr.Timeout() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if !isPacketTooBig(err) { |  | ||||||
| 				t.Fatal("unexpected read error:", spew.Sdump(err)) |  | ||||||
| 			} |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if n != recvN { |  | ||||||
| 			t.Fatalf("short read: %d, want %d", n, recvN) |  | ||||||
| 		} |  | ||||||
| 		for i := range buf { |  | ||||||
| 			if buf[i] != byte(i) { |  | ||||||
| 				t.Fatalf("error in pattern") |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // shared test variables | // shared test variables | ||||||
| var ( | var ( | ||||||
| 	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix()) | 	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix()) | ||||||
| @@ -118,9 +68,9 @@ func newUDPTest(t *testing.T) *udpTest { | |||||||
| 		pipe:       newpipe(), | 		pipe:       newpipe(), | ||||||
| 		localkey:   newkey(), | 		localkey:   newkey(), | ||||||
| 		remotekey:  newkey(), | 		remotekey:  newkey(), | ||||||
| 		remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, | 		remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, | ||||||
| 	} | 	} | ||||||
| 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") | 	test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) | ||||||
| 	return test | 	return test | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -362,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) { | |||||||
| 	// check that the sent neighbors are all returned by findnode | 	// check that the sent neighbors are all returned by findnode | ||||||
| 	select { | 	select { | ||||||
| 	case result := <-resultc: | 	case result := <-resultc: | ||||||
| 		if !reflect.DeepEqual(result, list) { | 		want := append(list[:2], list[3:]...) | ||||||
| 			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, list) | 		if !reflect.DeepEqual(result, want) { | ||||||
|  | 			t.Errorf("neighbors mismatch:\n  got:  %v\n  want: %v", result, want) | ||||||
| 		} | 		} | ||||||
| 	case err := <-errc: | 	case err := <-errc: | ||||||
| 		t.Errorf("findnode error: %v", err) | 		t.Errorf("findnode error: %v", err) | ||||||
|   | |||||||
| @@ -31,6 +31,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/rlp" | 	"github.com/ethereum/go-ethereum/rlp" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -45,6 +46,7 @@ const ( | |||||||
| 	bucketRefreshInterval = 1 * time.Minute | 	bucketRefreshInterval = 1 * time.Minute | ||||||
| 	seedCount             = 30 | 	seedCount             = 30 | ||||||
| 	seedMaxAge            = 5 * 24 * time.Hour | 	seedMaxAge            = 5 * 24 * time.Hour | ||||||
|  | 	lowPort               = 1024 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const testTopic = "foo" | const testTopic = "foo" | ||||||
| @@ -62,8 +64,9 @@ func debugLog(s string) { | |||||||
|  |  | ||||||
| // Network manages the table and all protocol interaction. | // Network manages the table and all protocol interaction. | ||||||
| type Network struct { | type Network struct { | ||||||
| 	db   *nodeDB // database of known nodes | 	db          *nodeDB // database of known nodes | ||||||
| 	conn transport | 	conn        transport | ||||||
|  | 	netrestrict *netutil.Netlist | ||||||
|  |  | ||||||
| 	closed           chan struct{}          // closed when loop is done | 	closed           chan struct{}          // closed when loop is done | ||||||
| 	closeReq         chan struct{}          // 'request to close' | 	closeReq         chan struct{}          // 'request to close' | ||||||
| @@ -132,7 +135,7 @@ type timeoutEvent struct { | |||||||
| 	node *Node | 	node *Node | ||||||
| } | } | ||||||
|  |  | ||||||
| func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { | func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { | ||||||
| 	ourID := PubkeyID(&ourPubkey) | 	ourID := PubkeyID(&ourPubkey) | ||||||
|  |  | ||||||
| 	var db *nodeDB | 	var db *nodeDB | ||||||
| @@ -147,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d | |||||||
| 	net := &Network{ | 	net := &Network{ | ||||||
| 		db:               db, | 		db:               db, | ||||||
| 		conn:             conn, | 		conn:             conn, | ||||||
|  | 		netrestrict:      netrestrict, | ||||||
| 		tab:              tab, | 		tab:              tab, | ||||||
| 		topictab:         newTopicTable(db, tab.self), | 		topictab:         newTopicTable(db, tab.self), | ||||||
| 		ticketStore:      newTicketStore(), | 		ticketStore:      newTicketStore(), | ||||||
| @@ -684,16 +688,22 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node { | |||||||
| 	return n | 	return n | ||||||
| } | } | ||||||
|  |  | ||||||
| func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) { | func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) { | ||||||
| 	if rn.ID == net.tab.self.ID { | 	if rn.ID == net.tab.self.ID { | ||||||
| 		return nil, errors.New("is self") | 		return nil, errors.New("is self") | ||||||
| 	} | 	} | ||||||
|  | 	if rn.UDP <= lowPort { | ||||||
|  | 		return nil, errors.New("low port") | ||||||
|  | 	} | ||||||
| 	n = net.nodes[rn.ID] | 	n = net.nodes[rn.ID] | ||||||
| 	if n == nil { | 	if n == nil { | ||||||
| 		// We haven't seen this node before. | 		// We haven't seen this node before. | ||||||
| 		n, err = nodeFromRPC(rn) | 		n, err = nodeFromRPC(sender, rn) | ||||||
| 		n.state = unknown | 		if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { | ||||||
|  | 			return n, errors.New("not contained in netrestrict whitelist") | ||||||
|  | 		} | ||||||
| 		if err == nil { | 		if err == nil { | ||||||
|  | 			n.state = unknown | ||||||
| 			net.nodes[n.ID] = n | 			net.nodes[n.ID] = n | ||||||
| 		} | 		} | ||||||
| 		return n, err | 		return n, err | ||||||
| @@ -1095,7 +1105,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) | |||||||
| 		net.conn.sendNeighbours(n, results) | 		net.conn.sendNeighbours(n, results) | ||||||
| 		return n.state, nil | 		return n.state, nil | ||||||
| 	case neighborsPacket: | 	case neighborsPacket: | ||||||
| 		err := net.handleNeighboursPacket(n, pkt.data.(*neighbors)) | 		err := net.handleNeighboursPacket(n, pkt) | ||||||
| 		return n.state, err | 		return n.state, err | ||||||
| 	case neighboursTimeout: | 	case neighboursTimeout: | ||||||
| 		if n.pendingNeighbours != nil { | 		if n.pendingNeighbours != nil { | ||||||
| @@ -1182,17 +1192,18 @@ func rlpHash(x interface{}) (h common.Hash) { | |||||||
| 	return h | 	return h | ||||||
| } | } | ||||||
|  |  | ||||||
| func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error { | func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { | ||||||
| 	if n.pendingNeighbours == nil { | 	if n.pendingNeighbours == nil { | ||||||
| 		return errNoQuery | 		return errNoQuery | ||||||
| 	} | 	} | ||||||
| 	net.abortTimedEvent(n, neighboursTimeout) | 	net.abortTimedEvent(n, neighboursTimeout) | ||||||
|  |  | ||||||
|  | 	req := pkt.data.(*neighbors) | ||||||
| 	nodes := make([]*Node, len(req.Nodes)) | 	nodes := make([]*Node, len(req.Nodes)) | ||||||
| 	for i, rn := range req.Nodes { | 	for i, rn := range req.Nodes { | ||||||
| 		nn, err := net.internNodeFromNeighbours(rn) | 		nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err) | 			glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		nodes[i] = nn | 		nodes[i] = nn | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ import ( | |||||||
|  |  | ||||||
| func TestNetwork_Lookup(t *testing.T) { | func TestNetwork_Lookup(t *testing.T) { | ||||||
| 	key, _ := crypto.GenerateKey() | 	key, _ := crypto.GenerateKey() | ||||||
| 	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") | 	network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) { | |||||||
| 	// 	t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) | 	// 	t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) | ||||||
| 	// } | 	// } | ||||||
| 	// seed table with initial node (otherwise lookup will terminate immediately) | 	// seed table with initial node (otherwise lookup will terminate immediately) | ||||||
| 	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)} | 	seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)} | ||||||
| 	if err := network.SetFallbackNodes(seeds); err != nil { | 	if err := network.SetFallbackNodes(seeds); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| @@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) { | |||||||
| func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) { | func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) { | ||||||
| 	// current log distance is encoded in port number | 	// current log distance is encoded in port number | ||||||
| 	// fmt.Println("findnode query at dist", toaddr.Port) | 	// fmt.Println("findnode query at dist", toaddr.Port) | ||||||
| 	if to.UDP == 0 { | 	if to.UDP <= lowPort { | ||||||
| 		panic("query to node at distance 0") | 		panic("query to node at or below distance 0") | ||||||
| 	} | 	} | ||||||
| 	next := to.UDP - 1 | 	next := to.UDP - 1 | ||||||
| 	var result []rpcNode | 	var result []rpcNode | ||||||
| 	for i, id := range tn.dists[to.UDP] { | 	for i, id := range tn.dists[to.UDP-lowPort] { | ||||||
| 		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) | 		result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) | ||||||
| 	} | 	} | ||||||
| 	injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) | 	injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) | ||||||
| } | } | ||||||
| @@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha | |||||||
| 		// ignored | 		// ignored | ||||||
| 	case findnodeHashPacket: | 	case findnodeHashPacket: | ||||||
| 		// current log distance is encoded in port number | 		// current log distance is encoded in port number | ||||||
| 		// fmt.Println("findnode query at dist", toaddr.Port) | 		// fmt.Println("findnode query at dist", toaddr.Port-lowPort) | ||||||
| 		if to.UDP == 0 { | 		if to.UDP <= lowPort { | ||||||
| 			panic("query to node at distance 0") | 			panic("query to node at or below  distance 0") | ||||||
| 		} | 		} | ||||||
| 		next := to.UDP - 1 | 		next := to.UDP - 1 | ||||||
| 		var result []rpcNode | 		var result []rpcNode | ||||||
| 		for i, id := range tn.dists[to.UDP] { | 		for i, id := range tn.dists[to.UDP-lowPort] { | ||||||
| 			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) | 			result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) | ||||||
| 		} | 		} | ||||||
| 		injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) | 		injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) | ||||||
| 	default: | 	default: | ||||||
| @@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int, | |||||||
| 	panic("sendTopicRegister called") | 	panic("sendTopicRegister called") | ||||||
| } | } | ||||||
|  |  | ||||||
| func (*preminedTestnet) Close()                  {} | func (*preminedTestnet) Close() {} | ||||||
| func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) } |  | ||||||
|  | func (*preminedTestnet) localAddr() *net.UDPAddr { | ||||||
|  | 	return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000} | ||||||
|  | } | ||||||
|  |  | ||||||
| // mine generates a testnet struct literal with nodes at | // mine generates a testnet struct literal with nodes at | ||||||
| // various distances to the given target. | // various distances to the given target. | ||||||
|   | |||||||
| @@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network { | |||||||
| 	addr := &net.UDPAddr{IP: ip, Port: 30303} | 	addr := &net.UDPAddr{IP: ip, Port: 30303} | ||||||
|  |  | ||||||
| 	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} | 	transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} | ||||||
| 	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>") | 	net, err := newNetwork(transport, key.PublicKey, nil, "<no database>", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		panic("cannot launch new node: " + err.Error()) | 		panic("cannot launch new node: " + err.Error()) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/rlp" | 	"github.com/ethereum/go-ethereum/rlp" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool { | |||||||
| 	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP) | 	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP) | ||||||
| } | } | ||||||
|  |  | ||||||
| func nodeFromRPC(rn rpcNode) (*Node, error) { | func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { | ||||||
| 	// TODO: don't accept localhost, LAN addresses from internet hosts | 	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
| 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | 	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) | ||||||
| 	err := n.validateComplete() | 	err := n.validateComplete() | ||||||
| 	return n, err | 	return n, err | ||||||
| @@ -235,12 +238,12 @@ type udp struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // ListenUDP returns a new table that listens for UDP packets on laddr. | // ListenUDP returns a new table that listens for UDP packets on laddr. | ||||||
| func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { | func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { | ||||||
| 	transport, err := listenUDP(priv, laddr) | 	transport, err := listenUDP(priv, laddr) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) | 	net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| @@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	for i, result := range nodes { | 	for i, result := range nodes { | ||||||
|  | 		if netutil.CheckRelayIP(remote.IP, result.IP) != nil { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
| 		p.Nodes = append(p.Nodes, nodeToRPC(result)) | 		p.Nodes = append(p.Nodes, nodeToRPC(result)) | ||||||
| 		if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { | 		if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { | ||||||
| 			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) | 			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) | ||||||
| @@ -385,7 +391,7 @@ func (t *udp) readLoop() { | |||||||
| 	buf := make([]byte, 1280) | 	buf := make([]byte, 1280) | ||||||
| 	for { | 	for { | ||||||
| 		nbytes, from, err := t.conn.ReadFromUDP(buf) | 		nbytes, from, err := t.conn.ReadFromUDP(buf) | ||||||
| 		if isTemporaryError(err) { | 		if netutil.IsTemporaryError(err) { | ||||||
| 			// Ignore temporary read errors. | 			// Ignore temporary read errors. | ||||||
| 			glog.V(logger.Debug).Infof("Temporary read error: %v", err) | 			glog.V(logger.Debug).Infof("Temporary read error: %v", err) | ||||||
| 			continue | 			continue | ||||||
| @@ -398,13 +404,6 @@ func (t *udp) readLoop() { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func isTemporaryError(err error) bool { |  | ||||||
| 	tempErr, ok := err.(interface { |  | ||||||
| 		Temporary() bool |  | ||||||
| 	}) |  | ||||||
| 	return ok && tempErr.Temporary() || isPacketTooBig(err) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { | func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { | ||||||
| 	pkt := ingressPacket{remoteAddr: from} | 	pkt := ingressPacket{remoteAddr: from} | ||||||
| 	if err := decodePacket(buf, &pkt); err != nil { | 	if err := decodePacket(buf, &pkt); err != nil { | ||||||
|   | |||||||
| @@ -36,56 +36,6 @@ func init() { | |||||||
| 	spew.Config.DisableMethods = true | 	spew.Config.DisableMethods = true | ||||||
| } | } | ||||||
|  |  | ||||||
| // This test checks that isPacketTooBig correctly identifies |  | ||||||
| // errors that result from receiving a UDP packet larger |  | ||||||
| // than the supplied receive buffer. |  | ||||||
| func TestIsPacketTooBig(t *testing.T) { |  | ||||||
| 	listener, err := net.ListenPacket("udp", "127.0.0.1:0") |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer listener.Close() |  | ||||||
| 	sender, err := net.Dial("udp", listener.LocalAddr().String()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer sender.Close() |  | ||||||
|  |  | ||||||
| 	sendN := 1800 |  | ||||||
| 	recvN := 300 |  | ||||||
| 	for i := 0; i < 20; i++ { |  | ||||||
| 		go func() { |  | ||||||
| 			buf := make([]byte, sendN) |  | ||||||
| 			for i := range buf { |  | ||||||
| 				buf[i] = byte(i) |  | ||||||
| 			} |  | ||||||
| 			sender.Write(buf) |  | ||||||
| 		}() |  | ||||||
|  |  | ||||||
| 		buf := make([]byte, recvN) |  | ||||||
| 		listener.SetDeadline(time.Now().Add(1 * time.Second)) |  | ||||||
| 		n, _, err := listener.ReadFrom(buf) |  | ||||||
| 		if err != nil { |  | ||||||
| 			if nerr, ok := err.(net.Error); ok && nerr.Timeout() { |  | ||||||
| 				continue |  | ||||||
| 			} |  | ||||||
| 			if !isPacketTooBig(err) { |  | ||||||
| 				t.Fatal("unexpected read error:", spew.Sdump(err)) |  | ||||||
| 			} |  | ||||||
| 			continue |  | ||||||
| 		} |  | ||||||
| 		if n != recvN { |  | ||||||
| 			t.Fatalf("short read: %d, want %d", n, recvN) |  | ||||||
| 		} |  | ||||||
| 		for i := range buf { |  | ||||||
| 			if buf[i] != byte(i) { |  | ||||||
| 				t.Fatalf("error in pattern") |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // shared test variables | // shared test variables | ||||||
| var ( | var ( | ||||||
| 	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix()) | 	futureExp          = uint64(time.Now().Add(10 * time.Hour).Unix()) | ||||||
|   | |||||||
| @@ -1,40 +0,0 @@ | |||||||
| // Copyright 2016 The go-ethereum Authors |  | ||||||
| // This file is part of the go-ethereum library. |  | ||||||
| // |  | ||||||
| // The go-ethereum library is free software: you can redistribute it and/or modify |  | ||||||
| // it under the terms of the GNU Lesser General Public License as published by |  | ||||||
| // the Free Software Foundation, either version 3 of the License, or |  | ||||||
| // (at your option) any later version. |  | ||||||
| // |  | ||||||
| // The go-ethereum library is distributed in the hope that it will be useful, |  | ||||||
| // but WITHOUT ANY WARRANTY; without even the implied warranty of |  | ||||||
| // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |  | ||||||
| // GNU Lesser General Public License for more details. |  | ||||||
| // |  | ||||||
| // You should have received a copy of the GNU Lesser General Public License |  | ||||||
| // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. |  | ||||||
|  |  | ||||||
| //+build windows |  | ||||||
|  |  | ||||||
| package discv5 |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"net" |  | ||||||
| 	"os" |  | ||||||
| 	"syscall" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const _WSAEMSGSIZE = syscall.Errno(10040) |  | ||||||
|  |  | ||||||
| // reports whether err indicates that a UDP packet didn't |  | ||||||
| // fit the receive buffer. On Windows, WSARecvFrom returns |  | ||||||
| // code WSAEMSGSIZE and no data if this happens. |  | ||||||
| func isPacketTooBig(err error) bool { |  | ||||||
| 	if opErr, ok := err.(*net.OpError); ok { |  | ||||||
| 		if scErr, ok := opErr.Err.(*os.SyscallError); ok { |  | ||||||
| 			return scErr.Err == _WSAEMSGSIZE |  | ||||||
| 		} |  | ||||||
| 		return opErr.Err == _WSAEMSGSIZE |  | ||||||
| 	} |  | ||||||
| 	return false |  | ||||||
| } |  | ||||||
| @@ -14,13 +14,12 @@ | |||||||
| // You should have received a copy of the GNU Lesser General Public License | // You should have received a copy of the GNU Lesser General Public License | ||||||
| // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | ||||||
| 
 | 
 | ||||||
| //+build !windows | package netutil | ||||||
| 
 | 
 | ||||||
| package discv5 | // IsTemporaryError checks whether the given error should be considered temporary. | ||||||
| 
 | func IsTemporaryError(err error) bool { | ||||||
| // reports whether err indicates that a UDP packet didn't | 	tempErr, ok := err.(interface { | ||||||
| // fit the receive buffer. There is no such error on | 		Temporary() bool | ||||||
| // non-Windows platforms. | 	}) | ||||||
| func isPacketTooBig(err error) bool { | 	return ok && tempErr.Temporary() || isPacketTooBig(err) | ||||||
| 	return false |  | ||||||
| } | } | ||||||
							
								
								
									
										73
									
								
								p2p/netutil/error_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								p2p/netutil/error_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,73 @@ | |||||||
|  | // Copyright 2016 The go-ethereum Authors | ||||||
|  | // This file is part of the go-ethereum library. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Lesser General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||||
|  | // GNU Lesser General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Lesser General Public License | ||||||
|  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  | ||||||
|  | package netutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // This test checks that isPacketTooBig correctly identifies | ||||||
|  | // errors that result from receiving a UDP packet larger | ||||||
|  | // than the supplied receive buffer. | ||||||
|  | func TestIsPacketTooBig(t *testing.T) { | ||||||
|  | 	listener, err := net.ListenPacket("udp", "127.0.0.1:0") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer listener.Close() | ||||||
|  | 	sender, err := net.Dial("udp", listener.LocalAddr().String()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer sender.Close() | ||||||
|  |  | ||||||
|  | 	sendN := 1800 | ||||||
|  | 	recvN := 300 | ||||||
|  | 	for i := 0; i < 20; i++ { | ||||||
|  | 		go func() { | ||||||
|  | 			buf := make([]byte, sendN) | ||||||
|  | 			for i := range buf { | ||||||
|  | 				buf[i] = byte(i) | ||||||
|  | 			} | ||||||
|  | 			sender.Write(buf) | ||||||
|  | 		}() | ||||||
|  |  | ||||||
|  | 		buf := make([]byte, recvN) | ||||||
|  | 		listener.SetDeadline(time.Now().Add(1 * time.Second)) | ||||||
|  | 		n, _, err := listener.ReadFrom(buf) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if nerr, ok := err.(net.Error); ok && nerr.Timeout() { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 			if !isPacketTooBig(err) { | ||||||
|  | 				t.Fatalf("unexpected read error: %v", err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if n != recvN { | ||||||
|  | 			t.Fatalf("short read: %d, want %d", n, recvN) | ||||||
|  | 		} | ||||||
|  | 		for i := range buf { | ||||||
|  | 			if buf[i] != byte(i) { | ||||||
|  | 				t.Fatalf("error in pattern") | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										166
									
								
								p2p/netutil/net.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								p2p/netutil/net.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,166 @@ | |||||||
|  | // Copyright 2016 The go-ethereum Authors | ||||||
|  | // This file is part of the go-ethereum library. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Lesser General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||||
|  | // GNU Lesser General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Lesser General Public License | ||||||
|  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  | ||||||
|  | // Package netutil contains extensions to the net package. | ||||||
|  | package netutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var lan4, lan6, special4, special6 Netlist | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	// Lists from RFC 5735, RFC 5156, | ||||||
|  | 	// https://www.iana.org/assignments/iana-ipv4-special-registry/ | ||||||
|  | 	lan4.Add("0.0.0.0/8")              // "This" network | ||||||
|  | 	lan4.Add("10.0.0.0/8")             // Private Use | ||||||
|  | 	lan4.Add("172.16.0.0/12")          // Private Use | ||||||
|  | 	lan4.Add("192.168.0.0/16")         // Private Use | ||||||
|  | 	lan6.Add("fe80::/10")              // Link-Local | ||||||
|  | 	lan6.Add("fc00::/7")               // Unique-Local | ||||||
|  | 	special4.Add("192.0.0.0/29")       // IPv4 Service Continuity | ||||||
|  | 	special4.Add("192.0.0.9/32")       // PCP Anycast | ||||||
|  | 	special4.Add("192.0.0.170/32")     // NAT64/DNS64 Discovery | ||||||
|  | 	special4.Add("192.0.0.171/32")     // NAT64/DNS64 Discovery | ||||||
|  | 	special4.Add("192.0.2.0/24")       // TEST-NET-1 | ||||||
|  | 	special4.Add("192.31.196.0/24")    // AS112 | ||||||
|  | 	special4.Add("192.52.193.0/24")    // AMT | ||||||
|  | 	special4.Add("192.88.99.0/24")     // 6to4 Relay Anycast | ||||||
|  | 	special4.Add("192.175.48.0/24")    // AS112 | ||||||
|  | 	special4.Add("198.18.0.0/15")      // Device Benchmark Testing | ||||||
|  | 	special4.Add("198.51.100.0/24")    // TEST-NET-2 | ||||||
|  | 	special4.Add("203.0.113.0/24")     // TEST-NET-3 | ||||||
|  | 	special4.Add("255.255.255.255/32") // Limited Broadcast | ||||||
|  |  | ||||||
|  | 	// http://www.iana.org/assignments/iana-ipv6-special-registry/ | ||||||
|  | 	special6.Add("100::/64") | ||||||
|  | 	special6.Add("2001::/32") | ||||||
|  | 	special6.Add("2001:1::1/128") | ||||||
|  | 	special6.Add("2001:2::/48") | ||||||
|  | 	special6.Add("2001:3::/32") | ||||||
|  | 	special6.Add("2001:4:112::/48") | ||||||
|  | 	special6.Add("2001:5::/32") | ||||||
|  | 	special6.Add("2001:10::/28") | ||||||
|  | 	special6.Add("2001:20::/28") | ||||||
|  | 	special6.Add("2001:db8::/32") | ||||||
|  | 	special6.Add("2002::/16") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Netlist is a list of IP networks. | ||||||
|  | type Netlist []net.IPNet | ||||||
|  |  | ||||||
|  | // ParseNetlist parses a comma-separated list of CIDR masks. | ||||||
|  | // Whitespace and extra commas are ignored. | ||||||
|  | func ParseNetlist(s string) (*Netlist, error) { | ||||||
|  | 	ws := strings.NewReplacer(" ", "", "\n", "", "\t", "") | ||||||
|  | 	masks := strings.Split(ws.Replace(s), ",") | ||||||
|  | 	l := make(Netlist, 0) | ||||||
|  | 	for _, mask := range masks { | ||||||
|  | 		if mask == "" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		_, n, err := net.ParseCIDR(mask) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		l = append(l, *n) | ||||||
|  | 	} | ||||||
|  | 	return &l, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is | ||||||
|  | // intended to be used for setting up static lists. | ||||||
|  | func (l *Netlist) Add(cidr string) { | ||||||
|  | 	_, n, err := net.ParseCIDR(cidr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	*l = append(*l, *n) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Contains reports whether the given IP is contained in the list. | ||||||
|  | func (l *Netlist) Contains(ip net.IP) bool { | ||||||
|  | 	if l == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  | 	for _, net := range *l { | ||||||
|  | 		if net.Contains(ip) { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // IsLAN reports whether an IP is a local network address. | ||||||
|  | func IsLAN(ip net.IP) bool { | ||||||
|  | 	if ip.IsLoopback() { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if v4 := ip.To4(); v4 != nil { | ||||||
|  | 		return lan4.Contains(v4) | ||||||
|  | 	} | ||||||
|  | 	return lan6.Contains(ip) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // IsSpecialNetwork reports whether an IP is located in a special-use network range | ||||||
|  | // This includes broadcast, multicast and documentation addresses. | ||||||
|  | func IsSpecialNetwork(ip net.IP) bool { | ||||||
|  | 	if ip.IsMulticast() { | ||||||
|  | 		return true | ||||||
|  | 	} | ||||||
|  | 	if v4 := ip.To4(); v4 != nil { | ||||||
|  | 		return special4.Contains(v4) | ||||||
|  | 	} | ||||||
|  | 	return special6.Contains(ip) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	errInvalid     = errors.New("invalid IP") | ||||||
|  | 	errUnspecified = errors.New("zero address") | ||||||
|  | 	errSpecial     = errors.New("special network") | ||||||
|  | 	errLoopback    = errors.New("loopback address from non-loopback host") | ||||||
|  | 	errLAN         = errors.New("LAN address from WAN host") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // CheckRelayIP reports whether an IP relayed from the given sender IP | ||||||
|  | // is a valid connection target. | ||||||
|  | // | ||||||
|  | // There are four rules: | ||||||
|  | //   - Special network addresses are never valid. | ||||||
|  | //   - Loopback addresses are OK if relayed by a loopback host. | ||||||
|  | //   - LAN addresses are OK if relayed by a LAN host. | ||||||
|  | //   - All other addresses are always acceptable. | ||||||
|  | func CheckRelayIP(sender, addr net.IP) error { | ||||||
|  | 	if len(addr) != net.IPv4len && len(addr) != net.IPv6len { | ||||||
|  | 		return errInvalid | ||||||
|  | 	} | ||||||
|  | 	if addr.IsUnspecified() { | ||||||
|  | 		return errUnspecified | ||||||
|  | 	} | ||||||
|  | 	if IsSpecialNetwork(addr) { | ||||||
|  | 		return errSpecial | ||||||
|  | 	} | ||||||
|  | 	if addr.IsLoopback() && !sender.IsLoopback() { | ||||||
|  | 		return errLoopback | ||||||
|  | 	} | ||||||
|  | 	if IsLAN(addr) && !IsLAN(sender) { | ||||||
|  | 		return errLAN | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										173
									
								
								p2p/netutil/net_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								p2p/netutil/net_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,173 @@ | |||||||
|  | // Copyright 2016 The go-ethereum Authors | ||||||
|  | // This file is part of the go-ethereum library. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is free software: you can redistribute it and/or modify | ||||||
|  | // it under the terms of the GNU Lesser General Public License as published by | ||||||
|  | // the Free Software Foundation, either version 3 of the License, or | ||||||
|  | // (at your option) any later version. | ||||||
|  | // | ||||||
|  | // The go-ethereum library is distributed in the hope that it will be useful, | ||||||
|  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||||
|  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||||
|  | // GNU Lesser General Public License for more details. | ||||||
|  | // | ||||||
|  | // You should have received a copy of the GNU Lesser General Public License | ||||||
|  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | ||||||
|  |  | ||||||
|  | package netutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/davecgh/go-spew/spew" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestParseNetlist(t *testing.T) { | ||||||
|  | 	var tests = []struct { | ||||||
|  | 		input    string | ||||||
|  | 		wantErr  error | ||||||
|  | 		wantList *Netlist | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			input:    "", | ||||||
|  | 			wantList: &Netlist{}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:    "127.0.0.0/8", | ||||||
|  | 			wantErr:  nil, | ||||||
|  | 			wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input:   "127.0.0.0/44", | ||||||
|  | 			wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			input: "127.0.0.0/16, 23.23.23.23/24,", | ||||||
|  | 			wantList: &Netlist{ | ||||||
|  | 				{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)}, | ||||||
|  | 				{IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		l, err := ParseNetlist(test.input) | ||||||
|  | 		if !reflect.DeepEqual(err, test.wantErr) { | ||||||
|  | 			t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		if !reflect.DeepEqual(l, test.wantList) { | ||||||
|  | 			spew.Dump(l) | ||||||
|  | 			spew.Dump(test.wantList) | ||||||
|  | 			t.Errorf("%q: got %v, want %v", test.input, l, test.wantList) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNilNetListContains(t *testing.T) { | ||||||
|  | 	var list *Netlist | ||||||
|  | 	checkContains(t, list.Contains, nil, []string{"1.2.3.4"}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIsLAN(t *testing.T) { | ||||||
|  | 	checkContains(t, IsLAN, | ||||||
|  | 		[]string{ // included | ||||||
|  | 			"0.0.0.0", | ||||||
|  | 			"0.2.0.8", | ||||||
|  | 			"127.0.0.1", | ||||||
|  | 			"10.0.1.1", | ||||||
|  | 			"10.22.0.3", | ||||||
|  | 			"172.31.252.251", | ||||||
|  | 			"192.168.1.4", | ||||||
|  | 			"fe80::f4a1:8eff:fec5:9d9d", | ||||||
|  | 			"febf::ab32:2233", | ||||||
|  | 			"fc00::4", | ||||||
|  | 		}, | ||||||
|  | 		[]string{ // excluded | ||||||
|  | 			"192.0.2.1", | ||||||
|  | 			"1.0.0.0", | ||||||
|  | 			"172.32.0.1", | ||||||
|  | 			"fec0::2233", | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestIsSpecialNetwork(t *testing.T) { | ||||||
|  | 	checkContains(t, IsSpecialNetwork, | ||||||
|  | 		[]string{ // included | ||||||
|  | 			"192.0.2.1", | ||||||
|  | 			"192.0.2.44", | ||||||
|  | 			"2001:db8:85a3:8d3:1319:8a2e:370:7348", | ||||||
|  | 			"255.255.255.255", | ||||||
|  | 			"224.0.0.22", // IPv4 multicast | ||||||
|  | 			"ff05::1:3",  // IPv6 multicast | ||||||
|  | 		}, | ||||||
|  | 		[]string{ // excluded | ||||||
|  | 			"192.0.3.1", | ||||||
|  | 			"1.0.0.0", | ||||||
|  | 			"172.32.0.1", | ||||||
|  | 			"fec0::2233", | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) { | ||||||
|  | 	for _, s := range inc { | ||||||
|  | 		if !fn(parseIP(s)) { | ||||||
|  | 			t.Error("returned false for included address", s) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for _, s := range exc { | ||||||
|  | 		if fn(parseIP(s)) { | ||||||
|  | 			t.Error("returned true for excluded address", s) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseIP(s string) net.IP { | ||||||
|  | 	ip := net.ParseIP(s) | ||||||
|  | 	if ip == nil { | ||||||
|  | 		panic("invalid " + s) | ||||||
|  | 	} | ||||||
|  | 	return ip | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCheckRelayIP(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		sender, addr string | ||||||
|  | 		want         error | ||||||
|  | 	}{ | ||||||
|  | 		{"127.0.0.1", "0.0.0.0", errUnspecified}, | ||||||
|  | 		{"192.168.0.1", "0.0.0.0", errUnspecified}, | ||||||
|  | 		{"23.55.1.242", "0.0.0.0", errUnspecified}, | ||||||
|  | 		{"127.0.0.1", "255.255.255.255", errSpecial}, | ||||||
|  | 		{"192.168.0.1", "255.255.255.255", errSpecial}, | ||||||
|  | 		{"23.55.1.242", "255.255.255.255", errSpecial}, | ||||||
|  | 		{"192.168.0.1", "127.0.2.19", errLoopback}, | ||||||
|  | 		{"23.55.1.242", "192.168.0.1", errLAN}, | ||||||
|  |  | ||||||
|  | 		{"127.0.0.1", "127.0.2.19", nil}, | ||||||
|  | 		{"127.0.0.1", "192.168.0.1", nil}, | ||||||
|  | 		{"127.0.0.1", "23.55.1.242", nil}, | ||||||
|  | 		{"192.168.0.1", "192.168.0.1", nil}, | ||||||
|  | 		{"192.168.0.1", "23.55.1.242", nil}, | ||||||
|  | 		{"23.55.1.242", "23.55.1.242", nil}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, test := range tests { | ||||||
|  | 		err := CheckRelayIP(parseIP(test.sender), parseIP(test.addr)) | ||||||
|  | 		if err != test.want { | ||||||
|  | 			t.Errorf("%s from %s: got %q, want %q", test.addr, test.sender, err, test.want) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkCheckRelayIP(b *testing.B) { | ||||||
|  | 	sender := parseIP("23.55.1.242") | ||||||
|  | 	addr := parseIP("23.55.1.2") | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		CheckRelayIP(sender, addr) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -16,9 +16,9 @@ | |||||||
| 
 | 
 | ||||||
| //+build !windows | //+build !windows | ||||||
| 
 | 
 | ||||||
| package discover | package netutil | ||||||
| 
 | 
 | ||||||
| // reports whether err indicates that a UDP packet didn't | // isPacketTooBig reports whether err indicates that a UDP packet didn't | ||||||
| // fit the receive buffer. There is no such error on | // fit the receive buffer. There is no such error on | ||||||
| // non-Windows platforms. | // non-Windows platforms. | ||||||
| func isPacketTooBig(err error) bool { | func isPacketTooBig(err error) bool { | ||||||
| @@ -16,7 +16,7 @@ | |||||||
| 
 | 
 | ||||||
| //+build windows | //+build windows | ||||||
| 
 | 
 | ||||||
| package discover | package netutil | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"net" | 	"net" | ||||||
| @@ -26,7 +26,7 @@ import ( | |||||||
| 
 | 
 | ||||||
| const _WSAEMSGSIZE = syscall.Errno(10040) | const _WSAEMSGSIZE = syscall.Errno(10040) | ||||||
| 
 | 
 | ||||||
| // reports whether err indicates that a UDP packet didn't | // isPacketTooBig reports whether err indicates that a UDP packet didn't | ||||||
| // fit the receive buffer. On Windows, WSARecvFrom returns | // fit the receive buffer. On Windows, WSARecvFrom returns | ||||||
| // code WSAEMSGSIZE and no data if this happens. | // code WSAEMSGSIZE and no data if this happens. | ||||||
| func isPacketTooBig(err error) bool { | func isPacketTooBig(err error) bool { | ||||||
| @@ -30,6 +30,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discv5" | 	"github.com/ethereum/go-ethereum/p2p/discv5" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/nat" | 	"github.com/ethereum/go-ethereum/p2p/nat" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -101,6 +102,11 @@ type Config struct { | |||||||
| 	// allowed to connect, even above the peer limit. | 	// allowed to connect, even above the peer limit. | ||||||
| 	TrustedNodes []*discover.Node | 	TrustedNodes []*discover.Node | ||||||
|  |  | ||||||
|  | 	// Connectivity can be restricted to certain IP networks. | ||||||
|  | 	// If this option is set to a non-nil value, only hosts which match one of the | ||||||
|  | 	// IP networks contained in the list are considered. | ||||||
|  | 	NetRestrict *netutil.Netlist | ||||||
|  |  | ||||||
| 	// NodeDatabase is the path to the database containing the previously seen | 	// NodeDatabase is the path to the database containing the previously seen | ||||||
| 	// live nodes in the network. | 	// live nodes in the network. | ||||||
| 	NodeDatabase string | 	NodeDatabase string | ||||||
| @@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) { | |||||||
|  |  | ||||||
| 	// node table | 	// node table | ||||||
| 	if srv.Discovery { | 	if srv.Discovery { | ||||||
| 		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) | 		ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if srv.DiscoveryV5 { | 	if srv.DiscoveryV5 { | ||||||
| 		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase) | 		ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return err | 			return err | ||||||
| 		} | 		} | ||||||
| @@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) { | |||||||
| 	if !srv.Discovery { | 	if !srv.Discovery { | ||||||
| 		dynPeers = 0 | 		dynPeers = 0 | ||||||
| 	} | 	} | ||||||
| 	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) | 	dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict) | ||||||
|  |  | ||||||
| 	// handshake | 	// handshake | ||||||
| 	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} | 	srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} | ||||||
| @@ -634,8 +640,19 @@ func (srv *Server) listenLoop() { | |||||||
| 			} | 			} | ||||||
| 			break | 			break | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		// Reject connections that do not match NetRestrict. | ||||||
|  | 		if srv.NetRestrict != nil { | ||||||
|  | 			if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { | ||||||
|  | 				glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr()) | ||||||
|  | 				fd.Close() | ||||||
|  | 				slots <- struct{}{} | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		fd = newMeteredConn(fd, true) | 		fd = newMeteredConn(fd, true) | ||||||
| 		glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) | 		glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr()) | ||||||
|  |  | ||||||
| 		// Spawn the handler. It will give the slot back when the connection | 		// Spawn the handler. It will give the slot back when the connection | ||||||
| 		// has been established. | 		// has been established. | ||||||
|   | |||||||
| @@ -26,6 +26,7 @@ import ( | |||||||
| 	"github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| 	"github.com/ethereum/go-ethereum/logger/glog" | 	"github.com/ethereum/go-ethereum/logger/glog" | ||||||
| 	"github.com/ethereum/go-ethereum/p2p/discover" | 	"github.com/ethereum/go-ethereum/p2p/discover" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p/netutil" | ||||||
| 	"github.com/ethereum/go-ethereum/swarm/network/kademlia" | 	"github.com/ethereum/go-ethereum/swarm/network/kademlia" | ||||||
| 	"github.com/ethereum/go-ethereum/swarm/storage" | 	"github.com/ethereum/go-ethereum/swarm/storage" | ||||||
| ) | ) | ||||||
| @@ -288,6 +289,10 @@ func newNodeRecord(addr *peerAddr) *kademlia.NodeRecord { | |||||||
| func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) { | func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) { | ||||||
| 	var nrs []*kademlia.NodeRecord | 	var nrs []*kademlia.NodeRecord | ||||||
| 	for _, p := range req.Peers { | 	for _, p := range req.Peers { | ||||||
|  | 		if err := netutil.CheckRelayIP(from.remoteAddr.IP, p.IP); err != nil { | ||||||
|  | 			glog.V(logger.Detail).Infof("invalid peer IP %v from %v: %v", from.remoteAddr.IP, p.IP, err) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
| 		nrs = append(nrs, newNodeRecord(p)) | 		nrs = append(nrs, newNodeRecord(p)) | ||||||
| 	} | 	} | ||||||
| 	self.kad.Add(nrs) | 	self.kad.Add(nrs) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user