| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | // Copyright 2014 The go-ethereum Authors | 
					
						
							| 
									
										
										
										
											2015-07-22 18:48:40 +02:00
										 |  |  | // This file is part of the go-ethereum library. | 
					
						
							| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | // | 
					
						
							| 
									
										
										
										
											2015-07-23 18:35:11 +02:00
										 |  |  | // The go-ethereum library is free software: you can redistribute it and/or modify | 
					
						
							| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | // it under the terms of the GNU Lesser General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							| 
									
										
										
										
											2015-07-22 18:48:40 +02:00
										 |  |  | // The go-ethereum library is distributed in the hope that it will be useful, | 
					
						
							| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							| 
									
										
										
										
											2015-07-22 18:48:40 +02:00
										 |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | 
					
						
							| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | // GNU Lesser General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Lesser General Public License | 
					
						
							| 
									
										
										
										
											2015-07-22 18:48:40 +02:00
										 |  |  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | 
					
						
							| 
									
										
										
										
											2015-07-07 02:54:22 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | package p2p | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	"crypto/ecdsa" | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	"errors" | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	"math/rand" | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | 	"net" | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	"reflect" | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | 	"testing" | 
					
						
							|  |  |  | 	"time" | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	"github.com/ethereum/go-ethereum/crypto" | 
					
						
							| 
									
										
										
										
											2015-02-27 03:06:55 +00:00
										 |  |  | 	"github.com/ethereum/go-ethereum/crypto/sha3" | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	"github.com/ethereum/go-ethereum/p2p/discover" | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | func init() { | 
					
						
							|  |  |  | 	// glog.SetV(6) | 
					
						
							|  |  |  | 	// glog.SetToStderr(true) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | type testTransport struct { | 
					
						
							|  |  |  | 	id discover.NodeID | 
					
						
							|  |  |  | 	*rlpx | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	closeErr error | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func newTestTransport(id discover.NodeID, fd net.Conn) transport { | 
					
						
							|  |  |  | 	wrapped := newRLPX(fd).(*rlpx) | 
					
						
							|  |  |  | 	wrapped.rw = newRLPXFrameRW(fd, secrets{ | 
					
						
							|  |  |  | 		MAC:        zero16, | 
					
						
							|  |  |  | 		AES:        zero16, | 
					
						
							|  |  |  | 		IngressMAC: sha3.NewKeccak256(), | 
					
						
							|  |  |  | 		EgressMAC:  sha3.NewKeccak256(), | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	return &testTransport{id: id, rlpx: wrapped} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { | 
					
						
							|  |  |  | 	return c.id, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { | 
					
						
							|  |  |  | 	return &protoHandshake{ID: c.id, Name: "test"}, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func (c *testTransport) close(err error) { | 
					
						
							|  |  |  | 	c.rlpx.fd.Close() | 
					
						
							|  |  |  | 	c.closeErr = err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	server := &Server{ | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		Name:         "test", | 
					
						
							|  |  |  | 		MaxPeers:     10, | 
					
						
							|  |  |  | 		ListenAddr:   "127.0.0.1:0", | 
					
						
							|  |  |  | 		PrivateKey:   newkey(), | 
					
						
							|  |  |  | 		newPeerHook:  pf, | 
					
						
							|  |  |  | 		newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	if err := server.Start(); err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("Could not start server: %v", err) | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	return server | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | func TestServerListen(t *testing.T) { | 
					
						
							|  |  |  | 	// start the test server | 
					
						
							|  |  |  | 	connected := make(chan *Peer) | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	remid := randomID() | 
					
						
							|  |  |  | 	srv := startTestServer(t, remid, func(p *Peer) { | 
					
						
							|  |  |  | 		if p.ID() != remid { | 
					
						
							|  |  |  | 			t.Error("peer func called with wrong node id") | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 		if p == nil { | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 			t.Error("peer func called with nil conn") | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 		connected <- p | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	}) | 
					
						
							|  |  |  | 	defer close(connected) | 
					
						
							|  |  |  | 	defer srv.Stop() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// dial the test server | 
					
						
							|  |  |  | 	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("could not dial: %v", err) | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	defer conn.Close() | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	select { | 
					
						
							|  |  |  | 	case peer := <-connected: | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 		if peer.LocalAddr().String() != conn.RemoteAddr().String() { | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 			t.Errorf("peer started with wrong conn: got %v, want %v", | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 				peer.LocalAddr(), conn.RemoteAddr()) | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		peers := srv.Peers() | 
					
						
							|  |  |  | 		if !reflect.DeepEqual(peers, []*Peer{peer}) { | 
					
						
							|  |  |  | 			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	case <-time.After(1 * time.Second): | 
					
						
							|  |  |  | 		t.Error("server did not accept within one second") | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | func TestServerDial(t *testing.T) { | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	// run a one-shot TCP server to handle the connection. | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	listener, err := net.Listen("tcp", "127.0.0.1:0") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("could not setup listener: %v") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	defer listener.Close() | 
					
						
							|  |  |  | 	accepted := make(chan net.Conn) | 
					
						
							|  |  |  | 	go func() { | 
					
						
							|  |  |  | 		conn, err := listener.Accept() | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 			t.Error("accept error:", err) | 
					
						
							|  |  |  | 			return | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		accepted <- conn | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	// start the server | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	connected := make(chan *Peer) | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	remid := randomID() | 
					
						
							|  |  |  | 	srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	defer close(connected) | 
					
						
							|  |  |  | 	defer srv.Stop() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 	// tell the server to connect | 
					
						
							|  |  |  | 	tcpAddr := listener.Addr().(*net.TCPAddr) | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	select { | 
					
						
							|  |  |  | 	case conn := <-accepted: | 
					
						
							| 
									
										
										
										
											2015-06-09 22:26:26 +03:00
										 |  |  | 		defer conn.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 		select { | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 		case peer := <-connected: | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 			if peer.ID() != remid { | 
					
						
							|  |  |  | 				t.Errorf("peer has wrong id") | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			if peer.Name() != "test" { | 
					
						
							|  |  |  | 				t.Errorf("peer has wrong name") | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 			if peer.RemoteAddr().String() != conn.LocalAddr().String() { | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 				t.Errorf("peer started with wrong conn: got %v, want %v", | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | 					peer.RemoteAddr(), conn.LocalAddr()) | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 			} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 			peers := srv.Peers() | 
					
						
							|  |  |  | 			if !reflect.DeepEqual(peers, []*Peer{peer}) { | 
					
						
							|  |  |  | 				t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) | 
					
						
							|  |  |  | 			} | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 		case <-time.After(1 * time.Second): | 
					
						
							|  |  |  | 			t.Error("server did not launch peer within one second") | 
					
						
							| 
									
										
										
										
											2014-11-04 13:21:44 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2014-11-21 21:48:49 +01:00
										 |  |  | 	case <-time.After(1 * time.Second): | 
					
						
							|  |  |  | 		t.Error("server did not connect within one second") | 
					
						
							| 
									
										
										
										
											2014-10-23 16:57:54 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | // This test checks that tasks generated by dialstate are | 
					
						
							|  |  |  | // actually executed and taskdone is called for them. | 
					
						
							|  |  |  | func TestServerTaskScheduling(t *testing.T) { | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		done           = make(chan *testTask) | 
					
						
							|  |  |  | 		quit, returned = make(chan struct{}), make(chan struct{}) | 
					
						
							|  |  |  | 		tc             = 0 | 
					
						
							|  |  |  | 		tg             = taskgen{ | 
					
						
							|  |  |  | 			newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { | 
					
						
							|  |  |  | 				tc++ | 
					
						
							|  |  |  | 				return []task{&testTask{index: tc - 1}} | 
					
						
							|  |  |  | 			}, | 
					
						
							|  |  |  | 			doneFunc: func(t task) { | 
					
						
							|  |  |  | 				select { | 
					
						
							|  |  |  | 				case done <- t.(*testTask): | 
					
						
							|  |  |  | 				case <-quit: | 
					
						
							|  |  |  | 				} | 
					
						
							|  |  |  | 			}, | 
					
						
							| 
									
										
										
										
											2015-04-10 13:25:35 +02:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	) | 
					
						
							| 
									
										
										
										
											2015-04-10 13:25:35 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	// The Server in this test isn't actually running | 
					
						
							|  |  |  | 	// because we're only interested in what run does. | 
					
						
							|  |  |  | 	srv := &Server{ | 
					
						
							|  |  |  | 		MaxPeers: 10, | 
					
						
							|  |  |  | 		quit:     make(chan struct{}), | 
					
						
							|  |  |  | 		ntab:     fakeTable{}, | 
					
						
							|  |  |  | 		running:  true, | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	srv.loopWG.Add(1) | 
					
						
							|  |  |  | 	go func() { | 
					
						
							|  |  |  | 		srv.run(tg) | 
					
						
							|  |  |  | 		close(returned) | 
					
						
							|  |  |  | 	}() | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	var gotdone []*testTask | 
					
						
							|  |  |  | 	for i := 0; i < 100; i++ { | 
					
						
							|  |  |  | 		gotdone = append(gotdone, <-done) | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	for i, task := range gotdone { | 
					
						
							|  |  |  | 		if task.index != i { | 
					
						
							|  |  |  | 			t.Errorf("task %d has wrong index, got %d", i, task.index) | 
					
						
							|  |  |  | 			break | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if !task.called { | 
					
						
							|  |  |  | 			t.Errorf("task %d was not called", i) | 
					
						
							|  |  |  | 			break | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	close(quit) | 
					
						
							|  |  |  | 	srv.Stop() | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 	select { | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	case <-returned: | 
					
						
							|  |  |  | 	case <-time.After(500 * time.Millisecond): | 
					
						
							|  |  |  | 		t.Error("Server.run did not return within 500ms") | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | type taskgen struct { | 
					
						
							|  |  |  | 	newFunc  func(running int, peers map[discover.NodeID]*Peer) []task | 
					
						
							|  |  |  | 	doneFunc func(task) | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { | 
					
						
							|  |  |  | 	return tg.newFunc(running, peers) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | func (tg taskgen) taskDone(t task, now time.Time) { | 
					
						
							|  |  |  | 	tg.doneFunc(t) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | func (tg taskgen) addStatic(*discover.Node) { | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | type testTask struct { | 
					
						
							|  |  |  | 	index  int | 
					
						
							|  |  |  | 	called bool | 
					
						
							| 
									
										
										
										
											2015-05-04 13:08:42 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | func (t *testTask) Do(srv *Server) { | 
					
						
							|  |  |  | 	t.called = true | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | // This test checks that connections are disconnected | 
					
						
							|  |  |  | // just after the encryption handshake when the server is | 
					
						
							|  |  |  | // at capacity. Trusted connections should still be accepted. | 
					
						
							|  |  |  | func TestServerAtCap(t *testing.T) { | 
					
						
							|  |  |  | 	trustedID := randomID() | 
					
						
							|  |  |  | 	srv := &Server{ | 
					
						
							| 
									
										
										
										
											2015-05-04 13:59:51 +03:00
										 |  |  | 		PrivateKey:   newkey(), | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		MaxPeers:     10, | 
					
						
							| 
									
										
										
										
											2015-05-04 13:59:51 +03:00
										 |  |  | 		NoDial:       true, | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		TrustedNodes: []*discover.Node{{ID: trustedID}}, | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	if err := srv.Start(); err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("could not start: %v", err) | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	defer srv.Stop() | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	newconn := func(id discover.NodeID) *conn { | 
					
						
							|  |  |  | 		fd, _ := net.Pipe() | 
					
						
							|  |  |  | 		tx := newTestTransport(id, fd) | 
					
						
							|  |  |  | 		return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	// Inject a few connections to fill up the peer set. | 
					
						
							|  |  |  | 	for i := 0; i < 10; i++ { | 
					
						
							|  |  |  | 		c := newconn(randomID()) | 
					
						
							|  |  |  | 		if err := srv.checkpoint(c, srv.addpeer); err != nil { | 
					
						
							|  |  |  | 			t.Fatalf("could not add conn %d: %v", i, err) | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	// Try inserting a non-trusted connection. | 
					
						
							|  |  |  | 	c := newconn(randomID()) | 
					
						
							|  |  |  | 	if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { | 
					
						
							|  |  |  | 		t.Error("wrong error for insert:", err) | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	// Try inserting a trusted connection. | 
					
						
							|  |  |  | 	c = newconn(trustedID) | 
					
						
							|  |  |  | 	if err := srv.checkpoint(c, srv.posthandshake); err != nil { | 
					
						
							|  |  |  | 		t.Error("unexpected error for trusted conn @posthandshake:", err) | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	if !c.is(trustedConn) { | 
					
						
							|  |  |  | 		t.Error("Server did not set trusted flag") | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-04-30 12:41:27 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | func TestServerSetupConn(t *testing.T) { | 
					
						
							|  |  |  | 	id := randomID() | 
					
						
							|  |  |  | 	srvkey := newkey() | 
					
						
							|  |  |  | 	srvid := discover.PubkeyID(&srvkey.PublicKey) | 
					
						
							|  |  |  | 	tests := []struct { | 
					
						
							|  |  |  | 		dontstart bool | 
					
						
							|  |  |  | 		tt        *setupTransport | 
					
						
							|  |  |  | 		flags     connFlag | 
					
						
							|  |  |  | 		dialDest  *discover.Node | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		wantCloseErr error | 
					
						
							|  |  |  | 		wantCalls    string | 
					
						
							|  |  |  | 	}{ | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			dontstart:    true, | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id}, | 
					
						
							|  |  |  | 			wantCalls:    "close,", | 
					
						
							|  |  |  | 			wantCloseErr: errServerStopped, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, | 
					
						
							|  |  |  | 			flags:        inboundConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: errors.New("read error"), | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id}, | 
					
						
							|  |  |  | 			dialDest:     &discover.Node{ID: randomID()}, | 
					
						
							|  |  |  | 			flags:        dynDialedConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: DiscUnexpectedIdentity, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, | 
					
						
							|  |  |  | 			dialDest:     &discover.Node{ID: id}, | 
					
						
							|  |  |  | 			flags:        dynDialedConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,doProtoHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: DiscUnexpectedIdentity, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, | 
					
						
							|  |  |  | 			dialDest:     &discover.Node{ID: id}, | 
					
						
							|  |  |  | 			flags:        dynDialedConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,doProtoHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: errors.New("foo"), | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, | 
					
						
							|  |  |  | 			flags:        inboundConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: DiscSelf, | 
					
						
							|  |  |  | 		}, | 
					
						
							|  |  |  | 		{ | 
					
						
							|  |  |  | 			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}}, | 
					
						
							|  |  |  | 			flags:        inboundConn, | 
					
						
							|  |  |  | 			wantCalls:    "doEncHandshake,doProtoHandshake,close,", | 
					
						
							|  |  |  | 			wantCloseErr: DiscUselessPeer, | 
					
						
							|  |  |  | 		}, | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	for i, test := range tests { | 
					
						
							|  |  |  | 		srv := &Server{ | 
					
						
							|  |  |  | 			PrivateKey:   srvkey, | 
					
						
							|  |  |  | 			MaxPeers:     10, | 
					
						
							|  |  |  | 			NoDial:       true, | 
					
						
							|  |  |  | 			Protocols:    []Protocol{discard}, | 
					
						
							|  |  |  | 			newTransport: func(fd net.Conn) transport { return test.tt }, | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		if !test.dontstart { | 
					
						
							|  |  |  | 			if err := srv.Start(); err != nil { | 
					
						
							|  |  |  | 				t.Fatalf("couldn't start server: %v", err) | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 			} | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 		p1, _ := net.Pipe() | 
					
						
							|  |  |  | 		srv.setupConn(p1, test.flags, test.dialDest) | 
					
						
							|  |  |  | 		if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { | 
					
						
							|  |  |  | 			t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if test.tt.calls != test.wantCalls { | 
					
						
							|  |  |  | 			t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | type setupTransport struct { | 
					
						
							|  |  |  | 	id              discover.NodeID | 
					
						
							|  |  |  | 	encHandshakeErr error | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	phs               *protoHandshake | 
					
						
							|  |  |  | 	protoHandshakeErr error | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	calls    string | 
					
						
							|  |  |  | 	closeErr error | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { | 
					
						
							|  |  |  | 	c.calls += "doEncHandshake," | 
					
						
							|  |  |  | 	return c.id, c.encHandshakeErr | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { | 
					
						
							|  |  |  | 	c.calls += "doProtoHandshake," | 
					
						
							|  |  |  | 	if c.protoHandshakeErr != nil { | 
					
						
							|  |  |  | 		return nil, c.protoHandshakeErr | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-05-16 00:38:28 +02:00
										 |  |  | 	return c.phs, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | func (c *setupTransport) close(err error) { | 
					
						
							|  |  |  | 	c.calls += "close," | 
					
						
							|  |  |  | 	c.closeErr = err | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // setupConn shouldn't write to/read from the connection. | 
					
						
							|  |  |  | func (c *setupTransport) WriteMsg(Msg) error { | 
					
						
							|  |  |  | 	panic("WriteMsg called on setupTransport") | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | func (c *setupTransport) ReadMsg() (Msg, error) { | 
					
						
							|  |  |  | 	panic("ReadMsg called on setupTransport") | 
					
						
							| 
									
										
										
										
											2015-05-04 17:35:49 +03:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2015-02-05 03:07:58 +01:00
										 |  |  | func newkey() *ecdsa.PrivateKey { | 
					
						
							|  |  |  | 	key, err := crypto.GenerateKey() | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		panic("couldn't generate key: " + err.Error()) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return key | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func randomID() (id discover.NodeID) { | 
					
						
							|  |  |  | 	for i := range id { | 
					
						
							|  |  |  | 		id[i] = byte(rand.Intn(255)) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return id | 
					
						
							|  |  |  | } |