264 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			264 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| 
								 | 
							
								// Copyright 2020 The go-ethereum Authors
							 | 
						||
| 
								 | 
							
								// This file is part of go-ethereum.
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// go-ethereum is free software: you can redistribute it and/or modify
							 | 
						||
| 
								 | 
							
								// it under the terms of the GNU General Public License as published by
							 | 
						||
| 
								 | 
							
								// the Free Software Foundation, either version 3 of the License, or
							 | 
						||
| 
								 | 
							
								// (at your option) any later version.
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// go-ethereum is distributed in the hope that it will be useful,
							 | 
						||
| 
								 | 
							
								// but WITHOUT ANY WARRANTY; without even the implied warranty of
							 | 
						||
| 
								 | 
							
								// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
							 | 
						||
| 
								 | 
							
								// GNU General Public License for more details.
							 | 
						||
| 
								 | 
							
								//
							 | 
						||
| 
								 | 
							
								// You should have received a copy of the GNU General Public License
							 | 
						||
| 
								 | 
							
								// along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								package v5test
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bytes"
							 | 
						||
| 
								 | 
							
									"crypto/ecdsa"
							 | 
						||
| 
								 | 
							
									"encoding/binary"
							 | 
						||
| 
								 | 
							
									"fmt"
							 | 
						||
| 
								 | 
							
									"net"
							 | 
						||
| 
								 | 
							
									"time"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"github.com/ethereum/go-ethereum/common/mclock"
							 | 
						||
| 
								 | 
							
									"github.com/ethereum/go-ethereum/crypto"
							 | 
						||
| 
								 | 
							
									"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
							 | 
						||
| 
								 | 
							
									"github.com/ethereum/go-ethereum/p2p/enode"
							 | 
						||
| 
								 | 
							
									"github.com/ethereum/go-ethereum/p2p/enr"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// readError represents an error during packet reading.
							 | 
						||
| 
								 | 
							
								// This exists to facilitate type-switching on the result of conn.read.
							 | 
						||
| 
								 | 
							
								type readError struct {
							 | 
						||
| 
								 | 
							
									err error
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (p *readError) Kind() byte          { return 99 }
							 | 
						||
| 
								 | 
							
								func (p *readError) Name() string        { return fmt.Sprintf("error: %v", p.err) }
							 | 
						||
| 
								 | 
							
								func (p *readError) Error() string       { return p.err.Error() }
							 | 
						||
| 
								 | 
							
								func (p *readError) Unwrap() error       { return p.err }
							 | 
						||
| 
								 | 
							
								func (p *readError) RequestID() []byte   { return nil }
							 | 
						||
| 
								 | 
							
								func (p *readError) SetRequestID([]byte) {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// readErrorf creates a readError with the given text.
							 | 
						||
| 
								 | 
							
								func readErrorf(format string, args ...interface{}) *readError {
							 | 
						||
| 
								 | 
							
									return &readError{fmt.Errorf(format, args...)}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// This is the response timeout used in tests.
							 | 
						||
| 
								 | 
							
								const waitTime = 300 * time.Millisecond
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// conn is a connection to the node under test.
							 | 
						||
| 
								 | 
							
								type conn struct {
							 | 
						||
| 
								 | 
							
									localNode  *enode.LocalNode
							 | 
						||
| 
								 | 
							
									localKey   *ecdsa.PrivateKey
							 | 
						||
| 
								 | 
							
									remote     *enode.Node
							 | 
						||
| 
								 | 
							
									remoteAddr *net.UDPAddr
							 | 
						||
| 
								 | 
							
									listeners  []net.PacketConn
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									log           logger
							 | 
						||
| 
								 | 
							
									codec         *v5wire.Codec
							 | 
						||
| 
								 | 
							
									lastRequest   v5wire.Packet
							 | 
						||
| 
								 | 
							
									lastChallenge *v5wire.Whoareyou
							 | 
						||
| 
								 | 
							
									idCounter     uint32
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								type logger interface {
							 | 
						||
| 
								 | 
							
									Logf(string, ...interface{})
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// newConn sets up a connection to the given node.
							 | 
						||
| 
								 | 
							
								func newConn(dest *enode.Node, log logger) *conn {
							 | 
						||
| 
								 | 
							
									key, err := crypto.GenerateKey()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										panic(err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									db, err := enode.OpenDB("")
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										panic(err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									ln := enode.NewLocalNode(db, key)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									return &conn{
							 | 
						||
| 
								 | 
							
										localKey:   key,
							 | 
						||
| 
								 | 
							
										localNode:  ln,
							 | 
						||
| 
								 | 
							
										remote:     dest,
							 | 
						||
| 
								 | 
							
										remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()},
							 | 
						||
| 
								 | 
							
										codec:      v5wire.NewCodec(ln, key, mclock.System{}),
							 | 
						||
| 
								 | 
							
										log:        log,
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (tc *conn) setEndpoint(c net.PacketConn) {
							 | 
						||
| 
								 | 
							
									tc.localNode.SetStaticIP(laddr(c).IP)
							 | 
						||
| 
								 | 
							
									tc.localNode.SetFallbackUDP(laddr(c).Port)
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (tc *conn) listen(ip string) net.PacketConn {
							 | 
						||
| 
								 | 
							
									l, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", ip))
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										panic(err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									tc.listeners = append(tc.listeners, l)
							 | 
						||
| 
								 | 
							
									return l
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// close shuts down all listeners and the local node.
							 | 
						||
| 
								 | 
							
								func (tc *conn) close() {
							 | 
						||
| 
								 | 
							
									for _, l := range tc.listeners {
							 | 
						||
| 
								 | 
							
										l.Close()
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									tc.localNode.Database().Close()
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// nextReqID creates a request id.
							 | 
						||
| 
								 | 
							
								func (tc *conn) nextReqID() []byte {
							 | 
						||
| 
								 | 
							
									id := make([]byte, 4)
							 | 
						||
| 
								 | 
							
									tc.idCounter++
							 | 
						||
| 
								 | 
							
									binary.BigEndian.PutUint32(id, tc.idCounter)
							 | 
						||
| 
								 | 
							
									return id
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// reqresp performs a request/response interaction on the given connection.
							 | 
						||
| 
								 | 
							
								// The request is retried if a handshake is requested.
							 | 
						||
| 
								 | 
							
								func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet {
							 | 
						||
| 
								 | 
							
									reqnonce := tc.write(c, req, nil)
							 | 
						||
| 
								 | 
							
									switch resp := tc.read(c).(type) {
							 | 
						||
| 
								 | 
							
									case *v5wire.Whoareyou:
							 | 
						||
| 
								 | 
							
										if resp.Nonce != reqnonce {
							 | 
						||
| 
								 | 
							
											return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:])
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										resp.Node = tc.remote
							 | 
						||
| 
								 | 
							
										tc.write(c, req, resp)
							 | 
						||
| 
								 | 
							
										return tc.read(c)
							 | 
						||
| 
								 | 
							
									default:
							 | 
						||
| 
								 | 
							
										return resp
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// findnode sends a FINDNODE request and waits for its responses.
							 | 
						||
| 
								 | 
							
								func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) {
							 | 
						||
| 
								 | 
							
									var (
							 | 
						||
| 
								 | 
							
										findnode = &v5wire.Findnode{ReqID: tc.nextReqID(), Distances: dists}
							 | 
						||
| 
								 | 
							
										reqnonce = tc.write(c, findnode, nil)
							 | 
						||
| 
								 | 
							
										first    = true
							 | 
						||
| 
								 | 
							
										total    uint8
							 | 
						||
| 
								 | 
							
										results  []*enode.Node
							 | 
						||
| 
								 | 
							
									)
							 | 
						||
| 
								 | 
							
									for n := 1; n > 0; {
							 | 
						||
| 
								 | 
							
										switch resp := tc.read(c).(type) {
							 | 
						||
| 
								 | 
							
										case *v5wire.Whoareyou:
							 | 
						||
| 
								 | 
							
											// Handle handshake.
							 | 
						||
| 
								 | 
							
											if resp.Nonce == reqnonce {
							 | 
						||
| 
								 | 
							
												resp.Node = tc.remote
							 | 
						||
| 
								 | 
							
												tc.write(c, findnode, resp)
							 | 
						||
| 
								 | 
							
											} else {
							 | 
						||
| 
								 | 
							
												return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:])
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										case *v5wire.Ping:
							 | 
						||
| 
								 | 
							
											// Handle ping from remote.
							 | 
						||
| 
								 | 
							
											tc.write(c, &v5wire.Pong{
							 | 
						||
| 
								 | 
							
												ReqID:  resp.ReqID,
							 | 
						||
| 
								 | 
							
												ENRSeq: tc.localNode.Seq(),
							 | 
						||
| 
								 | 
							
											}, nil)
							 | 
						||
| 
								 | 
							
										case *v5wire.Nodes:
							 | 
						||
| 
								 | 
							
											// Got NODES! Check request ID.
							 | 
						||
| 
								 | 
							
											if !bytes.Equal(resp.ReqID, findnode.ReqID) {
							 | 
						||
| 
								 | 
							
												return nil, fmt.Errorf("NODES response has wrong request id %x", resp.ReqID)
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											// Check total count. It should be greater than one
							 | 
						||
| 
								 | 
							
											// and needs to be the same across all responses.
							 | 
						||
| 
								 | 
							
											if first {
							 | 
						||
| 
								 | 
							
												if resp.Total == 0 || resp.Total > 6 {
							 | 
						||
| 
								 | 
							
													return nil, fmt.Errorf("invalid NODES response 'total' %d (not in (0,7))", resp.Total)
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
												total = resp.Total
							 | 
						||
| 
								 | 
							
												n = int(total) - 1
							 | 
						||
| 
								 | 
							
												first = false
							 | 
						||
| 
								 | 
							
											} else {
							 | 
						||
| 
								 | 
							
												n--
							 | 
						||
| 
								 | 
							
												if resp.Total != total {
							 | 
						||
| 
								 | 
							
													return nil, fmt.Errorf("invalid NODES response 'total' %d (!= %d)", resp.Total, total)
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											// Check nodes.
							 | 
						||
| 
								 | 
							
											nodes, err := checkRecords(resp.Nodes)
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												return nil, fmt.Errorf("invalid node in NODES response: %v", err)
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											results = append(results, nodes...)
							 | 
						||
| 
								 | 
							
										default:
							 | 
						||
| 
								 | 
							
											return nil, fmt.Errorf("expected NODES, got %v", resp)
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return results, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// write sends a packet on the given connection.
							 | 
						||
| 
								 | 
							
								func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce {
							 | 
						||
| 
								 | 
							
									packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil {
							 | 
						||
| 
								 | 
							
										tc.logf("Can't send %s: %v", p.Name(), err)
							 | 
						||
| 
								 | 
							
									} else {
							 | 
						||
| 
								 | 
							
										tc.logf(">> %s", p.Name())
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return nonce
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// read waits for an incoming packet on the given connection.
							 | 
						||
| 
								 | 
							
								func (tc *conn) read(c net.PacketConn) v5wire.Packet {
							 | 
						||
| 
								 | 
							
									buf := make([]byte, 1280)
							 | 
						||
| 
								 | 
							
									if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
							 | 
						||
| 
								 | 
							
										return &readError{err}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									n, fromAddr, err := c.ReadFrom(buf)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return &readError{err}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									_, _, p, err := tc.codec.Decode(buf[:n], fromAddr.String())
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return &readError{err}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									tc.logf("<< %s", p.Name())
							 | 
						||
| 
								 | 
							
									return p
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// logf prints to the test log.
							 | 
						||
| 
								 | 
							
								func (tc *conn) logf(format string, args ...interface{}) {
							 | 
						||
| 
								 | 
							
									if tc.log != nil {
							 | 
						||
| 
								 | 
							
										tc.log.Logf("(%s) %s", tc.localNode.ID().TerminalString(), fmt.Sprintf(format, args...))
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func laddr(c net.PacketConn) *net.UDPAddr {
							 | 
						||
| 
								 | 
							
									return c.LocalAddr().(*net.UDPAddr)
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func checkRecords(records []*enr.Record) ([]*enode.Node, error) {
							 | 
						||
| 
								 | 
							
									nodes := make([]*enode.Node, len(records))
							 | 
						||
| 
								 | 
							
									for i := range records {
							 | 
						||
| 
								 | 
							
										n, err := enode.New(enode.ValidSchemes, records[i])
							 | 
						||
| 
								 | 
							
										if err != nil {
							 | 
						||
| 
								 | 
							
											return nil, err
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										nodes[i] = n
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return nodes, nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func containsUint(ints []uint, x uint) bool {
							 | 
						||
| 
								 | 
							
									for i := range ints {
							 | 
						||
| 
								 | 
							
										if ints[i] == x {
							 | 
						||
| 
								 | 
							
											return true
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return false
							 | 
						||
| 
								 | 
							
								}
							 |