Merge branch 'feature/p2p-protocol-interface' of https://github.com/fjl/go-ethereum into fjl-feature/p2p-protocol-interface
This commit is contained in:
		| @@ -5,10 +5,10 @@ import ( | |||||||
| 	"runtime" | 	"runtime" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // should be used in Peer handleHandshake, incorporate Caps, ProtocolVersion, Pubkey etc. | // ClientIdentity represents the identity of a peer. | ||||||
| type ClientIdentity interface { | type ClientIdentity interface { | ||||||
| 	String() string | 	String() string // human readable identity | ||||||
| 	Pubkey() []byte | 	Pubkey() []byte // 512-bit public key | ||||||
| } | } | ||||||
|  |  | ||||||
| type SimpleClientIdentity struct { | type SimpleClientIdentity struct { | ||||||
|   | |||||||
| @@ -1,275 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	// "fmt" |  | ||||||
| 	"net" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/ethutil" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type Connection struct { |  | ||||||
| 	conn net.Conn |  | ||||||
| 	// conn       NetworkConnection |  | ||||||
| 	timeout    time.Duration |  | ||||||
| 	in         chan []byte |  | ||||||
| 	out        chan []byte |  | ||||||
| 	err        chan *PeerError |  | ||||||
| 	closingIn  chan chan bool |  | ||||||
| 	closingOut chan chan bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // const readBufferLength = 2 //for testing |  | ||||||
|  |  | ||||||
| const readBufferLength = 1440 |  | ||||||
| const partialsQueueSize = 10 |  | ||||||
| const maxPendingQueueSize = 1 |  | ||||||
| const defaultTimeout = 500 |  | ||||||
|  |  | ||||||
| var magicToken = []byte{34, 64, 8, 145} |  | ||||||
|  |  | ||||||
| func (self *Connection) Open() { |  | ||||||
| 	go self.startRead() |  | ||||||
| 	go self.startWrite() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) Close() { |  | ||||||
| 	self.closeIn() |  | ||||||
| 	self.closeOut() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) closeIn() { |  | ||||||
| 	errc := make(chan bool) |  | ||||||
| 	self.closingIn <- errc |  | ||||||
| 	<-errc |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) closeOut() { |  | ||||||
| 	errc := make(chan bool) |  | ||||||
| 	self.closingOut <- errc |  | ||||||
| 	<-errc |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewConnection(conn net.Conn, errchan chan *PeerError) *Connection { |  | ||||||
| 	return &Connection{ |  | ||||||
| 		conn:       conn, |  | ||||||
| 		timeout:    defaultTimeout, |  | ||||||
| 		in:         make(chan []byte), |  | ||||||
| 		out:        make(chan []byte), |  | ||||||
| 		err:        errchan, |  | ||||||
| 		closingIn:  make(chan chan bool, 1), |  | ||||||
| 		closingOut: make(chan chan bool, 1), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) Read() <-chan []byte { |  | ||||||
| 	return self.in |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) Write() chan<- []byte { |  | ||||||
| 	return self.out |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) Error() <-chan *PeerError { |  | ||||||
| 	return self.err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) startRead() { |  | ||||||
| 	payloads := make(chan []byte) |  | ||||||
| 	done := make(chan *PeerError) |  | ||||||
| 	pending := [][]byte{} |  | ||||||
| 	var head []byte |  | ||||||
| 	var wait time.Duration // initally 0 (no delay) |  | ||||||
| 	read := time.After(wait * time.Millisecond) |  | ||||||
|  |  | ||||||
| 	for { |  | ||||||
| 		// if pending empty, nil channel blocks |  | ||||||
| 		var in chan []byte |  | ||||||
| 		if len(pending) > 0 { |  | ||||||
| 			in = self.in // enable send case |  | ||||||
| 			head = pending[0] |  | ||||||
| 		} else { |  | ||||||
| 			in = nil |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		select { |  | ||||||
| 		case <-read: |  | ||||||
| 			go self.read(payloads, done) |  | ||||||
| 		case err := <-done: |  | ||||||
| 			if err == nil { // no error but nothing to read |  | ||||||
| 				if len(pending) < maxPendingQueueSize { |  | ||||||
| 					wait = 100 |  | ||||||
| 				} else if wait == 0 { |  | ||||||
| 					wait = 100 |  | ||||||
| 				} else { |  | ||||||
| 					wait = 2 * wait |  | ||||||
| 				} |  | ||||||
| 			} else { |  | ||||||
| 				self.err <- err // report error |  | ||||||
| 				wait = 100 |  | ||||||
| 			} |  | ||||||
| 			read = time.After(wait * time.Millisecond) |  | ||||||
| 		case payload := <-payloads: |  | ||||||
| 			pending = append(pending, payload) |  | ||||||
| 			if len(pending) < maxPendingQueueSize { |  | ||||||
| 				wait = 0 |  | ||||||
| 			} else { |  | ||||||
| 				wait = 100 |  | ||||||
| 			} |  | ||||||
| 			read = time.After(wait * time.Millisecond) |  | ||||||
| 		case in <- head: |  | ||||||
| 			pending = pending[1:] |  | ||||||
| 		case errc := <-self.closingIn: |  | ||||||
| 			errc <- true |  | ||||||
| 			close(self.in) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) startWrite() { |  | ||||||
| 	pending := [][]byte{} |  | ||||||
| 	done := make(chan *PeerError) |  | ||||||
| 	writing := false |  | ||||||
| 	for { |  | ||||||
| 		if len(pending) > 0 && !writing { |  | ||||||
| 			writing = true |  | ||||||
| 			go self.write(pending[0], done) |  | ||||||
| 		} |  | ||||||
| 		select { |  | ||||||
| 		case payload := <-self.out: |  | ||||||
| 			pending = append(pending, payload) |  | ||||||
| 		case err := <-done: |  | ||||||
| 			if err == nil { |  | ||||||
| 				pending = pending[1:] |  | ||||||
| 				writing = false |  | ||||||
| 			} else { |  | ||||||
| 				self.err <- err // report error |  | ||||||
| 			} |  | ||||||
| 		case errc := <-self.closingOut: |  | ||||||
| 			errc <- true |  | ||||||
| 			close(self.out) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func pack(payload []byte) (packet []byte) { |  | ||||||
| 	length := ethutil.NumberToBytes(uint32(len(payload)), 32) |  | ||||||
| 	// return error if too long? |  | ||||||
| 	// Write magic token and payload length (first 8 bytes) |  | ||||||
| 	packet = append(magicToken, length...) |  | ||||||
| 	packet = append(packet, payload...) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func avoidPanic(done chan *PeerError) { |  | ||||||
| 	if rec := recover(); rec != nil { |  | ||||||
| 		err := NewPeerError(MiscError, " %v", rec) |  | ||||||
| 		logger.Debugln(err) |  | ||||||
| 		done <- err |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) write(payload []byte, done chan *PeerError) { |  | ||||||
| 	defer avoidPanic(done) |  | ||||||
| 	var err *PeerError |  | ||||||
| 	_, ok := self.conn.Write(pack(payload)) |  | ||||||
| 	if ok != nil { |  | ||||||
| 		err = NewPeerError(WriteError, " %v", ok) |  | ||||||
| 		logger.Debugln(err) |  | ||||||
| 	} |  | ||||||
| 	done <- err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) read(payloads chan []byte, done chan *PeerError) { |  | ||||||
| 	//defer avoidPanic(done) |  | ||||||
|  |  | ||||||
| 	partials := make(chan []byte, partialsQueueSize) |  | ||||||
| 	errc := make(chan *PeerError) |  | ||||||
| 	go self.readPartials(partials, errc) |  | ||||||
|  |  | ||||||
| 	packet := []byte{} |  | ||||||
| 	length := 8 |  | ||||||
| 	start := true |  | ||||||
| 	var err *PeerError |  | ||||||
| out: |  | ||||||
| 	for { |  | ||||||
| 		// appends partials read via connection until packet is |  | ||||||
| 		// - either parseable (>=8bytes) |  | ||||||
| 		// - or complete (payload fully consumed) |  | ||||||
| 		for len(packet) < length { |  | ||||||
| 			partial, ok := <-partials |  | ||||||
| 			if !ok { // partials channel is closed |  | ||||||
| 				err = <-errc |  | ||||||
| 				if err == nil && len(packet) > 0 { |  | ||||||
| 					if start { |  | ||||||
| 						err = NewPeerError(PacketTooShort, "%v", packet) |  | ||||||
| 					} else { |  | ||||||
| 						err = NewPeerError(PayloadTooShort, "%d < %d", len(packet), length) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 				break out |  | ||||||
| 			} |  | ||||||
| 			packet = append(packet, partial...) |  | ||||||
| 		} |  | ||||||
| 		if start { |  | ||||||
| 			// at least 8 bytes read, can validate packet |  | ||||||
| 			if bytes.Compare(magicToken, packet[:4]) != 0 { |  | ||||||
| 				err = NewPeerError(MagicTokenMismatch, " received %v", packet[:4]) |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 			length = int(ethutil.BytesToNumber(packet[4:8])) |  | ||||||
| 			packet = packet[8:] |  | ||||||
|  |  | ||||||
| 			if length > 0 { |  | ||||||
| 				start = false // now consuming payload |  | ||||||
| 			} else { //penalize peer but read on |  | ||||||
| 				self.err <- NewPeerError(EmptyPayload, "") |  | ||||||
| 				length = 8 |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			// packet complete (payload fully consumed) |  | ||||||
| 			payloads <- packet[:length] |  | ||||||
| 			packet = packet[length:] // resclice packet |  | ||||||
| 			start = true |  | ||||||
| 			length = 8 |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// this stops partials read via the connection, should we? |  | ||||||
| 	//if err != nil { |  | ||||||
| 	//  select { |  | ||||||
| 	//    case errc <- err |  | ||||||
| 	//  default: |  | ||||||
| 	//} |  | ||||||
| 	done <- err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Connection) readPartials(partials chan []byte, errc chan *PeerError) { |  | ||||||
| 	defer close(partials) |  | ||||||
| 	for { |  | ||||||
| 		// Give buffering some time |  | ||||||
| 		self.conn.SetReadDeadline(time.Now().Add(self.timeout * time.Millisecond)) |  | ||||||
| 		buffer := make([]byte, readBufferLength) |  | ||||||
| 		// read partial from connection |  | ||||||
| 		bytesRead, err := self.conn.Read(buffer) |  | ||||||
| 		if err == nil || err.Error() == "EOF" { |  | ||||||
| 			if bytesRead > 0 { |  | ||||||
| 				partials <- buffer[:bytesRead] |  | ||||||
| 			} |  | ||||||
| 			if err != nil && err.Error() == "EOF" { |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} else { |  | ||||||
| 			// unexpected error, report to errc |  | ||||||
| 			err := NewPeerError(ReadError, " %v", err) |  | ||||||
| 			logger.Debugln(err) |  | ||||||
| 			errc <- err |  | ||||||
| 			return // will close partials channel |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	close(errc) |  | ||||||
| } |  | ||||||
| @@ -1,222 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"bytes" |  | ||||||
| 	"fmt" |  | ||||||
| 	"io" |  | ||||||
| 	"net" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type TestNetworkConnection struct { |  | ||||||
| 	in      chan []byte |  | ||||||
| 	current []byte |  | ||||||
| 	Out     [][]byte |  | ||||||
| 	addr    net.Addr |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewTestNetworkConnection(addr net.Addr) *TestNetworkConnection { |  | ||||||
| 	return &TestNetworkConnection{ |  | ||||||
| 		in:      make(chan []byte), |  | ||||||
| 		current: []byte{}, |  | ||||||
| 		Out:     [][]byte{}, |  | ||||||
| 		addr:    addr, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) In(latency time.Duration, packets ...[]byte) { |  | ||||||
| 	time.Sleep(latency) |  | ||||||
| 	for _, s := range packets { |  | ||||||
| 		self.in <- s |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) Read(buff []byte) (n int, err error) { |  | ||||||
| 	if len(self.current) == 0 { |  | ||||||
| 		select { |  | ||||||
| 		case self.current = <-self.in: |  | ||||||
| 		default: |  | ||||||
| 			return 0, io.EOF |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	length := len(self.current) |  | ||||||
| 	if length > len(buff) { |  | ||||||
| 		copy(buff[:], self.current[:len(buff)]) |  | ||||||
| 		self.current = self.current[len(buff):] |  | ||||||
| 		return len(buff), nil |  | ||||||
| 	} else { |  | ||||||
| 		copy(buff[:length], self.current[:]) |  | ||||||
| 		self.current = []byte{} |  | ||||||
| 		return length, io.EOF |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) Write(buff []byte) (n int, err error) { |  | ||||||
| 	self.Out = append(self.Out, buff) |  | ||||||
| 	fmt.Printf("net write %v\n%v\n", len(self.Out), buff) |  | ||||||
| 	return len(buff), nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) Close() (err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) LocalAddr() (addr net.Addr) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) RemoteAddr() (addr net.Addr) { |  | ||||||
| 	return self.addr |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) SetDeadline(t time.Time) (err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) SetReadDeadline(t time.Time) (err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetworkConnection) SetWriteDeadline(t time.Time) (err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func setupConnection() (*Connection, *TestNetworkConnection) { |  | ||||||
| 	addr := &TestAddr{"test:30303"} |  | ||||||
| 	net := NewTestNetworkConnection(addr) |  | ||||||
| 	conn := NewConnection(net, NewPeerErrorChannel()) |  | ||||||
| 	conn.Open() |  | ||||||
| 	return conn, net |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingNilPacket(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{}) |  | ||||||
| 	// time.Sleep(10 * time.Millisecond) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		t.Errorf("read %v", packet) |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		t.Errorf("incorrect error %v", err) |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingShortPacket(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{0}) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		t.Errorf("read %v", packet) |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		if err.Code != PacketTooShort { |  | ||||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, PacketTooShort) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingInvalidPacket(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{1, 0, 0, 0, 0, 0, 0, 0}) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		t.Errorf("read %v", packet) |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		if err.Code != MagicTokenMismatch { |  | ||||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, MagicTokenMismatch) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingInvalidPayload(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 2, 0}) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		t.Errorf("read %v", packet) |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		if err.Code != PayloadTooShort { |  | ||||||
| 			t.Errorf("incorrect error %v, expected %v", err.Code, PayloadTooShort) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingEmptyPayload(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 0}) |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		t.Errorf("read %v", packet) |  | ||||||
| 	default: |  | ||||||
| 	} |  | ||||||
| 	select { |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		code := err.Code |  | ||||||
| 		if code != EmptyPayload { |  | ||||||
| 			t.Errorf("incorrect error, expected EmptyPayload, got %v", code) |  | ||||||
| 		} |  | ||||||
| 	default: |  | ||||||
| 		t.Errorf("no error, expected EmptyPayload") |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingCompletePacket(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 1}) |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
| 	select { |  | ||||||
| 	case packet := <-conn.Read(): |  | ||||||
| 		if bytes.Compare(packet, []byte{1}) != 0 { |  | ||||||
| 			t.Errorf("incorrect payload read") |  | ||||||
| 		} |  | ||||||
| 	case err := <-conn.Error(): |  | ||||||
| 		t.Errorf("incorrect error %v", err) |  | ||||||
| 	default: |  | ||||||
| 		t.Errorf("nothing read") |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestReadingTwoCompletePackets(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	go net.In(0, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0, 34, 64, 8, 145, 0, 0, 0, 1, 1}) |  | ||||||
|  |  | ||||||
| 	for i := 0; i < 2; i++ { |  | ||||||
| 		time.Sleep(10 * time.Millisecond) |  | ||||||
| 		select { |  | ||||||
| 		case packet := <-conn.Read(): |  | ||||||
| 			if bytes.Compare(packet, []byte{byte(i)}) != 0 { |  | ||||||
| 				t.Errorf("incorrect payload read") |  | ||||||
| 			} |  | ||||||
| 		case err := <-conn.Error(): |  | ||||||
| 			t.Errorf("incorrect error %v", err) |  | ||||||
| 		default: |  | ||||||
| 			t.Errorf("nothing read") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestWriting(t *testing.T) { |  | ||||||
| 	conn, net := setupConnection() |  | ||||||
| 	conn.Write() <- []byte{0} |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
| 	if len(net.Out) == 0 { |  | ||||||
| 		t.Errorf("no output") |  | ||||||
| 	} else { |  | ||||||
| 		out := net.Out[0] |  | ||||||
| 		if bytes.Compare(out, []byte{34, 64, 8, 145, 0, 0, 0, 1, 0}) != 0 { |  | ||||||
| 			t.Errorf("incorrect packet %v", out) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // hello packet with client id ABC: 0x22 40 08 91 00 00 00 08 84 00 00 00 43414243 |  | ||||||
							
								
								
									
										202
									
								
								p2p/message.go
									
									
									
									
									
								
							
							
						
						
									
										202
									
								
								p2p/message.go
									
									
									
									
									
								
							| @@ -1,75 +1,155 @@ | |||||||
| package p2p | package p2p | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	// "fmt" | 	"bytes" | ||||||
|  | 	"encoding/binary" | ||||||
|  | 	"io" | ||||||
|  | 	"io/ioutil" | ||||||
|  | 	"math/big" | ||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/ethutil" | 	"github.com/ethereum/go-ethereum/ethutil" | ||||||
|  | 	"github.com/ethereum/go-ethereum/rlp" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type MsgCode uint8 | // Msg defines the structure of a p2p message. | ||||||
|  | // | ||||||
|  | // Note that a Msg can only be sent once since the Payload reader is | ||||||
|  | // consumed during sending. It is not possible to create a Msg and | ||||||
|  | // send it any number of times. If you want to reuse an encoded | ||||||
|  | // structure, encode the payload into a byte array and create a | ||||||
|  | // separate Msg with a bytes.Reader as Payload for each send. | ||||||
| type Msg struct { | type Msg struct { | ||||||
| 	code    MsgCode // this is the raw code as per adaptive msg code scheme | 	Code    uint64 | ||||||
| 	data    *ethutil.Value | 	Size    uint32 // size of the paylod | ||||||
| 	encoded []byte | 	Payload io.Reader | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Msg) Code() MsgCode { | // NewMsg creates an RLP-encoded message with the given code. | ||||||
| 	return self.code | func NewMsg(code uint64, params ...interface{}) Msg { | ||||||
| } | 	buf := new(bytes.Buffer) | ||||||
|  | 	for _, p := range params { | ||||||
| func (self *Msg) Data() *ethutil.Value { | 		buf.Write(ethutil.Encode(p)) | ||||||
| 	return self.data |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewMsg(code MsgCode, params ...interface{}) (msg *Msg, err error) { |  | ||||||
|  |  | ||||||
| 	// // data := [][]interface{}{} |  | ||||||
| 	// data := []interface{}{} |  | ||||||
| 	// for _, value := range params { |  | ||||||
| 	// 	if encodable, ok := value.(ethutil.RlpEncodeDecode); ok { |  | ||||||
| 	// 		data = append(data, encodable.RlpValue()) |  | ||||||
| 	// 	} else if raw, ok := value.([]interface{}); ok { |  | ||||||
| 	// 		data = append(data, raw) |  | ||||||
| 	// 	} else { |  | ||||||
| 	// 		// data = append(data, interface{}(raw)) |  | ||||||
| 	// 		err = fmt.Errorf("Unable to encode object of type %T", value) |  | ||||||
| 	// 		return |  | ||||||
| 	// 	} |  | ||||||
| 	// } |  | ||||||
| 	return &Msg{ |  | ||||||
| 		code: code, |  | ||||||
| 		data: ethutil.NewValue(interface{}(params)), |  | ||||||
| 	}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewMsgFromBytes(encoded []byte) (msg *Msg, err error) { |  | ||||||
| 	value := ethutil.NewValueFromBytes(encoded) |  | ||||||
| 	// Type of message |  | ||||||
| 	code := value.Get(0).Uint() |  | ||||||
| 	// Actual data |  | ||||||
| 	data := value.SliceFrom(1) |  | ||||||
|  |  | ||||||
| 	msg = &Msg{ |  | ||||||
| 		code: MsgCode(code), |  | ||||||
| 		data: data, |  | ||||||
| 		// data:    ethutil.NewValue(data), |  | ||||||
| 		encoded: encoded, |  | ||||||
| 	} | 	} | ||||||
| 	return | 	return Msg{Code: code, Size: uint32(buf.Len()), Payload: buf} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Msg) Decode(offset MsgCode) { | func encodePayload(params ...interface{}) []byte { | ||||||
| 	self.code = self.code - offset | 	buf := new(bytes.Buffer) | ||||||
| } | 	for _, p := range params { | ||||||
|  | 		buf.Write(ethutil.Encode(p)) | ||||||
| // encode takes an offset argument to implement adaptive message coding |  | ||||||
| // the encoded message is memoized to make msgs relayed to several peers more efficient |  | ||||||
| func (self *Msg) Encode(offset MsgCode) (res []byte) { |  | ||||||
| 	if len(self.encoded) == 0 { |  | ||||||
| 		res = ethutil.NewValue(append([]interface{}{byte(self.code + offset)}, self.data.Slice()...)).Encode() |  | ||||||
| 		self.encoded = res |  | ||||||
| 	} else { |  | ||||||
| 		res = self.encoded |  | ||||||
| 	} | 	} | ||||||
| 	return | 	return buf.Bytes() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Decode parse the RLP content of a message into | ||||||
|  | // the given value, which must be a pointer. | ||||||
|  | // | ||||||
|  | // For the decoding rules, please see package rlp. | ||||||
|  | func (msg Msg) Decode(val interface{}) error { | ||||||
|  | 	s := rlp.NewListStream(msg.Payload, uint64(msg.Size)) | ||||||
|  | 	return s.Decode(val) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Discard reads any remaining payload data into a black hole. | ||||||
|  | func (msg Msg) Discard() error { | ||||||
|  | 	_, err := io.Copy(ioutil.Discard, msg.Payload) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type MsgReader interface { | ||||||
|  | 	ReadMsg() (Msg, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type MsgWriter interface { | ||||||
|  | 	// WriteMsg sends an existing message. | ||||||
|  | 	// The Payload reader of the message is consumed. | ||||||
|  | 	// Note that messages can be sent only once. | ||||||
|  | 	WriteMsg(Msg) error | ||||||
|  |  | ||||||
|  | 	// EncodeMsg writes an RLP-encoded message with the given | ||||||
|  | 	// code and data elements. | ||||||
|  | 	EncodeMsg(code uint64, data ...interface{}) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MsgReadWriter provides reading and writing of encoded messages. | ||||||
|  | type MsgReadWriter interface { | ||||||
|  | 	MsgReader | ||||||
|  | 	MsgWriter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var magicToken = []byte{34, 64, 8, 145} | ||||||
|  |  | ||||||
|  | func writeMsg(w io.Writer, msg Msg) error { | ||||||
|  | 	// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 | ||||||
|  | 	code := ethutil.Encode(uint32(msg.Code)) | ||||||
|  | 	listhdr := makeListHeader(msg.Size + uint32(len(code))) | ||||||
|  | 	payloadLen := uint32(len(listhdr)) + uint32(len(code)) + msg.Size | ||||||
|  |  | ||||||
|  | 	start := make([]byte, 8) | ||||||
|  | 	copy(start, magicToken) | ||||||
|  | 	binary.BigEndian.PutUint32(start[4:], payloadLen) | ||||||
|  |  | ||||||
|  | 	for _, b := range [][]byte{start, listhdr, code} { | ||||||
|  | 		if _, err := w.Write(b); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	_, err := io.CopyN(w, msg.Payload, int64(msg.Size)) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func makeListHeader(length uint32) []byte { | ||||||
|  | 	if length < 56 { | ||||||
|  | 		return []byte{byte(length + 0xc0)} | ||||||
|  | 	} | ||||||
|  | 	enc := big.NewInt(int64(length)).Bytes() | ||||||
|  | 	lenb := byte(len(enc)) + 0xf7 | ||||||
|  | 	return append([]byte{lenb}, enc...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // readMsg reads a message header from r. | ||||||
|  | // It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. | ||||||
|  | func readMsg(r rlp.ByteReader) (msg Msg, err error) { | ||||||
|  | 	// read magic and payload size | ||||||
|  | 	start := make([]byte, 8) | ||||||
|  | 	if _, err = io.ReadFull(r, start); err != nil { | ||||||
|  | 		return msg, newPeerError(errRead, "%v", err) | ||||||
|  | 	} | ||||||
|  | 	if !bytes.HasPrefix(start, magicToken) { | ||||||
|  | 		return msg, newPeerError(errMagicTokenMismatch, "got %x, want %x", start[:4], magicToken) | ||||||
|  | 	} | ||||||
|  | 	size := binary.BigEndian.Uint32(start[4:]) | ||||||
|  |  | ||||||
|  | 	// decode start of RLP message to get the message code | ||||||
|  | 	posr := &postrack{r, 0} | ||||||
|  | 	s := rlp.NewStream(posr) | ||||||
|  | 	if _, err := s.List(); err != nil { | ||||||
|  | 		return msg, err | ||||||
|  | 	} | ||||||
|  | 	code, err := s.Uint() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return msg, err | ||||||
|  | 	} | ||||||
|  | 	payloadsize := size - posr.p | ||||||
|  | 	return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // postrack wraps an rlp.ByteReader with a position counter. | ||||||
|  | type postrack struct { | ||||||
|  | 	r rlp.ByteReader | ||||||
|  | 	p uint32 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *postrack) Read(buf []byte) (int, error) { | ||||||
|  | 	n, err := r.r.Read(buf) | ||||||
|  | 	r.p += uint32(n) | ||||||
|  | 	return n, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *postrack) ReadByte() (byte, error) { | ||||||
|  | 	b, err := r.r.ReadByte() | ||||||
|  | 	if err == nil { | ||||||
|  | 		r.p++ | ||||||
|  | 	} | ||||||
|  | 	return b, err | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,38 +1,70 @@ | |||||||
| package p2p | package p2p | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"io/ioutil" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/ethutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestNewMsg(t *testing.T) { | func TestNewMsg(t *testing.T) { | ||||||
| 	msg, _ := NewMsg(3, 1, "000") | 	msg := NewMsg(3, 1, "000") | ||||||
| 	if msg.Code() != 3 { | 	if msg.Code != 3 { | ||||||
| 		t.Errorf("incorrect code %v", msg.Code()) | 		t.Errorf("incorrect code %d, want %d", msg.Code) | ||||||
| 	} | 	} | ||||||
| 	data0 := msg.Data().Get(0).Uint() | 	if msg.Size != 5 { | ||||||
| 	data1 := string(msg.Data().Get(1).Bytes()) | 		t.Errorf("incorrect size %d, want %d", msg.Size, 5) | ||||||
| 	if data0 != 1 { |  | ||||||
| 		t.Errorf("incorrect data %v", data0) |  | ||||||
| 	} | 	} | ||||||
| 	if data1 != "000" { | 	pl, _ := ioutil.ReadAll(msg.Payload) | ||||||
| 		t.Errorf("incorrect data %v", data1) | 	expect := []byte{0x01, 0x83, 0x30, 0x30, 0x30} | ||||||
|  | 	if !bytes.Equal(pl, expect) { | ||||||
|  | 		t.Errorf("incorrect payload content, got %x, want %x", pl, expect) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestEncodeDecodeMsg(t *testing.T) { | func TestEncodeDecodeMsg(t *testing.T) { | ||||||
| 	msg, _ := NewMsg(3, 1, "000") | 	msg := NewMsg(3, 1, "000") | ||||||
| 	encoded := msg.Encode(3) | 	buf := new(bytes.Buffer) | ||||||
| 	msg, _ = NewMsgFromBytes(encoded) | 	if err := writeMsg(buf, msg); err != nil { | ||||||
| 	msg.Decode(3) | 		t.Fatalf("encodeMsg error: %v", err) | ||||||
| 	if msg.Code() != 3 { |  | ||||||
| 		t.Errorf("incorrect code %v", msg.Code()) |  | ||||||
| 	} | 	} | ||||||
| 	data0 := msg.Data().Get(0).Uint() | 	// t.Logf("encoded: %x", buf.Bytes()) | ||||||
| 	data1 := msg.Data().Get(1).Str() |  | ||||||
| 	if data0 != 1 { | 	decmsg, err := readMsg(buf) | ||||||
| 		t.Errorf("incorrect data %v", data0) | 	if err != nil { | ||||||
|  | 		t.Fatalf("readMsg error: %v", err) | ||||||
| 	} | 	} | ||||||
| 	if data1 != "000" { | 	if decmsg.Code != 3 { | ||||||
| 		t.Errorf("incorrect data %v", data1) | 		t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) | ||||||
|  | 	} | ||||||
|  | 	if decmsg.Size != 5 { | ||||||
|  | 		t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var data struct { | ||||||
|  | 		I int | ||||||
|  | 		S string | ||||||
|  | 	} | ||||||
|  | 	if err := decmsg.Decode(&data); err != nil { | ||||||
|  | 		t.Fatalf("Decode error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if data.I != 1 { | ||||||
|  | 		t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) | ||||||
|  | 	} | ||||||
|  | 	if data.S != "000" { | ||||||
|  | 		t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecodeRealMsg(t *testing.T) { | ||||||
|  | 	data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") | ||||||
|  | 	msg, err := readMsg(bytes.NewReader(data)) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("unexpected error: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if msg.Code != 0 { | ||||||
|  | 		t.Errorf("incorrect code %d, want %d", msg.Code, 0) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										220
									
								
								p2p/messenger.go
									
									
									
									
									
								
							
							
						
						
									
										220
									
								
								p2p/messenger.go
									
									
									
									
									
								
							| @@ -1,220 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	handlerTimeout = 1000 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type Handlers map[string](func(p *Peer) Protocol) |  | ||||||
|  |  | ||||||
| type Messenger struct { |  | ||||||
| 	conn          *Connection |  | ||||||
| 	peer          *Peer |  | ||||||
| 	handlers      Handlers |  | ||||||
| 	protocolLock  sync.RWMutex |  | ||||||
| 	protocols     []Protocol |  | ||||||
| 	offsets       []MsgCode // offsets for adaptive message idss |  | ||||||
| 	protocolTable map[string]int |  | ||||||
| 	quit          chan chan bool |  | ||||||
| 	err           chan *PeerError |  | ||||||
| 	pulse         chan bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewMessenger(peer *Peer, conn *Connection, errchan chan *PeerError, handlers Handlers) *Messenger { |  | ||||||
| 	baseProtocol := NewBaseProtocol(peer) |  | ||||||
| 	return &Messenger{ |  | ||||||
| 		conn:          conn, |  | ||||||
| 		peer:          peer, |  | ||||||
| 		offsets:       []MsgCode{baseProtocol.Offset()}, |  | ||||||
| 		handlers:      handlers, |  | ||||||
| 		protocols:     []Protocol{baseProtocol}, |  | ||||||
| 		protocolTable: make(map[string]int), |  | ||||||
| 		err:           errchan, |  | ||||||
| 		pulse:         make(chan bool, 1), |  | ||||||
| 		quit:          make(chan chan bool, 1), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) Start() { |  | ||||||
| 	self.conn.Open() |  | ||||||
| 	go self.messenger() |  | ||||||
| 	self.protocolLock.RLock() |  | ||||||
| 	defer self.protocolLock.RUnlock() |  | ||||||
| 	self.protocols[0].Start() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) Stop() { |  | ||||||
| 	// close pulse to stop ping pong monitoring |  | ||||||
| 	close(self.pulse) |  | ||||||
| 	self.protocolLock.RLock() |  | ||||||
| 	defer self.protocolLock.RUnlock() |  | ||||||
| 	for _, protocol := range self.protocols { |  | ||||||
| 		protocol.Stop() // could be parallel |  | ||||||
| 	} |  | ||||||
| 	q := make(chan bool) |  | ||||||
| 	self.quit <- q |  | ||||||
| 	<-q |  | ||||||
| 	self.conn.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) messenger() { |  | ||||||
| 	in := self.conn.Read() |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case payload, ok := <-in: |  | ||||||
| 			//dispatches message to the protocol asynchronously |  | ||||||
| 			if ok { |  | ||||||
| 				go self.handle(payload) |  | ||||||
| 			} else { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		case q := <-self.quit: |  | ||||||
| 			q <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // handles each message by dispatching to the appropriate protocol |  | ||||||
| // using adaptive message codes |  | ||||||
| // this function is started as a separate go routine for each message |  | ||||||
| // it waits for the protocol response |  | ||||||
| // then encodes and sends outgoing messages to the connection's write channel |  | ||||||
| func (self *Messenger) handle(payload []byte) { |  | ||||||
| 	// send ping to heartbeat channel signalling time of last message |  | ||||||
| 	// select { |  | ||||||
| 	// case self.pulse <- true: |  | ||||||
| 	// default: |  | ||||||
| 	// } |  | ||||||
| 	self.pulse <- true |  | ||||||
| 	// initialise message from payload |  | ||||||
| 	msg, err := NewMsgFromBytes(payload) |  | ||||||
| 	if err != nil { |  | ||||||
| 		self.err <- NewPeerError(MiscError, " %v", err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	// retrieves protocol based on message Code |  | ||||||
| 	protocol, offset, peerErr := self.getProtocol(msg.Code()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		self.err <- peerErr |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	// reset message code based on adaptive offset |  | ||||||
| 	msg.Decode(offset) |  | ||||||
| 	// dispatches |  | ||||||
| 	response := make(chan *Msg) |  | ||||||
| 	go protocol.HandleIn(msg, response) |  | ||||||
| 	// protocol reponse timeout to prevent leaks |  | ||||||
| 	timer := time.After(handlerTimeout * time.Millisecond) |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case outgoing, ok := <-response: |  | ||||||
| 			// we check if response channel is not closed |  | ||||||
| 			if ok { |  | ||||||
| 				self.conn.Write() <- outgoing.Encode(offset) |  | ||||||
| 			} else { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		case <-timer: |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // negotiated protocols |  | ||||||
| // stores offsets needed for adaptive message id scheme |  | ||||||
|  |  | ||||||
| // based on offsets set at handshake |  | ||||||
| // get the right protocol to handle the message |  | ||||||
| func (self *Messenger) getProtocol(code MsgCode) (Protocol, MsgCode, *PeerError) { |  | ||||||
| 	self.protocolLock.RLock() |  | ||||||
| 	defer self.protocolLock.RUnlock() |  | ||||||
| 	base := MsgCode(0) |  | ||||||
| 	for index, offset := range self.offsets { |  | ||||||
| 		if code < offset { |  | ||||||
| 			return self.protocols[index], base, nil |  | ||||||
| 		} |  | ||||||
| 		base = offset |  | ||||||
| 	} |  | ||||||
| 	return nil, MsgCode(0), NewPeerError(InvalidMsgCode, " %v", code) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) PingPong(timeout time.Duration, gracePeriod time.Duration, pingCallback func(), timeoutCallback func()) { |  | ||||||
| 	fmt.Printf("pingpong keepalive started at %v", time.Now()) |  | ||||||
|  |  | ||||||
| 	timer := time.After(timeout) |  | ||||||
| 	pinged := false |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case _, ok := <-self.pulse: |  | ||||||
| 			if ok { |  | ||||||
| 				pinged = false |  | ||||||
| 				timer = time.After(timeout) |  | ||||||
| 			} else { |  | ||||||
| 				// pulse is closed, stop monitoring |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		case <-timer: |  | ||||||
| 			if pinged { |  | ||||||
| 				fmt.Printf("timeout at %v", time.Now()) |  | ||||||
| 				timeoutCallback() |  | ||||||
| 				return |  | ||||||
| 			} else { |  | ||||||
| 				fmt.Printf("pinged at %v", time.Now()) |  | ||||||
| 				pingCallback() |  | ||||||
| 				timer = time.After(gracePeriod) |  | ||||||
| 				pinged = true |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) AddProtocols(protocols []string) { |  | ||||||
| 	self.protocolLock.Lock() |  | ||||||
| 	defer self.protocolLock.Unlock() |  | ||||||
| 	i := len(self.offsets) |  | ||||||
| 	offset := self.offsets[i-1] |  | ||||||
| 	for _, name := range protocols { |  | ||||||
| 		protocolFunc, ok := self.handlers[name] |  | ||||||
| 		if ok { |  | ||||||
| 			protocol := protocolFunc(self.peer) |  | ||||||
| 			self.protocolTable[name] = i |  | ||||||
| 			i++ |  | ||||||
| 			offset += protocol.Offset() |  | ||||||
| 			fmt.Println("offset ", name, offset) |  | ||||||
|  |  | ||||||
| 			self.offsets = append(self.offsets, offset) |  | ||||||
| 			self.protocols = append(self.protocols, protocol) |  | ||||||
| 			protocol.Start() |  | ||||||
| 		} else { |  | ||||||
| 			fmt.Println("no ", name) |  | ||||||
| 			// protocol not handled |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Messenger) Write(protocol string, msg *Msg) error { |  | ||||||
| 	self.protocolLock.RLock() |  | ||||||
| 	defer self.protocolLock.RUnlock() |  | ||||||
| 	i := 0 |  | ||||||
| 	offset := MsgCode(0) |  | ||||||
| 	if len(protocol) > 0 { |  | ||||||
| 		var ok bool |  | ||||||
| 		i, ok = self.protocolTable[protocol] |  | ||||||
| 		if !ok { |  | ||||||
| 			return fmt.Errorf("protocol %v not handled by peer", protocol) |  | ||||||
| 		} |  | ||||||
| 		offset = self.offsets[i-1] |  | ||||||
| 	} |  | ||||||
| 	handler := self.protocols[i] |  | ||||||
| 	// checking if protocol status/caps allows the message to be sent out |  | ||||||
| 	if handler.HandleOut(msg) { |  | ||||||
| 		self.conn.Write() <- msg.Encode(offset) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
| @@ -1,147 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	// "fmt" |  | ||||||
| 	"bytes" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/ethereum/go-ethereum/ethutil" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func setupMessenger(handlers Handlers) (*TestNetworkConnection, chan *PeerError, *Messenger) { |  | ||||||
| 	errchan := NewPeerErrorChannel() |  | ||||||
| 	addr := &TestAddr{"test:30303"} |  | ||||||
| 	net := NewTestNetworkConnection(addr) |  | ||||||
| 	conn := NewConnection(net, errchan) |  | ||||||
| 	mess := NewMessenger(nil, conn, errchan, handlers) |  | ||||||
| 	mess.Start() |  | ||||||
| 	return net, errchan, mess |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TestProtocol struct { |  | ||||||
| 	Msgs []*Msg |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) Start() { |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) Stop() { |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) Offset() MsgCode { |  | ||||||
| 	return MsgCode(5) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) HandleIn(msg *Msg, response chan *Msg) { |  | ||||||
| 	self.Msgs = append(self.Msgs, msg) |  | ||||||
| 	close(response) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) HandleOut(msg *Msg) bool { |  | ||||||
| 	if msg.Code() > 3 { |  | ||||||
| 		return false |  | ||||||
| 	} else { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestProtocol) Name() string { |  | ||||||
| 	return "a" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func Packet(offset MsgCode, code MsgCode, params ...interface{}) []byte { |  | ||||||
| 	msg, _ := NewMsg(code, params...) |  | ||||||
| 	encoded := msg.Encode(offset) |  | ||||||
| 	packet := []byte{34, 64, 8, 145} |  | ||||||
| 	packet = append(packet, ethutil.NumberToBytes(uint32(len(encoded)), 32)...) |  | ||||||
| 	return append(packet, encoded...) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestRead(t *testing.T) { |  | ||||||
| 	handlers := make(Handlers) |  | ||||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} |  | ||||||
| 	handlers["a"] = func(p *Peer) Protocol { return testProtocol } |  | ||||||
| 	net, _, mess := setupMessenger(handlers) |  | ||||||
| 	mess.AddProtocols([]string{"a"}) |  | ||||||
| 	defer mess.Stop() |  | ||||||
| 	wait := 1 * time.Millisecond |  | ||||||
| 	packet := Packet(16, 1, uint32(1), "000") |  | ||||||
| 	go net.In(0, packet) |  | ||||||
| 	time.Sleep(wait) |  | ||||||
| 	if len(testProtocol.Msgs) != 1 { |  | ||||||
| 		t.Errorf("msg not relayed to correct protocol") |  | ||||||
| 	} else { |  | ||||||
| 		if testProtocol.Msgs[0].Code() != 1 { |  | ||||||
| 			t.Errorf("incorrect msg code relayed to protocol") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestWrite(t *testing.T) { |  | ||||||
| 	handlers := make(Handlers) |  | ||||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} |  | ||||||
| 	handlers["a"] = func(p *Peer) Protocol { return testProtocol } |  | ||||||
| 	net, _, mess := setupMessenger(handlers) |  | ||||||
| 	mess.AddProtocols([]string{"a"}) |  | ||||||
| 	defer mess.Stop() |  | ||||||
| 	wait := 1 * time.Millisecond |  | ||||||
| 	msg, _ := NewMsg(3, uint32(1), "000") |  | ||||||
| 	err := mess.Write("b", msg) |  | ||||||
| 	if err == nil { |  | ||||||
| 		t.Errorf("expect error for unknown protocol") |  | ||||||
| 	} |  | ||||||
| 	err = mess.Write("a", msg) |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Errorf("expect no error for known protocol: %v", err) |  | ||||||
| 	} else { |  | ||||||
| 		time.Sleep(wait) |  | ||||||
| 		if len(net.Out) != 1 { |  | ||||||
| 			t.Errorf("msg not written") |  | ||||||
| 		} else { |  | ||||||
| 			out := net.Out[0] |  | ||||||
| 			packet := Packet(16, 3, uint32(1), "000") |  | ||||||
| 			if bytes.Compare(out, packet) != 0 { |  | ||||||
| 				t.Errorf("incorrect packet %v", out) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestPulse(t *testing.T) { |  | ||||||
| 	net, _, mess := setupMessenger(make(Handlers)) |  | ||||||
| 	defer mess.Stop() |  | ||||||
| 	ping := false |  | ||||||
| 	timeout := false |  | ||||||
| 	pingTimeout := 10 * time.Millisecond |  | ||||||
| 	gracePeriod := 200 * time.Millisecond |  | ||||||
| 	go mess.PingPong(pingTimeout, gracePeriod, func() { ping = true }, func() { timeout = true }) |  | ||||||
| 	net.In(0, Packet(0, 1)) |  | ||||||
| 	if ping { |  | ||||||
| 		t.Errorf("ping sent too early") |  | ||||||
| 	} |  | ||||||
| 	time.Sleep(pingTimeout + 100*time.Millisecond) |  | ||||||
| 	if !ping { |  | ||||||
| 		t.Errorf("no ping sent after timeout") |  | ||||||
| 	} |  | ||||||
| 	if timeout { |  | ||||||
| 		t.Errorf("timeout too early") |  | ||||||
| 	} |  | ||||||
| 	ping = false |  | ||||||
| 	net.In(0, Packet(0, 1)) |  | ||||||
| 	time.Sleep(pingTimeout + 100*time.Millisecond) |  | ||||||
| 	if !ping { |  | ||||||
| 		t.Errorf("no ping sent after timeout") |  | ||||||
| 	} |  | ||||||
| 	if timeout { |  | ||||||
| 		t.Errorf("timeout too early") |  | ||||||
| 	} |  | ||||||
| 	ping = false |  | ||||||
| 	time.Sleep(gracePeriod) |  | ||||||
| 	if ping { |  | ||||||
| 		t.Errorf("ping called twice") |  | ||||||
| 	} |  | ||||||
| 	if !timeout { |  | ||||||
| 		t.Errorf("no timeout after grace period") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -3,6 +3,7 @@ package p2p | |||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	natpmp "github.com/jackpal/go-nat-pmp" | 	natpmp "github.com/jackpal/go-nat-pmp" | ||||||
| ) | ) | ||||||
| @@ -13,38 +14,37 @@ import ( | |||||||
| //  + Register for changes to the external address. | //  + Register for changes to the external address. | ||||||
| //  + Re-register port mapping when router reboots. | //  + Re-register port mapping when router reboots. | ||||||
| //  + A mechanism for keeping a port mapping registered. | //  + A mechanism for keeping a port mapping registered. | ||||||
|  | //  + Discover gateway address automatically. | ||||||
|  |  | ||||||
| type natPMPClient struct { | type natPMPClient struct { | ||||||
| 	client *natpmp.Client | 	client *natpmp.Client | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewNatPMP(gateway net.IP) (nat NAT) { | // PMP returns a NAT traverser that uses NAT-PMP. The provided gateway | ||||||
|  | // address should be the IP of your router. | ||||||
|  | func PMP(gateway net.IP) (nat NAT) { | ||||||
| 	return &natPMPClient{natpmp.NewClient(gateway)} | 	return &natPMPClient{natpmp.NewClient(gateway)} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *natPMPClient) GetExternalAddress() (addr net.IP, err error) { | func (*natPMPClient) String() string { | ||||||
| 	response, err := n.client.GetExternalAddress() | 	return "NAT-PMP" | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	ip := response.ExternalIPAddress |  | ||||||
| 	addr = net.IPv4(ip[0], ip[1], ip[2], ip[3]) |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *natPMPClient) AddPortMapping(protocol string, externalPort, internalPort int, | func (n *natPMPClient) GetExternalAddress() (net.IP, error) { | ||||||
| 	description string, timeout int) (mappedExternalPort int, err error) { | 	response, err := n.client.GetExternalAddress() | ||||||
| 	if timeout <= 0 { | 	if err != nil { | ||||||
| 		err = fmt.Errorf("timeout must not be <= 0") | 		return nil, err | ||||||
| 		return | 	} | ||||||
|  | 	return response.ExternalIPAddress[:], nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *natPMPClient) AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error { | ||||||
|  | 	if lifetime <= 0 { | ||||||
|  | 		return fmt.Errorf("lifetime must not be <= 0") | ||||||
| 	} | 	} | ||||||
| 	// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. | 	// Note order of port arguments is switched between our AddPortMapping and the client's AddPortMapping. | ||||||
| 	response, err := n.client.AddPortMapping(protocol, internalPort, externalPort, timeout) | 	_, err := n.client.AddPortMapping(protocol, intport, extport, int(lifetime/time.Second)) | ||||||
| 	if err != nil { | 	return err | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	mappedExternalPort = int(response.MappedExternalPort) |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | func (n *natPMPClient) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | ||||||
|   | |||||||
							
								
								
									
										200
									
								
								p2p/natupnp.go
									
									
									
									
									
								
							
							
						
						
									
										200
									
								
								p2p/natupnp.go
									
									
									
									
									
								
							| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/xml" | 	"encoding/xml" | ||||||
| 	"errors" | 	"errors" | ||||||
|  | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| @@ -15,28 +16,46 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	upnpDiscoverAttempts = 3 | ||||||
|  | 	upnpDiscoverTimeout  = 5 * time.Second | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // UPNP returns a NAT port mapper that uses UPnP. It will attempt to | ||||||
|  | // discover the address of your router using UDP broadcasts. | ||||||
|  | func UPNP() NAT { | ||||||
|  | 	return &upnpNAT{} | ||||||
|  | } | ||||||
|  |  | ||||||
| type upnpNAT struct { | type upnpNAT struct { | ||||||
| 	serviceURL string | 	serviceURL string | ||||||
| 	ourIP      string | 	ourIP      string | ||||||
| } | } | ||||||
|  |  | ||||||
| func upnpDiscover(attempts int) (nat NAT, err error) { | func (n *upnpNAT) String() string { | ||||||
|  | 	return "UPNP" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *upnpNAT) discover() error { | ||||||
|  | 	if n.serviceURL != "" { | ||||||
|  | 		// already discovered | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") | 	ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} | 	} | ||||||
|  | 	// TODO: try on all network interfaces simultaneously. | ||||||
|  | 	// Broadcasting on 0.0.0.0 could select a random interface | ||||||
|  | 	// to send on (platform specific). | ||||||
| 	conn, err := net.ListenPacket("udp4", ":0") | 	conn, err := net.ListenPacket("udp4", ":0") | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return err | ||||||
| 	} |  | ||||||
| 	socket := conn.(*net.UDPConn) |  | ||||||
| 	defer socket.Close() |  | ||||||
|  |  | ||||||
| 	err = socket.SetDeadline(time.Now().Add(10 * time.Second)) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	conn.SetDeadline(time.Now().Add(10 * time.Second)) | ||||||
| 	st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" | 	st := "ST: urn:schemas-upnp-org:device:InternetGatewayDevice:1\r\n" | ||||||
| 	buf := bytes.NewBufferString( | 	buf := bytes.NewBufferString( | ||||||
| 		"M-SEARCH * HTTP/1.1\r\n" + | 		"M-SEARCH * HTTP/1.1\r\n" + | ||||||
| @@ -46,19 +65,16 @@ func upnpDiscover(attempts int) (nat NAT, err error) { | |||||||
| 			"MX: 2\r\n\r\n") | 			"MX: 2\r\n\r\n") | ||||||
| 	message := buf.Bytes() | 	message := buf.Bytes() | ||||||
| 	answerBytes := make([]byte, 1024) | 	answerBytes := make([]byte, 1024) | ||||||
| 	for i := 0; i < attempts; i++ { | 	for i := 0; i < upnpDiscoverAttempts; i++ { | ||||||
| 		_, err = socket.WriteToUDP(message, ssdp) | 		_, err = conn.WriteTo(message, ssdp) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return | 			return err | ||||||
| 		} | 		} | ||||||
| 		var n int | 		nn, _, err := conn.ReadFrom(answerBytes) | ||||||
| 		n, _, err = socket.ReadFromUDP(answerBytes) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			continue | 			continue | ||||||
| 			// socket.Close() |  | ||||||
| 			// return |  | ||||||
| 		} | 		} | ||||||
| 		answer := string(answerBytes[0:n]) | 		answer := string(answerBytes[0:nn]) | ||||||
| 		if strings.Index(answer, "\r\n"+st) < 0 { | 		if strings.Index(answer, "\r\n"+st) < 0 { | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| @@ -79,17 +95,81 @@ func upnpDiscover(attempts int) (nat NAT, err error) { | |||||||
| 		var serviceURL string | 		var serviceURL string | ||||||
| 		serviceURL, err = getServiceURL(locURL) | 		serviceURL, err = getServiceURL(locURL) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return | 			return err | ||||||
| 		} | 		} | ||||||
| 		var ourIP string | 		var ourIP string | ||||||
| 		ourIP, err = getOurIP() | 		ourIP, err = getOurIP() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		n.serviceURL = serviceURL | ||||||
|  | 		n.ourIP = ourIP | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	return errors.New("UPnP port discovery failed.") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { | ||||||
|  | 	if err := n.discover(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	info, err := n.getStatusInfo() | ||||||
|  | 	return net.ParseIP(info.externalIpAddress), err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *upnpNAT) AddPortMapping(protocol string, extport, intport int, description string, lifetime time.Duration) error { | ||||||
|  | 	if err := n.discover(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// A single concatenation would break ARM compilation. | ||||||
|  | 	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||||
|  | 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(extport) | ||||||
|  | 	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" | ||||||
|  | 	message += "<NewInternalPort>" + strconv.Itoa(extport) + "</NewInternalPort>" + | ||||||
|  | 		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + | ||||||
|  | 		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" | ||||||
|  | 	message += description + | ||||||
|  | 		"</NewPortMappingDescription><NewLeaseDuration>" + fmt.Sprint(lifetime/time.Second) + | ||||||
|  | 		"</NewLeaseDuration></u:AddPortMapping>" | ||||||
|  |  | ||||||
|  | 	// TODO: check response to see if the port was forwarded | ||||||
|  | 	_, err := soapRequest(n.serviceURL, "AddPortMapping", message) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) error { | ||||||
|  | 	if err := n.discover(); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||||
|  | 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + | ||||||
|  | 		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + | ||||||
|  | 		"</u:DeletePortMapping>" | ||||||
|  |  | ||||||
|  | 	// TODO: check response to see if the port was deleted | ||||||
|  | 	_, err := soapRequest(n.serviceURL, "DeletePortMapping", message) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type statusInfo struct { | ||||||
|  | 	externalIpAddress string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { | ||||||
|  | 	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + | ||||||
|  | 		"</u:GetStatusInfo>" | ||||||
|  |  | ||||||
|  | 	var response *http.Response | ||||||
|  | 	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 		nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP} |  | ||||||
| 		return | 	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags... | ||||||
| 	} |  | ||||||
| 	err = errors.New("UPnP port discovery failed.") | 	response.Body.Close() | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -259,77 +339,3 @@ func soapRequest(url, function, message string) (r *http.Response, err error) { | |||||||
| 	} | 	} | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
|  |  | ||||||
| type statusInfo struct { |  | ||||||
| 	externalIpAddress string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (n *upnpNAT) getStatusInfo() (info statusInfo, err error) { |  | ||||||
|  |  | ||||||
| 	message := "<u:GetStatusInfo xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |  | ||||||
| 		"</u:GetStatusInfo>" |  | ||||||
|  |  | ||||||
| 	var response *http.Response |  | ||||||
| 	response, err = soapRequest(n.serviceURL, "GetStatusInfo", message) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// TODO: Write a soap reply parser. It has to eat the Body and envelope tags... |  | ||||||
|  |  | ||||||
| 	response.Body.Close() |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { |  | ||||||
| 	info, err := n.getStatusInfo() |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	addr = net.ParseIP(info.externalIpAddress) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { |  | ||||||
| 	// A single concatenation would break ARM compilation. |  | ||||||
| 	message := "<u:AddPortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |  | ||||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) |  | ||||||
| 	message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" |  | ||||||
| 	message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + |  | ||||||
| 		"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + |  | ||||||
| 		"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" |  | ||||||
| 	message += description + |  | ||||||
| 		"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + |  | ||||||
| 		"</NewLeaseDuration></u:AddPortMapping>" |  | ||||||
|  |  | ||||||
| 	var response *http.Response |  | ||||||
| 	response, err = soapRequest(n.serviceURL, "AddPortMapping", message) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// TODO: check response to see if the port was forwarded |  | ||||||
| 	// log.Println(message, response) |  | ||||||
| 	mappedExternalPort = externalPort |  | ||||||
| 	_ = response |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { |  | ||||||
|  |  | ||||||
| 	message := "<u:DeletePortMapping xmlns:u=\"urn:schemas-upnp-org:service:WANIPConnection:1\">\r\n" + |  | ||||||
| 		"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + |  | ||||||
| 		"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + |  | ||||||
| 		"</u:DeletePortMapping>" |  | ||||||
|  |  | ||||||
| 	var response *http.Response |  | ||||||
| 	response, err = soapRequest(n.serviceURL, "DeletePortMapping", message) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// TODO: check response to see if the port was deleted |  | ||||||
| 	// log.Println(message, response) |  | ||||||
| 	_ = response |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|   | |||||||
							
								
								
									
										196
									
								
								p2p/network.go
									
									
									
									
									
								
							
							
						
						
									
										196
									
								
								p2p/network.go
									
									
									
									
									
								
							| @@ -1,196 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"math/rand" |  | ||||||
| 	"net" |  | ||||||
| 	"strconv" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	DialerTimeout             = 180 //seconds |  | ||||||
| 	KeepAlivePeriod           = 60  //minutes |  | ||||||
| 	portMappingUpdateInterval = 900 // seconds = 15 mins |  | ||||||
| 	upnpDiscoverAttempts      = 3 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Dialer is not an interface in net, so we define one |  | ||||||
| // *net.Dialer conforms to this |  | ||||||
| type Dialer interface { |  | ||||||
| 	Dial(network, address string) (net.Conn, error) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type Network interface { |  | ||||||
| 	Start() error |  | ||||||
| 	Listener(net.Addr) (net.Listener, error) |  | ||||||
| 	Dialer(net.Addr) (Dialer, error) |  | ||||||
| 	NewAddr(string, int) (addr net.Addr, err error) |  | ||||||
| 	ParseAddr(string) (addr net.Addr, err error) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type NAT interface { |  | ||||||
| 	GetExternalAddress() (addr net.IP, err error) |  | ||||||
| 	AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) |  | ||||||
| 	DeletePortMapping(protocol string, externalPort, internalPort int) (err error) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TCPNetwork struct { |  | ||||||
| 	nat     NAT |  | ||||||
| 	natType NATType |  | ||||||
| 	quit    chan chan bool |  | ||||||
| 	ports   chan string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type NATType int |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	NONE = iota |  | ||||||
| 	UPNP |  | ||||||
| 	PMP |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	portMappingTimeout = 1200 // 20 mins |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func NewTCPNetwork(natType NATType) (net *TCPNetwork) { |  | ||||||
| 	return &TCPNetwork{ |  | ||||||
| 		natType: natType, |  | ||||||
| 		ports:   make(chan string), |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) Dialer(addr net.Addr) (Dialer, error) { |  | ||||||
| 	return &net.Dialer{ |  | ||||||
| 		Timeout: DialerTimeout * time.Second, |  | ||||||
| 		// KeepAlive: KeepAlivePeriod * time.Minute, |  | ||||||
| 		LocalAddr: addr, |  | ||||||
| 	}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) Listener(addr net.Addr) (net.Listener, error) { |  | ||||||
| 	if self.natType == UPNP { |  | ||||||
| 		_, port, _ := net.SplitHostPort(addr.String()) |  | ||||||
| 		if self.quit == nil { |  | ||||||
| 			self.quit = make(chan chan bool) |  | ||||||
| 			go self.updatePortMappings() |  | ||||||
| 		} |  | ||||||
| 		self.ports <- port |  | ||||||
| 	} |  | ||||||
| 	return net.Listen(addr.Network(), addr.String()) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) Start() (err error) { |  | ||||||
| 	switch self.natType { |  | ||||||
| 	case NONE: |  | ||||||
| 	case UPNP: |  | ||||||
| 		nat, uerr := upnpDiscover(upnpDiscoverAttempts) |  | ||||||
| 		if uerr != nil { |  | ||||||
| 			err = fmt.Errorf("UPNP failed: ", uerr) |  | ||||||
| 		} else { |  | ||||||
| 			self.nat = nat |  | ||||||
| 		} |  | ||||||
| 	case PMP: |  | ||||||
| 		err = fmt.Errorf("PMP not implemented") |  | ||||||
| 	default: |  | ||||||
| 		err = fmt.Errorf("Invalid NAT type: %v", self.natType) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) Stop() { |  | ||||||
| 	q := make(chan bool) |  | ||||||
| 	self.quit <- q |  | ||||||
| 	<-q |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) addPortMapping(lport int) (err error) { |  | ||||||
| 	_, err = self.nat.AddPortMapping("TCP", lport, lport, "p2p listen port", portMappingTimeout) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Errorf("unable to add port mapping on %v: %v", lport, err) |  | ||||||
| 	} else { |  | ||||||
| 		logger.Debugf("succesfully added port mapping on %v", lport) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) updatePortMappings() { |  | ||||||
| 	timer := time.NewTimer(portMappingUpdateInterval * time.Second) |  | ||||||
| 	lports := []int{} |  | ||||||
| out: |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case port := <-self.ports: |  | ||||||
| 			int64lport, _ := strconv.ParseInt(port, 10, 16) |  | ||||||
| 			lport := int(int64lport) |  | ||||||
| 			if err := self.addPortMapping(lport); err != nil { |  | ||||||
| 				lports = append(lports, lport) |  | ||||||
| 			} |  | ||||||
| 		case <-timer.C: |  | ||||||
| 			for lport := range lports { |  | ||||||
| 				if err := self.addPortMapping(lport); err != nil { |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 		case errc := <-self.quit: |  | ||||||
| 			errc <- true |  | ||||||
| 			break out |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	timer.Stop() |  | ||||||
| 	for lport := range lports { |  | ||||||
| 		if err := self.nat.DeletePortMapping("TCP", lport, lport); err != nil { |  | ||||||
| 			logger.Debugf("unable to remove port mapping on %v: %v", lport, err) |  | ||||||
| 		} else { |  | ||||||
| 			logger.Debugf("succesfully removed port mapping on %v", lport) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) NewAddr(host string, port int) (net.Addr, error) { |  | ||||||
| 	ip, err := self.lookupIP(host) |  | ||||||
| 	if err == nil { |  | ||||||
| 		return &net.TCPAddr{ |  | ||||||
| 			IP:   ip, |  | ||||||
| 			Port: port, |  | ||||||
| 		}, nil |  | ||||||
| 	} |  | ||||||
| 	return nil, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TCPNetwork) ParseAddr(address string) (net.Addr, error) { |  | ||||||
| 	host, port, err := net.SplitHostPort(address) |  | ||||||
| 	if err == nil { |  | ||||||
| 		iport, _ := strconv.Atoi(port) |  | ||||||
| 		addr, e := self.NewAddr(host, iport) |  | ||||||
| 		return addr, e |  | ||||||
| 	} |  | ||||||
| 	return nil, err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (*TCPNetwork) lookupIP(host string) (ip net.IP, err error) { |  | ||||||
| 	if ip = net.ParseIP(host); ip != nil { |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var ips []net.IP |  | ||||||
| 	ips, err = net.LookupIP(host) |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Warnln(err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if len(ips) == 0 { |  | ||||||
| 		err = fmt.Errorf("No IP addresses available for %v", host) |  | ||||||
| 		logger.Warnln(err) |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	if len(ips) > 1 { |  | ||||||
| 		// Pick a random IP address, simulating round-robin DNS. |  | ||||||
| 		rand.Seed(time.Now().UTC().UnixNano()) |  | ||||||
| 		ip = ips[rand.Intn(len(ips))] |  | ||||||
| 	} else { |  | ||||||
| 		ip = ips[0] |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
							
								
								
									
										490
									
								
								p2p/peer.go
									
									
									
									
									
								
							
							
						
						
									
										490
									
								
								p2p/peer.go
									
									
									
									
									
								
							| @@ -1,83 +1,455 @@ | |||||||
| package p2p | package p2p | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"bytes" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"io" | ||||||
|  | 	"io/ioutil" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strconv" | 	"sort" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/event" | ||||||
|  | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Peer struct { | // peerAddr is the structure of a peer list element. | ||||||
| 	// quit      chan chan bool | // It is also a valid net.Addr. | ||||||
| 	Inbound          bool // inbound (via listener) or outbound (via dialout) | type peerAddr struct { | ||||||
| 	Address          net.Addr | 	IP     net.IP | ||||||
| 	Host             []byte | 	Port   uint64 | ||||||
| 	Port             uint16 | 	Pubkey []byte // optional | ||||||
| 	Pubkey           []byte |  | ||||||
| 	Id               string |  | ||||||
| 	Caps             []string |  | ||||||
| 	peerErrorChan    chan *PeerError |  | ||||||
| 	messenger        *Messenger |  | ||||||
| 	peerErrorHandler *PeerErrorHandler |  | ||||||
| 	server           *Server |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Peer) Messenger() *Messenger { | func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { | ||||||
| 	return self.messenger | 	n := addr.Network() | ||||||
| } | 	if n != "tcp" && n != "tcp4" && n != "tcp6" { | ||||||
|  | 		// for testing with non-TCP | ||||||
| func (self *Peer) PeerErrorChan() chan *PeerError { | 		return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey} | ||||||
| 	return self.peerErrorChan |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Peer) Server() *Server { |  | ||||||
| 	return self.server |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewPeer(conn net.Conn, address net.Addr, inbound bool, server *Server) *Peer { |  | ||||||
| 	peerErrorChan := NewPeerErrorChannel() |  | ||||||
| 	host, port, _ := net.SplitHostPort(address.String()) |  | ||||||
| 	intport, _ := strconv.Atoi(port) |  | ||||||
| 	peer := &Peer{ |  | ||||||
| 		Inbound:       inbound, |  | ||||||
| 		Address:       address, |  | ||||||
| 		Port:          uint16(intport), |  | ||||||
| 		Host:          net.ParseIP(host), |  | ||||||
| 		peerErrorChan: peerErrorChan, |  | ||||||
| 		server:        server, |  | ||||||
| 	} | 	} | ||||||
| 	connection := NewConnection(conn, peerErrorChan) | 	ta := addr.(*net.TCPAddr) | ||||||
| 	peer.messenger = NewMessenger(peer, connection, peerErrorChan, server.Handlers()) | 	return &peerAddr{ta.IP, uint64(ta.Port), pubkey} | ||||||
| 	peer.peerErrorHandler = NewPeerErrorHandler(address, server.PeerDisconnect(), peerErrorChan, server.Blacklist()) | } | ||||||
|  |  | ||||||
|  | func (d peerAddr) Network() string { | ||||||
|  | 	if d.IP.To4() != nil { | ||||||
|  | 		return "tcp4" | ||||||
|  | 	} else { | ||||||
|  | 		return "tcp6" | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d peerAddr) String() string { | ||||||
|  | 	return fmt.Sprintf("%v:%d", d.IP, d.Port) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d peerAddr) RlpData() interface{} { | ||||||
|  | 	return []interface{}{d.IP, d.Port, d.Pubkey} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Peer represents a remote peer. | ||||||
|  | type Peer struct { | ||||||
|  | 	// Peers have all the log methods. | ||||||
|  | 	// Use them to display messages related to the peer. | ||||||
|  | 	*logger.Logger | ||||||
|  |  | ||||||
|  | 	infolock   sync.Mutex | ||||||
|  | 	identity   ClientIdentity | ||||||
|  | 	caps       []Cap | ||||||
|  | 	listenAddr *peerAddr // what remote peer is listening on | ||||||
|  | 	dialAddr   *peerAddr // non-nil if dialing | ||||||
|  |  | ||||||
|  | 	// The mutex protects the connection | ||||||
|  | 	// so only one protocol can write at a time. | ||||||
|  | 	writeMu sync.Mutex | ||||||
|  | 	conn    net.Conn | ||||||
|  | 	bufconn *bufio.ReadWriter | ||||||
|  |  | ||||||
|  | 	// These fields maintain the running protocols. | ||||||
|  | 	protocols       []Protocol | ||||||
|  | 	runBaseProtocol bool // for testing | ||||||
|  |  | ||||||
|  | 	runlock sync.RWMutex // protects running | ||||||
|  | 	running map[string]*proto | ||||||
|  |  | ||||||
|  | 	protoWG  sync.WaitGroup | ||||||
|  | 	protoErr chan error | ||||||
|  | 	closed   chan struct{} | ||||||
|  | 	disc     chan DiscReason | ||||||
|  |  | ||||||
|  | 	activity event.TypeMux // for activity events | ||||||
|  |  | ||||||
|  | 	slot int // index into Server peer list | ||||||
|  |  | ||||||
|  | 	// These fields are kept so base protocol can access them. | ||||||
|  | 	// TODO: this should be one or more interfaces | ||||||
|  | 	ourID         ClientIdentity        // client id of the Server | ||||||
|  | 	ourListenAddr *peerAddr             // listen addr of Server, nil if not listening | ||||||
|  | 	newPeerAddr   chan<- *peerAddr      // tell server about received peers | ||||||
|  | 	otherPeers    func() []*Peer        // should return the list of all peers | ||||||
|  | 	pubkeyHook    func(*peerAddr) error // called at end of handshake to validate pubkey | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewPeer returns a peer for testing purposes. | ||||||
|  | func NewPeer(id ClientIdentity, caps []Cap) *Peer { | ||||||
|  | 	conn, _ := net.Pipe() | ||||||
|  | 	peer := newPeer(conn, nil, nil) | ||||||
|  | 	peer.setHandshakeInfo(id, nil, caps) | ||||||
|  | 	close(peer.closed) | ||||||
| 	return peer | 	return peer | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Peer) String() string { | func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||||
| 	var kind string | 	p := newPeer(conn, server.Protocols, dialAddr) | ||||||
| 	if self.Inbound { | 	p.ourID = server.Identity | ||||||
| 		kind = "inbound" | 	p.newPeerAddr = server.peerConnect | ||||||
| 	} else { | 	p.otherPeers = server.Peers | ||||||
|  | 	p.pubkeyHook = server.verifyPeer | ||||||
|  | 	p.runBaseProtocol = true | ||||||
|  |  | ||||||
|  | 	// laddr can be updated concurrently by NAT traversal. | ||||||
|  | 	// newServerPeer must be called with the server lock held. | ||||||
|  | 	if server.laddr != nil { | ||||||
|  | 		p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey()) | ||||||
|  | 	} | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { | ||||||
|  | 	p := &Peer{ | ||||||
|  | 		Logger:    logger.NewLogger("P2P " + conn.RemoteAddr().String()), | ||||||
|  | 		conn:      conn, | ||||||
|  | 		dialAddr:  dialAddr, | ||||||
|  | 		bufconn:   bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), | ||||||
|  | 		protocols: protocols, | ||||||
|  | 		running:   make(map[string]*proto), | ||||||
|  | 		disc:      make(chan DiscReason), | ||||||
|  | 		protoErr:  make(chan error), | ||||||
|  | 		closed:    make(chan struct{}), | ||||||
|  | 	} | ||||||
|  | 	return p | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Identity returns the client identity of the remote peer. The | ||||||
|  | // identity can be nil if the peer has not yet completed the | ||||||
|  | // handshake. | ||||||
|  | func (p *Peer) Identity() ClientIdentity { | ||||||
|  | 	p.infolock.Lock() | ||||||
|  | 	defer p.infolock.Unlock() | ||||||
|  | 	return p.identity | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Caps returns the capabilities (supported subprotocols) of the remote peer. | ||||||
|  | func (p *Peer) Caps() []Cap { | ||||||
|  | 	p.infolock.Lock() | ||||||
|  | 	defer p.infolock.Unlock() | ||||||
|  | 	return p.caps | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { | ||||||
|  | 	p.infolock.Lock() | ||||||
|  | 	p.identity = id | ||||||
|  | 	p.listenAddr = laddr | ||||||
|  | 	p.caps = caps | ||||||
|  | 	p.infolock.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RemoteAddr returns the remote address of the network connection. | ||||||
|  | func (p *Peer) RemoteAddr() net.Addr { | ||||||
|  | 	return p.conn.RemoteAddr() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // LocalAddr returns the local address of the network connection. | ||||||
|  | func (p *Peer) LocalAddr() net.Addr { | ||||||
|  | 	return p.conn.LocalAddr() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Disconnect terminates the peer connection with the given reason. | ||||||
|  | // It returns immediately and does not wait until the connection is closed. | ||||||
|  | func (p *Peer) Disconnect(reason DiscReason) { | ||||||
|  | 	select { | ||||||
|  | 	case p.disc <- reason: | ||||||
|  | 	case <-p.closed: | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // String implements fmt.Stringer. | ||||||
|  | func (p *Peer) String() string { | ||||||
|  | 	kind := "inbound" | ||||||
|  | 	p.infolock.Lock() | ||||||
|  | 	if p.dialAddr != nil { | ||||||
| 		kind = "outbound" | 		kind = "outbound" | ||||||
| 	} | 	} | ||||||
| 	return fmt.Sprintf("%v:%v (%s) v%v %v", self.Host, self.Port, kind, self.Id, self.Caps) | 	p.infolock.Unlock() | ||||||
|  | 	return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Peer) Write(protocol string, msg *Msg) error { | const ( | ||||||
| 	return self.messenger.Write(protocol, msg) | 	// maximum amount of time allowed for reading a message | ||||||
|  | 	msgReadTimeout = 5 * time.Second | ||||||
|  | 	// maximum amount of time allowed for writing a message | ||||||
|  | 	msgWriteTimeout = 5 * time.Second | ||||||
|  | 	// messages smaller than this many bytes will be read at | ||||||
|  | 	// once before passing them to a protocol. | ||||||
|  | 	wholePayloadSize = 64 * 1024 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	inactivityTimeout     = 2 * time.Second | ||||||
|  | 	disconnectGracePeriod = 2 * time.Second | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (p *Peer) loop() (reason DiscReason, err error) { | ||||||
|  | 	defer p.activity.Stop() | ||||||
|  | 	defer p.closeProtocols() | ||||||
|  | 	defer close(p.closed) | ||||||
|  | 	defer p.conn.Close() | ||||||
|  |  | ||||||
|  | 	// read loop | ||||||
|  | 	readMsg := make(chan Msg) | ||||||
|  | 	readErr := make(chan error) | ||||||
|  | 	readNext := make(chan bool, 1) | ||||||
|  | 	protoDone := make(chan struct{}, 1) | ||||||
|  | 	go p.readLoop(readMsg, readErr, readNext) | ||||||
|  | 	readNext <- true | ||||||
|  |  | ||||||
|  | 	if p.runBaseProtocol { | ||||||
|  | 		p.startBaseProtocol() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | loop: | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case msg := <-readMsg: | ||||||
|  | 			// a new message has arrived. | ||||||
|  | 			var wait bool | ||||||
|  | 			if wait, err = p.dispatch(msg, protoDone); err != nil { | ||||||
|  | 				p.Errorf("msg dispatch error: %v\n", err) | ||||||
|  | 				reason = discReasonForError(err) | ||||||
|  | 				break loop | ||||||
|  | 			} | ||||||
|  | 			if !wait { | ||||||
|  | 				// Msg has already been read completely, continue with next message. | ||||||
|  | 				readNext <- true | ||||||
|  | 			} | ||||||
|  | 			p.activity.Post(time.Now()) | ||||||
|  | 		case <-protoDone: | ||||||
|  | 			// protocol has consumed the message payload, | ||||||
|  | 			// we can continue reading from the socket. | ||||||
|  | 			readNext <- true | ||||||
|  |  | ||||||
|  | 		case err := <-readErr: | ||||||
|  | 			// read failed. there is no need to run the | ||||||
|  | 			// polite disconnect sequence because the connection | ||||||
|  | 			// is probably dead anyway. | ||||||
|  | 			// TODO: handle write errors as well | ||||||
|  | 			return DiscNetworkError, err | ||||||
|  | 		case err = <-p.protoErr: | ||||||
|  | 			reason = discReasonForError(err) | ||||||
|  | 			break loop | ||||||
|  | 		case reason = <-p.disc: | ||||||
|  | 			break loop | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// wait for read loop to return. | ||||||
|  | 	close(readNext) | ||||||
|  | 	<-readErr | ||||||
|  | 	// tell the remote end to disconnect | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  | 	go func() { | ||||||
|  | 		p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) | ||||||
|  | 		p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) | ||||||
|  | 		io.Copy(ioutil.Discard, p.conn) | ||||||
|  | 		close(done) | ||||||
|  | 	}() | ||||||
|  | 	select { | ||||||
|  | 	case <-done: | ||||||
|  | 	case <-time.After(disconnectGracePeriod): | ||||||
|  | 	} | ||||||
|  | 	return reason, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Peer) Start() { | func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { | ||||||
| 	self.peerErrorHandler.Start() | 	for _ = range unblock { | ||||||
| 	self.messenger.Start() | 		p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) | ||||||
|  | 		if msg, err := readMsg(p.bufconn); err != nil { | ||||||
|  | 			errc <- err | ||||||
|  | 		} else { | ||||||
|  | 			msgc <- msg | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	close(errc) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *Peer) Stop() { | func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { | ||||||
| 	self.peerErrorHandler.Stop() | 	proto, err := p.getProto(msg.Code) | ||||||
| 	self.messenger.Stop() | 	if err != nil { | ||||||
| 	// q := make(chan bool) | 		return false, err | ||||||
| 	// self.quit <- q | 	} | ||||||
| 	// <-q | 	if msg.Size <= wholePayloadSize { | ||||||
|  | 		// optimization: msg is small enough, read all | ||||||
|  | 		// of it and move on to the next message | ||||||
|  | 		buf, err := ioutil.ReadAll(msg.Payload) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return false, err | ||||||
|  | 		} | ||||||
|  | 		msg.Payload = bytes.NewReader(buf) | ||||||
|  | 		proto.in <- msg | ||||||
|  | 	} else { | ||||||
|  | 		wait = true | ||||||
|  | 		pr := &eofSignal{msg.Payload, protoDone} | ||||||
|  | 		msg.Payload = pr | ||||||
|  | 		proto.in <- msg | ||||||
|  | 	} | ||||||
|  | 	return wait, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Peer) Encode() []interface{} { | func (p *Peer) startBaseProtocol() { | ||||||
| 	return []interface{}{p.Host, p.Port, p.Pubkey} | 	p.runlock.Lock() | ||||||
|  | 	defer p.runlock.Unlock() | ||||||
|  | 	p.running[""] = p.startProto(0, Protocol{ | ||||||
|  | 		Length: baseProtocolLength, | ||||||
|  | 		Run:    runBaseProtocol, | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // startProtocols starts matching named subprotocols. | ||||||
|  | func (p *Peer) startSubprotocols(caps []Cap) { | ||||||
|  | 	sort.Sort(capsByName(caps)) | ||||||
|  |  | ||||||
|  | 	p.runlock.Lock() | ||||||
|  | 	defer p.runlock.Unlock() | ||||||
|  | 	offset := baseProtocolLength | ||||||
|  | outer: | ||||||
|  | 	for _, cap := range caps { | ||||||
|  | 		for _, proto := range p.protocols { | ||||||
|  | 			if proto.Name == cap.Name && | ||||||
|  | 				proto.Version == cap.Version && | ||||||
|  | 				p.running[cap.Name] == nil { | ||||||
|  | 				p.running[cap.Name] = p.startProto(offset, proto) | ||||||
|  | 				offset += proto.Length | ||||||
|  | 				continue outer | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Peer) startProto(offset uint64, impl Protocol) *proto { | ||||||
|  | 	rw := &proto{ | ||||||
|  | 		in:      make(chan Msg), | ||||||
|  | 		offset:  offset, | ||||||
|  | 		maxcode: impl.Length, | ||||||
|  | 		peer:    p, | ||||||
|  | 	} | ||||||
|  | 	p.protoWG.Add(1) | ||||||
|  | 	go func() { | ||||||
|  | 		err := impl.Run(p, rw) | ||||||
|  | 		if err == nil { | ||||||
|  | 			p.Infof("protocol %q returned", impl.Name) | ||||||
|  | 			err = newPeerError(errMisc, "protocol returned") | ||||||
|  | 		} else { | ||||||
|  | 			p.Errorf("protocol %q error: %v\n", impl.Name, err) | ||||||
|  | 		} | ||||||
|  | 		select { | ||||||
|  | 		case p.protoErr <- err: | ||||||
|  | 		case <-p.closed: | ||||||
|  | 		} | ||||||
|  | 		p.protoWG.Done() | ||||||
|  | 	}() | ||||||
|  | 	return rw | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // getProto finds the protocol responsible for handling | ||||||
|  | // the given message code. | ||||||
|  | func (p *Peer) getProto(code uint64) (*proto, error) { | ||||||
|  | 	p.runlock.RLock() | ||||||
|  | 	defer p.runlock.RUnlock() | ||||||
|  | 	for _, proto := range p.running { | ||||||
|  | 		if code >= proto.offset && code < proto.offset+proto.maxcode { | ||||||
|  | 			return proto, nil | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil, newPeerError(errInvalidMsgCode, "%d", code) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Peer) closeProtocols() { | ||||||
|  | 	p.runlock.RLock() | ||||||
|  | 	for _, p := range p.running { | ||||||
|  | 		close(p.in) | ||||||
|  | 	} | ||||||
|  | 	p.runlock.RUnlock() | ||||||
|  | 	p.protoWG.Wait() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // writeProtoMsg sends the given message on behalf of the given named protocol. | ||||||
|  | func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { | ||||||
|  | 	p.runlock.RLock() | ||||||
|  | 	proto, ok := p.running[protoName] | ||||||
|  | 	p.runlock.RUnlock() | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("protocol %s not handled by peer", protoName) | ||||||
|  | 	} | ||||||
|  | 	if msg.Code >= proto.maxcode { | ||||||
|  | 		return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) | ||||||
|  | 	} | ||||||
|  | 	msg.Code += proto.offset | ||||||
|  | 	return p.writeMsg(msg, msgWriteTimeout) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // writeMsg writes a message to the connection. | ||||||
|  | func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error { | ||||||
|  | 	p.writeMu.Lock() | ||||||
|  | 	defer p.writeMu.Unlock() | ||||||
|  | 	p.conn.SetWriteDeadline(time.Now().Add(timeout)) | ||||||
|  | 	if err := writeMsg(p.bufconn, msg); err != nil { | ||||||
|  | 		return newPeerError(errWrite, "%v", err) | ||||||
|  | 	} | ||||||
|  | 	return p.bufconn.Flush() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type proto struct { | ||||||
|  | 	name            string | ||||||
|  | 	in              chan Msg | ||||||
|  | 	maxcode, offset uint64 | ||||||
|  | 	peer            *Peer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rw *proto) WriteMsg(msg Msg) error { | ||||||
|  | 	if msg.Code >= rw.maxcode { | ||||||
|  | 		return newPeerError(errInvalidMsgCode, "not handled") | ||||||
|  | 	} | ||||||
|  | 	msg.Code += rw.offset | ||||||
|  | 	return rw.peer.writeMsg(msg, msgWriteTimeout) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error { | ||||||
|  | 	return rw.WriteMsg(NewMsg(code, data)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rw *proto) ReadMsg() (Msg, error) { | ||||||
|  | 	msg, ok := <-rw.in | ||||||
|  | 	if !ok { | ||||||
|  | 		return msg, io.EOF | ||||||
|  | 	} | ||||||
|  | 	msg.Code -= rw.offset | ||||||
|  | 	return msg, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // eofSignal wraps a reader with eof signaling. | ||||||
|  | // the eof channel is closed when the wrapped reader | ||||||
|  | // reaches EOF. | ||||||
|  | type eofSignal struct { | ||||||
|  | 	wrapped io.Reader | ||||||
|  | 	eof     chan<- struct{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r *eofSignal) Read(buf []byte) (int, error) { | ||||||
|  | 	n, err := r.wrapped.Read(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		r.eof <- struct{}{} // tell Peer that msg has been consumed | ||||||
|  | 	} | ||||||
|  | 	return n, err | ||||||
| } | } | ||||||
|   | |||||||
| @@ -4,73 +4,121 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ErrorCode int |  | ||||||
|  |  | ||||||
| const errorChanCapacity = 10 |  | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	PacketTooShort = iota | 	errMagicTokenMismatch = iota | ||||||
| 	PayloadTooShort | 	errRead | ||||||
| 	MagicTokenMismatch | 	errWrite | ||||||
| 	EmptyPayload | 	errMisc | ||||||
| 	ReadError | 	errInvalidMsgCode | ||||||
| 	WriteError | 	errInvalidMsg | ||||||
| 	MiscError | 	errP2PVersionMismatch | ||||||
| 	InvalidMsgCode | 	errPubkeyMissing | ||||||
| 	InvalidMsg | 	errPubkeyInvalid | ||||||
| 	P2PVersionMismatch | 	errPubkeyForbidden | ||||||
| 	PubkeyMissing | 	errProtocolBreach | ||||||
| 	PubkeyInvalid | 	errPingTimeout | ||||||
| 	PubkeyForbidden | 	errInvalidNetworkId | ||||||
| 	ProtocolBreach | 	errInvalidProtocolVersion | ||||||
| 	PortMismatch |  | ||||||
| 	PingTimeout |  | ||||||
| 	InvalidGenesis |  | ||||||
| 	InvalidNetworkId |  | ||||||
| 	InvalidProtocolVersion |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var errorToString = map[ErrorCode]string{ | var errorToString = map[int]string{ | ||||||
| 	PacketTooShort:         "Packet too short", | 	errMagicTokenMismatch:     "Magic token mismatch", | ||||||
| 	PayloadTooShort:        "Payload too short", | 	errRead:                   "Read error", | ||||||
| 	MagicTokenMismatch:     "Magic token mismatch", | 	errWrite:                  "Write error", | ||||||
| 	EmptyPayload:           "Empty payload", | 	errMisc:                   "Misc error", | ||||||
| 	ReadError:              "Read error", | 	errInvalidMsgCode:         "Invalid message code", | ||||||
| 	WriteError:             "Write error", | 	errInvalidMsg:             "Invalid message", | ||||||
| 	MiscError:              "Misc error", | 	errP2PVersionMismatch:     "P2P Version Mismatch", | ||||||
| 	InvalidMsgCode:         "Invalid message code", | 	errPubkeyMissing:          "Public key missing", | ||||||
| 	InvalidMsg:             "Invalid message", | 	errPubkeyInvalid:          "Public key invalid", | ||||||
| 	P2PVersionMismatch:     "P2P Version Mismatch", | 	errPubkeyForbidden:        "Public key forbidden", | ||||||
| 	PubkeyMissing:          "Public key missing", | 	errProtocolBreach:         "Protocol Breach", | ||||||
| 	PubkeyInvalid:          "Public key invalid", | 	errPingTimeout:            "Ping timeout", | ||||||
| 	PubkeyForbidden:        "Public key forbidden", | 	errInvalidNetworkId:       "Invalid network id", | ||||||
| 	ProtocolBreach:         "Protocol Breach", | 	errInvalidProtocolVersion: "Invalid protocol version", | ||||||
| 	PortMismatch:           "Port mismatch", |  | ||||||
| 	PingTimeout:            "Ping timeout", |  | ||||||
| 	InvalidGenesis:         "Invalid genesis block", |  | ||||||
| 	InvalidNetworkId:       "Invalid network id", |  | ||||||
| 	InvalidProtocolVersion: "Invalid protocol version", |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type PeerError struct { | type peerError struct { | ||||||
| 	Code    ErrorCode | 	Code    int | ||||||
| 	message string | 	message string | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewPeerError(code ErrorCode, format string, v ...interface{}) *PeerError { | func newPeerError(code int, format string, v ...interface{}) *peerError { | ||||||
| 	desc, ok := errorToString[code] | 	desc, ok := errorToString[code] | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		panic("invalid error code") | 		panic("invalid error code") | ||||||
| 	} | 	} | ||||||
| 	format = desc + ": " + format | 	err := &peerError{code, desc} | ||||||
| 	message := fmt.Sprintf(format, v...) | 	if format != "" { | ||||||
| 	return &PeerError{code, message} | 		err.message += ": " + fmt.Sprintf(format, v...) | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *PeerError) Error() string { | func (self *peerError) Error() string { | ||||||
| 	return self.message | 	return self.message | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewPeerErrorChannel() chan *PeerError { | type DiscReason byte | ||||||
| 	return make(chan *PeerError, errorChanCapacity) |  | ||||||
|  | const ( | ||||||
|  | 	DiscRequested           DiscReason = 0x00 | ||||||
|  | 	DiscNetworkError                   = 0x01 | ||||||
|  | 	DiscProtocolError                  = 0x02 | ||||||
|  | 	DiscUselessPeer                    = 0x03 | ||||||
|  | 	DiscTooManyPeers                   = 0x04 | ||||||
|  | 	DiscAlreadyConnected               = 0x05 | ||||||
|  | 	DiscIncompatibleVersion            = 0x06 | ||||||
|  | 	DiscInvalidIdentity                = 0x07 | ||||||
|  | 	DiscQuitting                       = 0x08 | ||||||
|  | 	DiscUnexpectedIdentity             = 0x09 | ||||||
|  | 	DiscSelf                           = 0x0a | ||||||
|  | 	DiscReadTimeout                    = 0x0b | ||||||
|  | 	DiscSubprotocolError               = 0x10 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var discReasonToString = [DiscSubprotocolError + 1]string{ | ||||||
|  | 	DiscRequested:           "Disconnect requested", | ||||||
|  | 	DiscNetworkError:        "Network error", | ||||||
|  | 	DiscProtocolError:       "Breach of protocol", | ||||||
|  | 	DiscUselessPeer:         "Useless peer", | ||||||
|  | 	DiscTooManyPeers:        "Too many peers", | ||||||
|  | 	DiscAlreadyConnected:    "Already connected", | ||||||
|  | 	DiscIncompatibleVersion: "Incompatible P2P protocol version", | ||||||
|  | 	DiscInvalidIdentity:     "Invalid node identity", | ||||||
|  | 	DiscQuitting:            "Client quitting", | ||||||
|  | 	DiscUnexpectedIdentity:  "Unexpected identity", | ||||||
|  | 	DiscSelf:                "Connected to self", | ||||||
|  | 	DiscReadTimeout:         "Read timeout", | ||||||
|  | 	DiscSubprotocolError:    "Subprotocol error", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d DiscReason) String() string { | ||||||
|  | 	if len(discReasonToString) < int(d) { | ||||||
|  | 		return fmt.Sprintf("Unknown Reason(%d)", d) | ||||||
|  | 	} | ||||||
|  | 	return discReasonToString[d] | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func discReasonForError(err error) DiscReason { | ||||||
|  | 	peerError, ok := err.(*peerError) | ||||||
|  | 	if !ok { | ||||||
|  | 		return DiscSubprotocolError | ||||||
|  | 	} | ||||||
|  | 	switch peerError.Code { | ||||||
|  | 	case errP2PVersionMismatch: | ||||||
|  | 		return DiscIncompatibleVersion | ||||||
|  | 	case errPubkeyMissing, errPubkeyInvalid: | ||||||
|  | 		return DiscInvalidIdentity | ||||||
|  | 	case errPubkeyForbidden: | ||||||
|  | 		return DiscUselessPeer | ||||||
|  | 	case errInvalidMsgCode, errMagicTokenMismatch, errProtocolBreach: | ||||||
|  | 		return DiscProtocolError | ||||||
|  | 	case errPingTimeout: | ||||||
|  | 		return DiscReadTimeout | ||||||
|  | 	case errRead, errWrite, errMisc: | ||||||
|  | 		return DiscNetworkError | ||||||
|  | 	default: | ||||||
|  | 		return DiscSubprotocolError | ||||||
|  | 	} | ||||||
| } | } | ||||||
|   | |||||||
| @@ -1,101 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"net" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	severityThreshold = 10 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type DisconnectRequest struct { |  | ||||||
| 	addr   net.Addr |  | ||||||
| 	reason DiscReason |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type PeerErrorHandler struct { |  | ||||||
| 	quit           chan chan bool |  | ||||||
| 	address        net.Addr |  | ||||||
| 	peerDisconnect chan DisconnectRequest |  | ||||||
| 	severity       int |  | ||||||
| 	peerErrorChan  chan *PeerError |  | ||||||
| 	blacklist      Blacklist |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewPeerErrorHandler(address net.Addr, peerDisconnect chan DisconnectRequest, peerErrorChan chan *PeerError, blacklist Blacklist) *PeerErrorHandler { |  | ||||||
| 	return &PeerErrorHandler{ |  | ||||||
| 		quit:           make(chan chan bool), |  | ||||||
| 		address:        address, |  | ||||||
| 		peerDisconnect: peerDisconnect, |  | ||||||
| 		peerErrorChan:  peerErrorChan, |  | ||||||
| 		blacklist:      blacklist, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *PeerErrorHandler) Start() { |  | ||||||
| 	go self.listen() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *PeerErrorHandler) Stop() { |  | ||||||
| 	q := make(chan bool) |  | ||||||
| 	self.quit <- q |  | ||||||
| 	<-q |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *PeerErrorHandler) listen() { |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case peerError, ok := <-self.peerErrorChan: |  | ||||||
| 			if ok { |  | ||||||
| 				logger.Debugf("error %v\n", peerError) |  | ||||||
| 				go self.handle(peerError) |  | ||||||
| 			} else { |  | ||||||
| 				return |  | ||||||
| 			} |  | ||||||
| 		case q := <-self.quit: |  | ||||||
| 			q <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *PeerErrorHandler) handle(peerError *PeerError) { |  | ||||||
| 	reason := DiscReason(' ') |  | ||||||
| 	switch peerError.Code { |  | ||||||
| 	case P2PVersionMismatch: |  | ||||||
| 		reason = DiscIncompatibleVersion |  | ||||||
| 	case PubkeyMissing, PubkeyInvalid: |  | ||||||
| 		reason = DiscInvalidIdentity |  | ||||||
| 	case PubkeyForbidden: |  | ||||||
| 		reason = DiscUselessPeer |  | ||||||
| 	case InvalidMsgCode, PacketTooShort, PayloadTooShort, MagicTokenMismatch, EmptyPayload, ProtocolBreach: |  | ||||||
| 		reason = DiscProtocolError |  | ||||||
| 	case PingTimeout: |  | ||||||
| 		reason = DiscReadTimeout |  | ||||||
| 	case WriteError, MiscError: |  | ||||||
| 		reason = DiscNetworkError |  | ||||||
| 	case InvalidGenesis, InvalidNetworkId, InvalidProtocolVersion: |  | ||||||
| 		reason = DiscSubprotocolError |  | ||||||
| 	default: |  | ||||||
| 		self.severity += self.getSeverity(peerError) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if self.severity >= severityThreshold { |  | ||||||
| 		reason = DiscSubprotocolError |  | ||||||
| 	} |  | ||||||
| 	if reason != DiscReason(' ') { |  | ||||||
| 		self.peerDisconnect <- DisconnectRequest{ |  | ||||||
| 			addr:   self.address, |  | ||||||
| 			reason: reason, |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *PeerErrorHandler) getSeverity(peerError *PeerError) int { |  | ||||||
| 	switch peerError.Code { |  | ||||||
| 	case ReadError: |  | ||||||
| 		return 4 //tolerate 3 :) |  | ||||||
| 	default: |  | ||||||
| 		return 1 |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,34 +0,0 @@ | |||||||
| package p2p |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	// "fmt" |  | ||||||
| 	"net" |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestPeerErrorHandler(t *testing.T) { |  | ||||||
| 	address := &net.TCPAddr{IP: net.IP([]byte{1, 2, 3, 4}), Port: 30303} |  | ||||||
| 	peerDisconnect := make(chan DisconnectRequest) |  | ||||||
| 	peerErrorChan := NewPeerErrorChannel() |  | ||||||
| 	peh := NewPeerErrorHandler(address, peerDisconnect, peerErrorChan, NewBlacklist()) |  | ||||||
| 	peh.Start() |  | ||||||
| 	defer peh.Stop() |  | ||||||
| 	for i := 0; i < 11; i++ { |  | ||||||
| 		select { |  | ||||||
| 		case <-peerDisconnect: |  | ||||||
| 			t.Errorf("expected no disconnect request") |  | ||||||
| 		default: |  | ||||||
| 		} |  | ||||||
| 		peerErrorChan <- NewPeerError(MiscError, "") |  | ||||||
| 	} |  | ||||||
| 	time.Sleep(1 * time.Millisecond) |  | ||||||
| 	select { |  | ||||||
| 	case request := <-peerDisconnect: |  | ||||||
| 		if request.addr.String() != address.String() { |  | ||||||
| 			t.Errorf("incorrect address %v != %v", request.addr, address) |  | ||||||
| 		} |  | ||||||
| 	default: |  | ||||||
| 		t.Errorf("expected disconnect request") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										297
									
								
								p2p/peer_test.go
									
									
									
									
									
								
							
							
						
						
									
										297
									
								
								p2p/peer_test.go
									
									
									
									
									
								
							| @@ -1,96 +1,239 @@ | |||||||
| package p2p | package p2p | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bufio" | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" | 	"encoding/hex" | ||||||
| 	// "net" | 	"io/ioutil" | ||||||
|  | 	"net" | ||||||
|  | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestPeer(t *testing.T) { | var discard = Protocol{ | ||||||
| 	handlers := make(Handlers) | 	Name:   "discard", | ||||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | 	Length: 1, | ||||||
| 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } | 	Run: func(p *Peer, rw MsgReadWriter) error { | ||||||
| 	handlers["ccc"] = func(p *Peer) Protocol { return testProtocol } | 		for { | ||||||
| 	addr := &TestAddr{"test:30"} | 			msg, err := rw.ReadMsg() | ||||||
| 	conn := NewTestNetworkConnection(addr) |  | ||||||
| 	_, server := SetupTestServer(handlers) |  | ||||||
| 	server.Handshake() |  | ||||||
| 	peer := NewPeer(conn, addr, true, server) |  | ||||||
| 	// peer.Messenger().AddProtocols([]string{"aaa", "ccc"}) |  | ||||||
| 	peer.Start() |  | ||||||
| 	defer peer.Stop() |  | ||||||
| 	time.Sleep(2 * time.Millisecond) |  | ||||||
| 	if len(conn.Out) != 1 { |  | ||||||
| 		t.Errorf("handshake not sent") |  | ||||||
| 	} else { |  | ||||||
| 		out := conn.Out[0] |  | ||||||
| 		packet := Packet(0, HandshakeMsg, P2PVersion, []byte(peer.server.identity.String()), []interface{}{peer.server.protocols}, peer.server.port, peer.server.identity.Pubkey()[1:]) |  | ||||||
| 		if bytes.Compare(out, packet) != 0 { |  | ||||||
| 			t.Errorf("incorrect handshake packet %v != %v", out, packet) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	packet := Packet(0, HandshakeMsg, P2PVersion, []byte("peer"), []interface{}{"bbb", "aaa", "ccc"}, 30, []byte("0000000000000000000000000000000000000000000000000000000000000000")) |  | ||||||
| 	conn.In(0, packet) |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
|  |  | ||||||
| 	pro, _ := peer.Messenger().protocols[0].(*BaseProtocol) |  | ||||||
| 	if pro.state != handshakeReceived { |  | ||||||
| 		t.Errorf("handshake not received") |  | ||||||
| 	} |  | ||||||
| 	if peer.Port != 30 { |  | ||||||
| 		t.Errorf("port incorrectly set") |  | ||||||
| 	} |  | ||||||
| 	if peer.Id != "peer" { |  | ||||||
| 		t.Errorf("id incorrectly set") |  | ||||||
| 	} |  | ||||||
| 	if string(peer.Pubkey) != "0000000000000000000000000000000000000000000000000000000000000000" { |  | ||||||
| 		t.Errorf("pubkey incorrectly set") |  | ||||||
| 	} |  | ||||||
| 	fmt.Println(peer.Caps) |  | ||||||
| 	if len(peer.Caps) != 3 || peer.Caps[0] != "aaa" || peer.Caps[1] != "bbb" || peer.Caps[2] != "ccc" { |  | ||||||
| 		t.Errorf("protocols incorrectly set") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	msg, _ := NewMsg(3) |  | ||||||
| 	err := peer.Write("aaa", msg) |  | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 		t.Errorf("expect no error for known protocol: %v", err) | 				return err | ||||||
| 	} else { | 			} | ||||||
| 		time.Sleep(1 * time.Millisecond) | 			if err = msg.Discard(); err != nil { | ||||||
| 		if len(conn.Out) != 2 { | 				return err | ||||||
| 			t.Errorf("msg not written") |  | ||||||
| 		} else { |  | ||||||
| 			out := conn.Out[1] |  | ||||||
| 			packet := Packet(16, 3) |  | ||||||
| 			if bytes.Compare(out, packet) != 0 { |  | ||||||
| 				t.Errorf("incorrect packet %v != %v", out, packet) |  | ||||||
| 			} |  | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | 	}, | ||||||
|  | } | ||||||
|  |  | ||||||
| 	msg, _ = NewMsg(2) | func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { | ||||||
| 	err = peer.Write("ccc", msg) | 	conn1, conn2 := net.Pipe() | ||||||
|  | 	id := NewSimpleClientIdentity("test", "0", "0", "public key") | ||||||
|  | 	peer := newPeer(conn1, protos, nil) | ||||||
|  | 	peer.ourID = id | ||||||
|  | 	peer.pubkeyHook = func(*peerAddr) error { return nil } | ||||||
|  | 	errc := make(chan error, 1) | ||||||
|  | 	go func() { | ||||||
|  | 		_, err := peer.loop() | ||||||
|  | 		errc <- err | ||||||
|  | 	}() | ||||||
|  | 	return conn2, peer, errc | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerProtoReadMsg(t *testing.T) { | ||||||
|  | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  | 	proto := Protocol{ | ||||||
|  | 		Name:   "a", | ||||||
|  | 		Length: 5, | ||||||
|  | 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||||
|  | 			msg, err := rw.ReadMsg() | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 		t.Errorf("expect no error for known protocol: %v", err) | 				t.Errorf("read error: %v", err) | ||||||
| 	} else { |  | ||||||
| 		time.Sleep(1 * time.Millisecond) |  | ||||||
| 		if len(conn.Out) != 3 { |  | ||||||
| 			t.Errorf("msg not written") |  | ||||||
| 		} else { |  | ||||||
| 			out := conn.Out[2] |  | ||||||
| 			packet := Packet(21, 2) |  | ||||||
| 			if bytes.Compare(out, packet) != 0 { |  | ||||||
| 				t.Errorf("incorrect packet %v != %v", out, packet) |  | ||||||
| 			} | 			} | ||||||
|  | 			if msg.Code != 2 { | ||||||
|  | 				t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) | ||||||
| 			} | 			} | ||||||
|  | 			data, err := ioutil.ReadAll(msg.Payload) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("payload read error: %v", err) | ||||||
|  | 			} | ||||||
|  | 			expdata, _ := hex.DecodeString("0183303030") | ||||||
|  | 			if !bytes.Equal(expdata, data) { | ||||||
|  | 				t.Errorf("incorrect msg data %x", data) | ||||||
|  | 			} | ||||||
|  | 			close(done) | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	err = peer.Write("bbb", msg) | 	net, peer, errc := testPeer([]Protocol{proto}) | ||||||
| 	time.Sleep(1 * time.Millisecond) | 	defer net.Close() | ||||||
| 	if err == nil { | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
| 		t.Errorf("expect error for unknown protocol") |  | ||||||
|  | 	writeMsg(net, NewMsg(18, 1, "000")) | ||||||
|  | 	select { | ||||||
|  | 	case <-done: | ||||||
|  | 	case err := <-errc: | ||||||
|  | 		t.Errorf("peer returned: %v", err) | ||||||
|  | 	case <-time.After(2 * time.Second): | ||||||
|  | 		t.Errorf("receive timeout") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestPeerProtoReadLargeMsg(t *testing.T) { | ||||||
|  | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
|  | 	msgsize := uint32(10 * 1024 * 1024) | ||||||
|  | 	done := make(chan struct{}) | ||||||
|  | 	proto := Protocol{ | ||||||
|  | 		Name:   "a", | ||||||
|  | 		Length: 5, | ||||||
|  | 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||||
|  | 			msg, err := rw.ReadMsg() | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Errorf("read error: %v", err) | ||||||
|  | 			} | ||||||
|  | 			if msg.Size != msgsize+4 { | ||||||
|  | 				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) | ||||||
|  | 			} | ||||||
|  | 			msg.Discard() | ||||||
|  | 			close(done) | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	net, peer, errc := testPeer([]Protocol{proto}) | ||||||
|  | 	defer net.Close() | ||||||
|  | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
|  |  | ||||||
|  | 	writeMsg(net, NewMsg(18, make([]byte, msgsize))) | ||||||
|  | 	select { | ||||||
|  | 	case <-done: | ||||||
|  | 	case err := <-errc: | ||||||
|  | 		t.Errorf("peer returned: %v", err) | ||||||
|  | 	case <-time.After(2 * time.Second): | ||||||
|  | 		t.Errorf("receive timeout") | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerProtoEncodeMsg(t *testing.T) { | ||||||
|  | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
|  | 	proto := Protocol{ | ||||||
|  | 		Name:   "a", | ||||||
|  | 		Length: 2, | ||||||
|  | 		Run: func(peer *Peer, rw MsgReadWriter) error { | ||||||
|  | 			if err := rw.EncodeMsg(2); err == nil { | ||||||
|  | 				t.Error("expected error for out-of-range msg code, got nil") | ||||||
|  | 			} | ||||||
|  | 			if err := rw.EncodeMsg(1); err != nil { | ||||||
|  | 				t.Errorf("write error: %v", err) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	net, peer, _ := testPeer([]Protocol{proto}) | ||||||
|  | 	defer net.Close() | ||||||
|  | 	peer.startSubprotocols([]Cap{proto.cap()}) | ||||||
|  |  | ||||||
|  | 	bufr := bufio.NewReader(net) | ||||||
|  | 	msg, err := readMsg(bufr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("read error: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if msg.Code != 17 { | ||||||
|  | 		t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerWrite(t *testing.T) { | ||||||
|  | 	defer testlog(t).detach() | ||||||
|  |  | ||||||
|  | 	net, peer, peerErr := testPeer([]Protocol{discard}) | ||||||
|  | 	defer net.Close() | ||||||
|  | 	peer.startSubprotocols([]Cap{discard.cap()}) | ||||||
|  |  | ||||||
|  | 	// test write errors | ||||||
|  | 	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { | ||||||
|  | 		t.Errorf("expected error for unknown protocol, got nil") | ||||||
|  | 	} | ||||||
|  | 	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { | ||||||
|  | 		t.Errorf("expected error for out-of-range msg code, got nil") | ||||||
|  | 	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { | ||||||
|  | 		t.Errorf("wrong error for out-of-range msg code, got %#v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// setup for reading the message on the other end | ||||||
|  | 	read := make(chan struct{}) | ||||||
|  | 	go func() { | ||||||
|  | 		bufr := bufio.NewReader(net) | ||||||
|  | 		msg, err := readMsg(bufr) | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Errorf("read error: %v", err) | ||||||
|  | 		} else if msg.Code != 16 { | ||||||
|  | 			t.Errorf("wrong code, got %d, expected %d", msg.Code, 16) | ||||||
|  | 		} | ||||||
|  | 		msg.Discard() | ||||||
|  | 		close(read) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	// test succcessful write | ||||||
|  | 	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { | ||||||
|  | 		t.Errorf("expect no error for known protocol: %v", err) | ||||||
|  | 	} | ||||||
|  | 	select { | ||||||
|  | 	case <-read: | ||||||
|  | 	case err := <-peerErr: | ||||||
|  | 		t.Fatalf("peer stopped: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestPeerActivity(t *testing.T) { | ||||||
|  | 	// shorten inactivityTimeout while this test is running | ||||||
|  | 	oldT := inactivityTimeout | ||||||
|  | 	defer func() { inactivityTimeout = oldT }() | ||||||
|  | 	inactivityTimeout = 20 * time.Millisecond | ||||||
|  |  | ||||||
|  | 	net, peer, peerErr := testPeer([]Protocol{discard}) | ||||||
|  | 	defer net.Close() | ||||||
|  | 	peer.startSubprotocols([]Cap{discard.cap()}) | ||||||
|  |  | ||||||
|  | 	sub := peer.activity.Subscribe(time.Time{}) | ||||||
|  | 	defer sub.Unsubscribe() | ||||||
|  |  | ||||||
|  | 	for i := 0; i < 6; i++ { | ||||||
|  | 		writeMsg(net, NewMsg(16)) | ||||||
|  | 		select { | ||||||
|  | 		case <-sub.Chan(): | ||||||
|  | 		case <-time.After(inactivityTimeout / 2): | ||||||
|  | 			t.Fatal("no event within ", inactivityTimeout/2) | ||||||
|  | 		case err := <-peerErr: | ||||||
|  | 			t.Fatal("peer error", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case <-time.After(inactivityTimeout * 2): | ||||||
|  | 	case <-sub.Chan(): | ||||||
|  | 		t.Fatal("got activity event while connection was inactive") | ||||||
|  | 	case err := <-peerErr: | ||||||
|  | 		t.Fatal("peer error", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNewPeer(t *testing.T) { | ||||||
|  | 	id := NewSimpleClientIdentity("clientid", "version", "customid", "pubkey") | ||||||
|  | 	caps := []Cap{{"foo", 2}, {"bar", 3}} | ||||||
|  | 	p := NewPeer(id, caps) | ||||||
|  | 	if !reflect.DeepEqual(p.Caps(), caps) { | ||||||
|  | 		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) | ||||||
|  | 	} | ||||||
|  | 	if p.Identity() != id { | ||||||
|  | 		t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) | ||||||
|  | 	} | ||||||
|  | 	// Should not hang. | ||||||
|  | 	p.Disconnect(DiscAlreadyConnected) | ||||||
|  | } | ||||||
|   | |||||||
							
								
								
									
										495
									
								
								p2p/protocol.go
									
									
									
									
									
								
							
							
						
						
									
										495
									
								
								p2p/protocol.go
									
									
									
									
									
								
							| @@ -2,277 +2,294 @@ package p2p | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" |  | ||||||
| 	"net" |  | ||||||
| 	"sort" |  | ||||||
| 	"sync" |  | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/ethutil" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type Protocol interface { | // Protocol represents a P2P subprotocol implementation. | ||||||
| 	Start() | type Protocol struct { | ||||||
| 	Stop() | 	// Name should contain the official protocol name, | ||||||
| 	HandleIn(*Msg, chan *Msg) | 	// often a three-letter word. | ||||||
| 	HandleOut(*Msg) bool | 	Name string | ||||||
| 	Offset() MsgCode |  | ||||||
| 	Name() string | 	// Version should contain the version number of the protocol. | ||||||
|  | 	Version uint | ||||||
|  |  | ||||||
|  | 	// Length should contain the number of message codes used | ||||||
|  | 	// by the protocol. | ||||||
|  | 	Length uint64 | ||||||
|  |  | ||||||
|  | 	// Run is called in a new groutine when the protocol has been | ||||||
|  | 	// negotiated with a peer. It should read and write messages from | ||||||
|  | 	// rw. The Payload for each message must be fully consumed. | ||||||
|  | 	// | ||||||
|  | 	// The peer connection is closed when Start returns. It should return | ||||||
|  | 	// any protocol-level error (such as an I/O error) that is | ||||||
|  | 	// encountered. | ||||||
|  | 	Run func(peer *Peer, rw MsgReadWriter) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p Protocol) cap() Cap { | ||||||
|  | 	return Cap{p.Name, p.Version} | ||||||
| } | } | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	P2PVersion      = 0 | 	baseProtocolVersion    = 2 | ||||||
| 	pingTimeout     = 2 | 	baseProtocolLength     = uint64(16) | ||||||
| 	pingGracePeriod = 2 | 	baseProtocolMaxMsgSize = 10 * 1024 * 1024 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	HandshakeMsg = iota | 	// devp2p message codes | ||||||
| 	DiscMsg | 	handshakeMsg = 0x00 | ||||||
| 	PingMsg | 	discMsg      = 0x01 | ||||||
| 	PongMsg | 	pingMsg      = 0x02 | ||||||
| 	GetPeersMsg | 	pongMsg      = 0x03 | ||||||
| 	PeersMsg | 	getPeersMsg  = 0x04 | ||||||
| 	offset = 16 | 	peersMsg     = 0x05 | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type ProtocolState uint8 | // handshake is the structure of a handshake list. | ||||||
|  | type handshake struct { | ||||||
| const ( | 	Version    uint64 | ||||||
| 	nullState = iota | 	ID         string | ||||||
| 	handshakeReceived | 	Caps       []Cap | ||||||
| ) | 	ListenPort uint64 | ||||||
|  | 	NodeID     []byte | ||||||
| type DiscReason byte |  | ||||||
|  |  | ||||||
| const ( |  | ||||||
| 	// Values are given explicitly instead of by iota because these values are |  | ||||||
| 	// defined by the wire protocol spec; it is easier for humans to ensure |  | ||||||
| 	// correctness when values are explicit. |  | ||||||
| 	DiscRequested           = 0x00 |  | ||||||
| 	DiscNetworkError        = 0x01 |  | ||||||
| 	DiscProtocolError       = 0x02 |  | ||||||
| 	DiscUselessPeer         = 0x03 |  | ||||||
| 	DiscTooManyPeers        = 0x04 |  | ||||||
| 	DiscAlreadyConnected    = 0x05 |  | ||||||
| 	DiscIncompatibleVersion = 0x06 |  | ||||||
| 	DiscInvalidIdentity     = 0x07 |  | ||||||
| 	DiscQuitting            = 0x08 |  | ||||||
| 	DiscUnexpectedIdentity  = 0x09 |  | ||||||
| 	DiscSelf                = 0x0a |  | ||||||
| 	DiscReadTimeout         = 0x0b |  | ||||||
| 	DiscSubprotocolError    = 0x10 |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var discReasonToString = map[DiscReason]string{ |  | ||||||
| 	DiscRequested:           "Disconnect requested", |  | ||||||
| 	DiscNetworkError:        "Network error", |  | ||||||
| 	DiscProtocolError:       "Breach of protocol", |  | ||||||
| 	DiscUselessPeer:         "Useless peer", |  | ||||||
| 	DiscTooManyPeers:        "Too many peers", |  | ||||||
| 	DiscAlreadyConnected:    "Already connected", |  | ||||||
| 	DiscIncompatibleVersion: "Incompatible P2P protocol version", |  | ||||||
| 	DiscInvalidIdentity:     "Invalid node identity", |  | ||||||
| 	DiscQuitting:            "Client quitting", |  | ||||||
| 	DiscUnexpectedIdentity:  "Unexpected identity", |  | ||||||
| 	DiscSelf:                "Connected to self", |  | ||||||
| 	DiscReadTimeout:         "Read timeout", |  | ||||||
| 	DiscSubprotocolError:    "Subprotocol error", |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (d DiscReason) String() string { | func (h *handshake) String() string { | ||||||
| 	if len(discReasonToString) < int(d) { | 	return h.ID | ||||||
| 		return "Unknown" | } | ||||||
| 	} | func (h *handshake) Pubkey() []byte { | ||||||
|  | 	return h.NodeID | ||||||
| 	return discReasonToString[d] |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type BaseProtocol struct { | // Cap is the structure of a peer capability. | ||||||
|  | type Cap struct { | ||||||
|  | 	Name    string | ||||||
|  | 	Version uint | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (cap Cap) RlpData() interface{} { | ||||||
|  | 	return []interface{}{cap.Name, cap.Version} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type capsByName []Cap | ||||||
|  |  | ||||||
|  | func (cs capsByName) Len() int           { return len(cs) } | ||||||
|  | func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } | ||||||
|  | func (cs capsByName) Swap(i, j int)      { cs[i], cs[j] = cs[j], cs[i] } | ||||||
|  |  | ||||||
|  | type baseProtocol struct { | ||||||
|  | 	rw   MsgReadWriter | ||||||
| 	peer *Peer | 	peer *Peer | ||||||
| 	state     ProtocolState |  | ||||||
| 	stateLock sync.RWMutex |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewBaseProtocol(peer *Peer) *BaseProtocol { | func runBaseProtocol(peer *Peer, rw MsgReadWriter) error { | ||||||
| 	self := &BaseProtocol{ | 	bp := &baseProtocol{rw, peer} | ||||||
| 		peer: peer, | 	if err := bp.doHandshake(rw); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	// run main loop | ||||||
|  | 	quit := make(chan error, 1) | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			if err := bp.handle(rw); err != nil { | ||||||
|  | 				quit <- err | ||||||
|  | 				break | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	return bp.loop(quit) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var pingTimeout = 2 * time.Second | ||||||
|  |  | ||||||
|  | func (bp *baseProtocol) loop(quit <-chan error) error { | ||||||
|  | 	ping := time.NewTimer(pingTimeout) | ||||||
|  | 	activity := bp.peer.activity.Subscribe(time.Time{}) | ||||||
|  | 	lastActive := time.Time{} | ||||||
|  | 	defer ping.Stop() | ||||||
|  | 	defer activity.Unsubscribe() | ||||||
|  |  | ||||||
|  | 	getPeersTick := time.NewTicker(10 * time.Second) | ||||||
|  | 	defer getPeersTick.Stop() | ||||||
|  | 	err := bp.rw.EncodeMsg(getPeersMsg) | ||||||
|  |  | ||||||
|  | 	for err == nil { | ||||||
|  | 		select { | ||||||
|  | 		case err = <-quit: | ||||||
|  | 			return err | ||||||
|  | 		case <-getPeersTick.C: | ||||||
|  | 			err = bp.rw.EncodeMsg(getPeersMsg) | ||||||
|  | 		case event := <-activity.Chan(): | ||||||
|  | 			ping.Reset(pingTimeout) | ||||||
|  | 			lastActive = event.(time.Time) | ||||||
|  | 		case t := <-ping.C: | ||||||
|  | 			if lastActive.Add(pingTimeout * 2).Before(t) { | ||||||
|  | 				err = newPeerError(errPingTimeout, "") | ||||||
|  | 			} else if lastActive.Add(pingTimeout).Before(t) { | ||||||
|  | 				err = bp.rw.EncodeMsg(pingMsg) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (bp *baseProtocol) handle(rw MsgReadWriter) error { | ||||||
|  | 	msg, err := rw.ReadMsg() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if msg.Size > baseProtocolMaxMsgSize { | ||||||
|  | 		return newPeerError(errMisc, "message too big") | ||||||
|  | 	} | ||||||
|  | 	// make sure that the payload has been fully consumed | ||||||
|  | 	defer msg.Discard() | ||||||
|  |  | ||||||
|  | 	switch msg.Code { | ||||||
|  | 	case handshakeMsg: | ||||||
|  | 		return newPeerError(errProtocolBreach, "extra handshake received") | ||||||
|  |  | ||||||
|  | 	case discMsg: | ||||||
|  | 		var reason DiscReason | ||||||
|  | 		if err := msg.Decode(&reason); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		bp.peer.Disconnect(reason) | ||||||
|  | 		return nil | ||||||
|  |  | ||||||
|  | 	case pingMsg: | ||||||
|  | 		return bp.rw.EncodeMsg(pongMsg) | ||||||
|  |  | ||||||
|  | 	case pongMsg: | ||||||
|  |  | ||||||
|  | 	case getPeersMsg: | ||||||
|  | 		peers := bp.peerList() | ||||||
|  | 		// this is dangerous. the spec says that we should _delay_ | ||||||
|  | 		// sending the response if no new information is available. | ||||||
|  | 		// this means that would need to send a response later when | ||||||
|  | 		// new peers become available. | ||||||
|  | 		// | ||||||
|  | 		// TODO: add event mechanism to notify baseProtocol for new peers | ||||||
|  | 		if len(peers) > 0 { | ||||||
|  | 			return bp.rw.EncodeMsg(peersMsg, peers) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 	return self | 	case peersMsg: | ||||||
| } | 		var peers []*peerAddr | ||||||
|  | 		if err := msg.Decode(&peers); err != nil { | ||||||
| func (self *BaseProtocol) Start() { | 			return err | ||||||
| 	if self.peer != nil { |  | ||||||
| 		self.peer.Write("", self.peer.Server().Handshake()) |  | ||||||
| 		go self.peer.Messenger().PingPong( |  | ||||||
| 			pingTimeout*time.Second, |  | ||||||
| 			pingGracePeriod*time.Second, |  | ||||||
| 			self.Ping, |  | ||||||
| 			self.Timeout, |  | ||||||
| 		) |  | ||||||
| 		} | 		} | ||||||
| } | 		for _, addr := range peers { | ||||||
|  | 			bp.peer.Debugf("received peer suggestion: %v", addr) | ||||||
| func (self *BaseProtocol) Stop() { | 			bp.peer.newPeerAddr <- addr | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) Ping() { |  | ||||||
| 	msg, _ := NewMsg(PingMsg) |  | ||||||
| 	self.peer.Write("", msg) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) Timeout() { |  | ||||||
| 	self.peerError(PingTimeout, "") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) Name() string { |  | ||||||
| 	return "" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) Offset() MsgCode { |  | ||||||
| 	return offset |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) CheckState(state ProtocolState) bool { |  | ||||||
| 	self.stateLock.RLock() |  | ||||||
| 	self.stateLock.RUnlock() |  | ||||||
| 	if self.state != state { |  | ||||||
| 		return false |  | ||||||
| 	} else { |  | ||||||
| 		return true |  | ||||||
| 		} | 		} | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) HandleIn(msg *Msg, response chan *Msg) { |  | ||||||
| 	if msg.Code() == HandshakeMsg { |  | ||||||
| 		self.handleHandshake(msg) |  | ||||||
| 	} else { |  | ||||||
| 		if !self.CheckState(handshakeReceived) { |  | ||||||
| 			self.peerError(ProtocolBreach, "message code %v not allowed", msg.Code()) |  | ||||||
| 			close(response) |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 		switch msg.Code() { |  | ||||||
| 		case DiscMsg: |  | ||||||
| 			logger.Infof("Disconnect requested from peer %v, reason", DiscReason(msg.Data().Get(0).Uint())) |  | ||||||
| 			self.peer.Server().PeerDisconnect() <- DisconnectRequest{ |  | ||||||
| 				addr:   self.peer.Address, |  | ||||||
| 				reason: DiscRequested, |  | ||||||
| 			} |  | ||||||
| 		case PingMsg: |  | ||||||
| 			out, _ := NewMsg(PongMsg) |  | ||||||
| 			response <- out |  | ||||||
| 		case PongMsg: |  | ||||||
| 		case GetPeersMsg: |  | ||||||
| 			// Peer asked for list of connected peers |  | ||||||
| 			if out, err := self.peer.Server().PeersMessage(); err != nil { |  | ||||||
| 				response <- out |  | ||||||
| 			} |  | ||||||
| 		case PeersMsg: |  | ||||||
| 			self.handlePeers(msg) |  | ||||||
| 	default: | 	default: | ||||||
| 			self.peerError(InvalidMsgCode, "unknown message code %v", msg.Code()) | 		return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code) | ||||||
| 	} | 	} | ||||||
| 	} | 	return nil | ||||||
| 	close(response) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *BaseProtocol) HandleOut(msg *Msg) (allowed bool) { | func (bp *baseProtocol) doHandshake(rw MsgReadWriter) error { | ||||||
| 	// somewhat overly paranoid | 	// send our handshake | ||||||
| 	allowed = msg.Code() == HandshakeMsg || msg.Code() == DiscMsg || msg.Code() < self.Offset() && self.CheckState(handshakeReceived) | 	if err := rw.WriteMsg(bp.handshakeMsg()); err != nil { | ||||||
| 	return | 		return err | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) peerError(errorCode ErrorCode, format string, v ...interface{}) { |  | ||||||
| 	err := NewPeerError(errorCode, format, v...) |  | ||||||
| 	logger.Warnln(err) |  | ||||||
| 	fmt.Println(self.peer, err) |  | ||||||
| 	if self.peer != nil { |  | ||||||
| 		self.peer.PeerErrorChan() <- err |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) handlePeers(msg *Msg) { |  | ||||||
| 	it := msg.Data().NewIterator() |  | ||||||
| 	for it.Next() { |  | ||||||
| 		ip := net.IP(it.Value().Get(0).Bytes()) |  | ||||||
| 		port := it.Value().Get(1).Uint() |  | ||||||
| 		address := &net.TCPAddr{IP: ip, Port: int(port)} |  | ||||||
| 		go self.peer.Server().PeerConnect(address) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *BaseProtocol) handleHandshake(msg *Msg) { |  | ||||||
| 	self.stateLock.Lock() |  | ||||||
| 	defer self.stateLock.Unlock() |  | ||||||
| 	if self.state != nullState { |  | ||||||
| 		self.peerError(ProtocolBreach, "extra handshake") |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c := msg.Data() | 	// read and handle remote handshake | ||||||
|  | 	msg, err := rw.ReadMsg() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if msg.Code != handshakeMsg { | ||||||
|  | 		return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code) | ||||||
|  | 	} | ||||||
|  | 	if msg.Size > baseProtocolMaxMsgSize { | ||||||
|  | 		return newPeerError(errMisc, "message too big") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var hs handshake | ||||||
|  | 	if err := msg.Decode(&hs); err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// validate handshake info | ||||||
|  | 	if hs.Version != baseProtocolVersion { | ||||||
|  | 		return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n", | ||||||
|  | 			baseProtocolVersion, hs.Version) | ||||||
|  | 	} | ||||||
|  | 	if len(hs.NodeID) == 0 { | ||||||
|  | 		return newPeerError(errPubkeyMissing, "") | ||||||
|  | 	} | ||||||
|  | 	if len(hs.NodeID) != 64 { | ||||||
|  | 		return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8) | ||||||
|  | 	} | ||||||
|  | 	if da := bp.peer.dialAddr; da != nil { | ||||||
|  | 		// verify that the peer we wanted to connect to | ||||||
|  | 		// actually holds the target public key. | ||||||
|  | 		if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) { | ||||||
|  | 			return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) | ||||||
|  | 	if err := bp.peer.pubkeyHook(pa); err != nil { | ||||||
|  | 		return newPeerError(errPubkeyForbidden, "%v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO: remove Caps with empty name | ||||||
|  |  | ||||||
|  | 	var addr *peerAddr | ||||||
|  | 	if hs.ListenPort != 0 { | ||||||
|  | 		addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID) | ||||||
|  | 		addr.Port = hs.ListenPort | ||||||
|  | 	} | ||||||
|  | 	bp.peer.setHandshakeInfo(&hs, addr, hs.Caps) | ||||||
|  | 	bp.peer.startSubprotocols(hs.Caps) | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (bp *baseProtocol) handshakeMsg() Msg { | ||||||
| 	var ( | 	var ( | ||||||
| 		p2pVersion = c.Get(0).Uint() | 		port uint64 | ||||||
| 		id         = c.Get(1).Str() | 		caps []interface{} | ||||||
| 		caps       = c.Get(2) |  | ||||||
| 		port       = c.Get(3).Uint() |  | ||||||
| 		pubkey     = c.Get(4).Bytes() |  | ||||||
| 	) | 	) | ||||||
| 	fmt.Printf("handshake received %v, %v, %v, %v, %v ", p2pVersion, id, caps, port, pubkey) | 	if bp.peer.ourListenAddr != nil { | ||||||
|  | 		port = bp.peer.ourListenAddr.Port | ||||||
| 	// Check correctness of p2p protocol version |  | ||||||
| 	if p2pVersion != P2PVersion { |  | ||||||
| 		self.peerError(P2PVersionMismatch, "Require protocol %d, received %d\n", P2PVersion, p2pVersion) |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  | 	for _, proto := range bp.peer.protocols { | ||||||
| 	// Handle the pub key (validation, uniqueness) | 		caps = append(caps, proto.cap()) | ||||||
| 	if len(pubkey) == 0 { |  | ||||||
| 		self.peerError(PubkeyMissing, "not supplied in handshake.") |  | ||||||
| 		return |  | ||||||
| 	} | 	} | ||||||
|  | 	return NewMsg(handshakeMsg, | ||||||
| 	if len(pubkey) != 64 { | 		baseProtocolVersion, | ||||||
| 		self.peerError(PubkeyInvalid, "require 512 bit, got %v", len(pubkey)*8) | 		bp.peer.ourID.String(), | ||||||
| 		return | 		caps, | ||||||
| 	} | 		port, | ||||||
|  | 		bp.peer.ourID.Pubkey()[1:], | ||||||
| 	// Self connect detection | 	) | ||||||
| 	if bytes.Compare(self.peer.Server().ClientIdentity().Pubkey()[1:], pubkey) == 0 { | } | ||||||
| 		self.peerError(PubkeyForbidden, "not allowed to connect to self") |  | ||||||
| 		return | func (bp *baseProtocol) peerList() []ethutil.RlpEncodable { | ||||||
| 	} | 	peers := bp.peer.otherPeers() | ||||||
|  | 	ds := make([]ethutil.RlpEncodable, 0, len(peers)) | ||||||
| 	// register pubkey on server. this also sets the pubkey on the peer (need lock) | 	for _, p := range peers { | ||||||
| 	if err := self.peer.Server().RegisterPubkey(self.peer, pubkey); err != nil { | 		p.infolock.Lock() | ||||||
| 		self.peerError(PubkeyForbidden, err.Error()) | 		addr := p.listenAddr | ||||||
| 		return | 		p.infolock.Unlock() | ||||||
| 	} | 		// filter out this peer and peers that are not listening or | ||||||
|  | 		// have not completed the handshake. | ||||||
| 	// check port | 		// TODO: track previously sent peers and exclude them as well. | ||||||
| 	if self.peer.Inbound { | 		if p == bp.peer || addr == nil { | ||||||
| 		uint16port := uint16(port) | 			continue | ||||||
| 		if self.peer.Port > 0 && self.peer.Port != uint16port { | 		} | ||||||
| 			self.peerError(PortMismatch, "port mismatch: %v != %v", self.peer.Port, port) | 		ds = append(ds, addr) | ||||||
| 			return | 	} | ||||||
| 		} else { | 	ourAddr := bp.peer.ourListenAddr | ||||||
| 			self.peer.Port = uint16port | 	if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() { | ||||||
| 		} | 		ds = append(ds, ourAddr) | ||||||
| 	} | 	} | ||||||
|  | 	return ds | ||||||
| 	capsIt := caps.NewIterator() |  | ||||||
| 	for capsIt.Next() { |  | ||||||
| 		cap := capsIt.Value().Str() |  | ||||||
| 		self.peer.Caps = append(self.peer.Caps, cap) |  | ||||||
| 	} |  | ||||||
| 	sort.Strings(self.peer.Caps) |  | ||||||
| 	self.peer.Messenger().AddProtocols(self.peer.Caps) |  | ||||||
|  |  | ||||||
| 	self.peer.Id = id |  | ||||||
|  |  | ||||||
| 	self.state = handshakeReceived |  | ||||||
|  |  | ||||||
| 	//p.ethereum.PushPeer(p) |  | ||||||
| 	// p.ethereum.reactor.Post("peerList", p.ethereum.Peers()) |  | ||||||
| 	return |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										825
									
								
								p2p/server.go
									
									
									
									
									
								
							
							
						
						
									
										825
									
								
								p2p/server.go
									
									
									
									
									
								
							| @@ -2,21 +2,420 @@ package p2p | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
| 	"sort" |  | ||||||
| 	"strconv" |  | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	logpkg "github.com/ethereum/go-ethereum/logger" | 	"github.com/ethereum/go-ethereum/logger" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| 	outboundAddressPoolSize = 10 | 	outboundAddressPoolSize   = 500 | ||||||
| 	disconnectGracePeriod   = 2 | 	defaultDialTimeout        = 10 * time.Second | ||||||
|  | 	portMappingUpdateInterval = 15 * time.Minute | ||||||
|  | 	portMappingTimeout        = 20 * time.Minute | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var srvlog = logger.NewLogger("P2P Server") | ||||||
|  |  | ||||||
|  | // Server manages all peer connections. | ||||||
|  | // | ||||||
|  | // The fields of Server are used as configuration parameters. | ||||||
|  | // You should set them before starting the Server. Fields may not be | ||||||
|  | // modified while the server is running. | ||||||
|  | type Server struct { | ||||||
|  | 	// This field must be set to a valid client identity. | ||||||
|  | 	Identity ClientIdentity | ||||||
|  |  | ||||||
|  | 	// MaxPeers is the maximum number of peers that can be | ||||||
|  | 	// connected. It must be greater than zero. | ||||||
|  | 	MaxPeers int | ||||||
|  |  | ||||||
|  | 	// Protocols should contain the protocols supported | ||||||
|  | 	// by the server. Matching protocols are launched for | ||||||
|  | 	// each peer. | ||||||
|  | 	Protocols []Protocol | ||||||
|  |  | ||||||
|  | 	// If Blacklist is set to a non-nil value, the given Blacklist | ||||||
|  | 	// is used to verify peer connections. | ||||||
|  | 	Blacklist Blacklist | ||||||
|  |  | ||||||
|  | 	// If ListenAddr is set to a non-nil address, the server | ||||||
|  | 	// will listen for incoming connections. | ||||||
|  | 	// | ||||||
|  | 	// If the port is zero, the operating system will pick a port. The | ||||||
|  | 	// ListenAddr field will be updated with the actual address when | ||||||
|  | 	// the server is started. | ||||||
|  | 	ListenAddr string | ||||||
|  |  | ||||||
|  | 	// If set to a non-nil value, the given NAT port mapper | ||||||
|  | 	// is used to make the listening port available to the | ||||||
|  | 	// Internet. | ||||||
|  | 	NAT NAT | ||||||
|  |  | ||||||
|  | 	// If Dialer is set to a non-nil value, the given Dialer | ||||||
|  | 	// is used to dial outbound peer connections. | ||||||
|  | 	Dialer *net.Dialer | ||||||
|  |  | ||||||
|  | 	// If NoDial is true, the server will not dial any peers. | ||||||
|  | 	NoDial bool | ||||||
|  |  | ||||||
|  | 	// Hook for testing. This is useful because we can inhibit | ||||||
|  | 	// the whole protocol stack. | ||||||
|  | 	newPeerFunc peerFunc | ||||||
|  |  | ||||||
|  | 	lock      sync.RWMutex | ||||||
|  | 	running   bool | ||||||
|  | 	listener  net.Listener | ||||||
|  | 	laddr     *net.TCPAddr // real listen addr | ||||||
|  | 	peers     []*Peer | ||||||
|  | 	peerSlots chan int | ||||||
|  | 	peerCount int | ||||||
|  |  | ||||||
|  | 	quit           chan struct{} | ||||||
|  | 	wg             sync.WaitGroup | ||||||
|  | 	peerConnect    chan *peerAddr | ||||||
|  | 	peerDisconnect chan *Peer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NAT is implemented by NAT traversal methods. | ||||||
|  | type NAT interface { | ||||||
|  | 	GetExternalAddress() (net.IP, error) | ||||||
|  | 	AddPortMapping(protocol string, extport, intport int, name string, lifetime time.Duration) error | ||||||
|  | 	DeletePortMapping(protocol string, extport, intport int) error | ||||||
|  |  | ||||||
|  | 	// Should return name of the method. | ||||||
|  | 	String() string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer | ||||||
|  |  | ||||||
|  | // Peers returns all connected peers. | ||||||
|  | func (srv *Server) Peers() (peers []*Peer) { | ||||||
|  | 	srv.lock.RLock() | ||||||
|  | 	defer srv.lock.RUnlock() | ||||||
|  | 	for _, peer := range srv.peers { | ||||||
|  | 		if peer != nil { | ||||||
|  | 			peers = append(peers, peer) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // PeerCount returns the number of connected peers. | ||||||
|  | func (srv *Server) PeerCount() int { | ||||||
|  | 	srv.lock.RLock() | ||||||
|  | 	defer srv.lock.RUnlock() | ||||||
|  | 	return srv.peerCount | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // SuggestPeer injects an address into the outbound address pool. | ||||||
|  | func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { | ||||||
|  | 	select { | ||||||
|  | 	case srv.peerConnect <- &peerAddr{ip, uint64(port), nodeID}: | ||||||
|  | 	default: // don't block | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Broadcast sends an RLP-encoded message to all connected peers. | ||||||
|  | // This method is deprecated and will be removed later. | ||||||
|  | func (srv *Server) Broadcast(protocol string, code uint64, data ...interface{}) { | ||||||
|  | 	var payload []byte | ||||||
|  | 	if data != nil { | ||||||
|  | 		payload = encodePayload(data...) | ||||||
|  | 	} | ||||||
|  | 	srv.lock.RLock() | ||||||
|  | 	defer srv.lock.RUnlock() | ||||||
|  | 	for _, peer := range srv.peers { | ||||||
|  | 		if peer != nil { | ||||||
|  | 			var msg = Msg{Code: code} | ||||||
|  | 			if data != nil { | ||||||
|  | 				msg.Payload = bytes.NewReader(payload) | ||||||
|  | 				msg.Size = uint32(len(payload)) | ||||||
|  | 			} | ||||||
|  | 			peer.writeProtoMsg(protocol, msg) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Start starts running the server. | ||||||
|  | // Servers can be re-used and started again after stopping. | ||||||
|  | func (srv *Server) Start() (err error) { | ||||||
|  | 	srv.lock.Lock() | ||||||
|  | 	defer srv.lock.Unlock() | ||||||
|  | 	if srv.running { | ||||||
|  | 		return errors.New("server already running") | ||||||
|  | 	} | ||||||
|  | 	srvlog.Infoln("Starting Server") | ||||||
|  |  | ||||||
|  | 	// initialize fields | ||||||
|  | 	if srv.Identity == nil { | ||||||
|  | 		return fmt.Errorf("Server.Identity must be set to a non-nil identity") | ||||||
|  | 	} | ||||||
|  | 	if srv.MaxPeers <= 0 { | ||||||
|  | 		return fmt.Errorf("Server.MaxPeers must be > 0") | ||||||
|  | 	} | ||||||
|  | 	srv.quit = make(chan struct{}) | ||||||
|  | 	srv.peers = make([]*Peer, srv.MaxPeers) | ||||||
|  | 	srv.peerSlots = make(chan int, srv.MaxPeers) | ||||||
|  | 	srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize) | ||||||
|  | 	srv.peerDisconnect = make(chan *Peer) | ||||||
|  | 	if srv.newPeerFunc == nil { | ||||||
|  | 		srv.newPeerFunc = newServerPeer | ||||||
|  | 	} | ||||||
|  | 	if srv.Blacklist == nil { | ||||||
|  | 		srv.Blacklist = NewBlacklist() | ||||||
|  | 	} | ||||||
|  | 	if srv.Dialer == nil { | ||||||
|  | 		srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if srv.ListenAddr != "" { | ||||||
|  | 		if err := srv.startListening(); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if !srv.NoDial { | ||||||
|  | 		srv.wg.Add(1) | ||||||
|  | 		go srv.dialLoop() | ||||||
|  | 	} | ||||||
|  | 	if srv.NoDial && srv.ListenAddr == "" { | ||||||
|  | 		srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// make all slots available | ||||||
|  | 	for i := range srv.peers { | ||||||
|  | 		srv.peerSlots <- i | ||||||
|  | 	} | ||||||
|  | 	// note: discLoop is not part of WaitGroup | ||||||
|  | 	go srv.discLoop() | ||||||
|  | 	srv.running = true | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) startListening() error { | ||||||
|  | 	listener, err := net.Listen("tcp", srv.ListenAddr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	srv.ListenAddr = listener.Addr().String() | ||||||
|  | 	srv.laddr = listener.Addr().(*net.TCPAddr) | ||||||
|  | 	srv.listener = listener | ||||||
|  | 	srv.wg.Add(1) | ||||||
|  | 	go srv.listenLoop() | ||||||
|  | 	if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { | ||||||
|  | 		srv.wg.Add(1) | ||||||
|  | 		go srv.natLoop(srv.laddr.Port) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Stop terminates the server and all active peer connections. | ||||||
|  | // It blocks until all active connections have been closed. | ||||||
|  | func (srv *Server) Stop() { | ||||||
|  | 	srv.lock.Lock() | ||||||
|  | 	if !srv.running { | ||||||
|  | 		srv.lock.Unlock() | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	srv.running = false | ||||||
|  | 	srv.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	srvlog.Infoln("Stopping server") | ||||||
|  | 	if srv.listener != nil { | ||||||
|  | 		// this unblocks listener Accept | ||||||
|  | 		srv.listener.Close() | ||||||
|  | 	} | ||||||
|  | 	close(srv.quit) | ||||||
|  | 	for _, peer := range srv.Peers() { | ||||||
|  | 		peer.Disconnect(DiscQuitting) | ||||||
|  | 	} | ||||||
|  | 	srv.wg.Wait() | ||||||
|  |  | ||||||
|  | 	// wait till they actually disconnect | ||||||
|  | 	// this is checked by claiming all peerSlots. | ||||||
|  | 	// slots become available as the peers disconnect. | ||||||
|  | 	for i := 0; i < cap(srv.peerSlots); i++ { | ||||||
|  | 		<-srv.peerSlots | ||||||
|  | 	} | ||||||
|  | 	// terminate discLoop | ||||||
|  | 	close(srv.peerDisconnect) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) discLoop() { | ||||||
|  | 	for peer := range srv.peerDisconnect { | ||||||
|  | 		// peer has just disconnected. free up its slot. | ||||||
|  | 		srvlog.Infof("%v is gone", peer) | ||||||
|  | 		srv.peerSlots <- peer.slot | ||||||
|  | 		srv.lock.Lock() | ||||||
|  | 		srv.peers[peer.slot] = nil | ||||||
|  | 		srv.lock.Unlock() | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // main loop for adding connections via listening | ||||||
|  | func (srv *Server) listenLoop() { | ||||||
|  | 	defer srv.wg.Done() | ||||||
|  |  | ||||||
|  | 	srvlog.Infoln("Listening on", srv.listener.Addr()) | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case slot := <-srv.peerSlots: | ||||||
|  | 			conn, err := srv.listener.Accept() | ||||||
|  | 			if err != nil { | ||||||
|  | 				srv.peerSlots <- slot | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 			srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot) | ||||||
|  | 			srv.addPeer(conn, nil, slot) | ||||||
|  | 		case <-srv.quit: | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) natLoop(port int) { | ||||||
|  | 	defer srv.wg.Done() | ||||||
|  | 	for { | ||||||
|  | 		srv.updatePortMapping(port) | ||||||
|  | 		select { | ||||||
|  | 		case <-time.After(portMappingUpdateInterval): | ||||||
|  | 			// one more round | ||||||
|  | 		case <-srv.quit: | ||||||
|  | 			srv.removePortMapping(port) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) updatePortMapping(port int) { | ||||||
|  | 	srvlog.Infoln("Attempting to map port", port, "with", srv.NAT) | ||||||
|  | 	err := srv.NAT.AddPortMapping("tcp", port, port, "ethereum p2p", portMappingTimeout) | ||||||
|  | 	if err != nil { | ||||||
|  | 		srvlog.Errorln("Port mapping error:", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	extip, err := srv.NAT.GetExternalAddress() | ||||||
|  | 	if err != nil { | ||||||
|  | 		srvlog.Errorln("Error getting external IP:", err) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	srv.lock.Lock() | ||||||
|  | 	extaddr := *(srv.listener.Addr().(*net.TCPAddr)) | ||||||
|  | 	extaddr.IP = extip | ||||||
|  | 	srvlog.Infoln("Mapped port, external addr is", &extaddr) | ||||||
|  | 	srv.laddr = &extaddr | ||||||
|  | 	srv.lock.Unlock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) removePortMapping(port int) { | ||||||
|  | 	srvlog.Infoln("Removing port mapping for", port, "with", srv.NAT) | ||||||
|  | 	srv.NAT.DeletePortMapping("tcp", port, port) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) dialLoop() { | ||||||
|  | 	defer srv.wg.Done() | ||||||
|  | 	var ( | ||||||
|  | 		suggest chan *peerAddr | ||||||
|  | 		slot    *int | ||||||
|  | 		slots   = srv.peerSlots | ||||||
|  | 	) | ||||||
|  | 	for { | ||||||
|  | 		select { | ||||||
|  | 		case i := <-slots: | ||||||
|  | 			// we need a peer in slot i, slot reserved | ||||||
|  | 			slot = &i | ||||||
|  | 			// now we can watch for candidate peers in the next loop | ||||||
|  | 			suggest = srv.peerConnect | ||||||
|  | 			// do not consume more until candidate peer is found | ||||||
|  | 			slots = nil | ||||||
|  |  | ||||||
|  | 		case desc := <-suggest: | ||||||
|  | 			// candidate peer found, will dial out asyncronously | ||||||
|  | 			// if connection fails slot will be released | ||||||
|  | 			go srv.dialPeer(desc, *slot) | ||||||
|  | 			// we can watch if more peers needed in the next loop | ||||||
|  | 			slots = srv.peerSlots | ||||||
|  | 			// until then we dont care about candidate peers | ||||||
|  | 			suggest = nil | ||||||
|  |  | ||||||
|  | 		case <-srv.quit: | ||||||
|  | 			// give back the currently reserved slot | ||||||
|  | 			if slot != nil { | ||||||
|  | 				srv.peerSlots <- *slot | ||||||
|  | 			} | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // connect to peer via dial out | ||||||
|  | func (srv *Server) dialPeer(desc *peerAddr, slot int) { | ||||||
|  | 	srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) | ||||||
|  | 	conn, err := srv.Dialer.Dial(desc.Network(), desc.String()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		srvlog.Errorf("Dial error: %v", err) | ||||||
|  | 		srv.peerSlots <- slot | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	go srv.addPeer(conn, desc, slot) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // creates the new peer object and inserts it into its slot | ||||||
|  | func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { | ||||||
|  | 	srv.lock.Lock() | ||||||
|  | 	defer srv.lock.Unlock() | ||||||
|  | 	if !srv.running { | ||||||
|  | 		conn.Close() | ||||||
|  | 		srv.peerSlots <- slot // release slot | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  | 	peer := srv.newPeerFunc(srv, conn, desc) | ||||||
|  | 	peer.slot = slot | ||||||
|  | 	srv.peers[slot] = peer | ||||||
|  | 	srv.peerCount++ | ||||||
|  | 	go func() { peer.loop(); srv.peerDisconnect <- peer }() | ||||||
|  | 	return peer | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot | ||||||
|  | func (srv *Server) removePeer(peer *Peer) { | ||||||
|  | 	srv.lock.Lock() | ||||||
|  | 	defer srv.lock.Unlock() | ||||||
|  | 	srvlog.Debugf("Removing peer %v %v (slot %v)\n", peer, peer.slot) | ||||||
|  | 	if srv.peers[peer.slot] != peer { | ||||||
|  | 		srvlog.Warnln("Invalid peer to remove:", peer) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	// remove from list and index | ||||||
|  | 	srv.peerCount-- | ||||||
|  | 	srv.peers[peer.slot] = nil | ||||||
|  | 	// release slot to signal need for a new peer, last! | ||||||
|  | 	srv.peerSlots <- peer.slot | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (srv *Server) verifyPeer(addr *peerAddr) error { | ||||||
|  | 	if srv.Blacklist.Exists(addr.Pubkey) { | ||||||
|  | 		return errors.New("blacklisted") | ||||||
|  | 	} | ||||||
|  | 	if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) { | ||||||
|  | 		return newPeerError(errPubkeyForbidden, "not allowed to connect to srv") | ||||||
|  | 	} | ||||||
|  | 	srv.lock.RLock() | ||||||
|  | 	defer srv.lock.RUnlock() | ||||||
|  | 	for _, peer := range srv.peers { | ||||||
|  | 		if peer != nil { | ||||||
|  | 			id := peer.Identity() | ||||||
|  | 			if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) { | ||||||
|  | 				return errors.New("already connected") | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| type Blacklist interface { | type Blacklist interface { | ||||||
| 	Get([]byte) (bool, error) | 	Get([]byte) (bool, error) | ||||||
| 	Put([]byte) error | 	Put([]byte) error | ||||||
| @@ -66,419 +465,3 @@ func (self *BlacklistMap) Delete(pubkey []byte) error { | |||||||
| 	delete(self.blacklist, string(pubkey)) | 	delete(self.blacklist, string(pubkey)) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| type Server struct { |  | ||||||
| 	network   Network |  | ||||||
| 	listening bool //needed? |  | ||||||
| 	dialing   bool //needed? |  | ||||||
| 	closed    bool |  | ||||||
| 	identity  ClientIdentity |  | ||||||
| 	addr      net.Addr |  | ||||||
| 	port      uint16 |  | ||||||
| 	protocols []string |  | ||||||
|  |  | ||||||
| 	quit      chan chan bool |  | ||||||
| 	peersLock sync.RWMutex |  | ||||||
|  |  | ||||||
| 	maxPeers   int |  | ||||||
| 	peers      []*Peer |  | ||||||
| 	peerSlots  chan int |  | ||||||
| 	peersTable map[string]int |  | ||||||
| 	peersMsg   *Msg |  | ||||||
| 	peerCount  int |  | ||||||
|  |  | ||||||
| 	peerConnect    chan net.Addr |  | ||||||
| 	peerDisconnect chan DisconnectRequest |  | ||||||
| 	blacklist      Blacklist |  | ||||||
| 	handlers       Handlers |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var logger = logpkg.NewLogger("P2P") |  | ||||||
|  |  | ||||||
| func New(network Network, addr net.Addr, identity ClientIdentity, handlers Handlers, maxPeers int, blacklist Blacklist) *Server { |  | ||||||
| 	// get alphabetical list of protocol names from handlers map |  | ||||||
| 	protocols := []string{} |  | ||||||
| 	for protocol := range handlers { |  | ||||||
| 		protocols = append(protocols, protocol) |  | ||||||
| 	} |  | ||||||
| 	sort.Strings(protocols) |  | ||||||
|  |  | ||||||
| 	_, port, _ := net.SplitHostPort(addr.String()) |  | ||||||
| 	intport, _ := strconv.Atoi(port) |  | ||||||
|  |  | ||||||
| 	self := &Server{ |  | ||||||
| 		// NewSimpleClientIdentity(clientIdentifier, version, customIdentifier) |  | ||||||
| 		network:   network, |  | ||||||
| 		identity:  identity, |  | ||||||
| 		addr:      addr, |  | ||||||
| 		port:      uint16(intport), |  | ||||||
| 		protocols: protocols, |  | ||||||
|  |  | ||||||
| 		quit: make(chan chan bool), |  | ||||||
|  |  | ||||||
| 		maxPeers:   maxPeers, |  | ||||||
| 		peers:      make([]*Peer, maxPeers), |  | ||||||
| 		peerSlots:  make(chan int, maxPeers), |  | ||||||
| 		peersTable: make(map[string]int), |  | ||||||
|  |  | ||||||
| 		peerConnect:    make(chan net.Addr, outboundAddressPoolSize), |  | ||||||
| 		peerDisconnect: make(chan DisconnectRequest), |  | ||||||
| 		blacklist:      blacklist, |  | ||||||
|  |  | ||||||
| 		handlers: handlers, |  | ||||||
| 	} |  | ||||||
| 	for i := 0; i < maxPeers; i++ { |  | ||||||
| 		self.peerSlots <- i // fill up with indexes |  | ||||||
| 	} |  | ||||||
| 	return self |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) NewAddr(host string, port int) (addr net.Addr, err error) { |  | ||||||
| 	addr, err = self.network.NewAddr(host, port) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) ParseAddr(address string) (addr net.Addr, err error) { |  | ||||||
| 	addr, err = self.network.ParseAddr(address) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) ClientIdentity() ClientIdentity { |  | ||||||
| 	return self.identity |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) PeersMessage() (msg *Msg, err error) { |  | ||||||
| 	// TODO: memoize and reset when peers change |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	msg = self.peersMsg |  | ||||||
| 	if msg == nil { |  | ||||||
| 		var peerData []interface{} |  | ||||||
| 		for _, i := range self.peersTable { |  | ||||||
| 			peer := self.peers[i] |  | ||||||
| 			peerData = append(peerData, peer.Encode()) |  | ||||||
| 		} |  | ||||||
| 		if len(peerData) == 0 { |  | ||||||
| 			err = fmt.Errorf("no peers") |  | ||||||
| 		} else { |  | ||||||
| 			msg, err = NewMsg(PeersMsg, peerData...) |  | ||||||
| 			self.peersMsg = msg //memoize |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) Peers() (peers []*Peer) { |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	for _, peer := range self.peers { |  | ||||||
| 		if peer != nil { |  | ||||||
| 			peers = append(peers, peer) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) PeerCount() int { |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	return self.peerCount |  | ||||||
| } |  | ||||||
|  |  | ||||||
| var getPeersMsg, _ = NewMsg(GetPeersMsg) |  | ||||||
|  |  | ||||||
| func (self *Server) PeerConnect(addr net.Addr) { |  | ||||||
| 	// TODO: should buffer, filter and uniq |  | ||||||
| 	// send GetPeersMsg if not blocking |  | ||||||
| 	select { |  | ||||||
| 	case self.peerConnect <- addr: // not enough peers |  | ||||||
| 		self.Broadcast("", getPeersMsg) |  | ||||||
| 	default: // we dont care |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) PeerDisconnect() chan DisconnectRequest { |  | ||||||
| 	return self.peerDisconnect |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) Blacklist() Blacklist { |  | ||||||
| 	return self.blacklist |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) Handlers() Handlers { |  | ||||||
| 	return self.handlers |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) Broadcast(protocol string, msg *Msg) { |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	for _, peer := range self.peers { |  | ||||||
| 		if peer != nil { |  | ||||||
| 			peer.Write(protocol, msg) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Start the server |  | ||||||
| func (self *Server) Start(listen bool, dial bool) { |  | ||||||
| 	self.network.Start() |  | ||||||
| 	if listen { |  | ||||||
| 		listener, err := self.network.Listener(self.addr) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Warnf("Error initializing listener: %v", err) |  | ||||||
| 			logger.Warnf("Connection listening disabled") |  | ||||||
| 			self.listening = false |  | ||||||
| 		} else { |  | ||||||
| 			self.listening = true |  | ||||||
| 			logger.Infoln("Listen on %v: ready and accepting connections", listener.Addr()) |  | ||||||
| 			go self.inboundPeerHandler(listener) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if dial { |  | ||||||
| 		dialer, err := self.network.Dialer(self.addr) |  | ||||||
| 		if err != nil { |  | ||||||
| 			logger.Warnf("Error initializing dialer: %v", err) |  | ||||||
| 			logger.Warnf("Connection dialout disabled") |  | ||||||
| 			self.dialing = false |  | ||||||
| 		} else { |  | ||||||
| 			self.dialing = true |  | ||||||
| 			logger.Infoln("Dial peers watching outbound address pool") |  | ||||||
| 			go self.outboundPeerHandler(dialer) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	logger.Infoln("server started") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) Stop() { |  | ||||||
| 	logger.Infoln("server stopping...") |  | ||||||
| 	// // quit one loop if dialing |  | ||||||
| 	if self.dialing { |  | ||||||
| 		logger.Infoln("stop dialout...") |  | ||||||
| 		dialq := make(chan bool) |  | ||||||
| 		self.quit <- dialq |  | ||||||
| 		<-dialq |  | ||||||
| 		fmt.Println("quit another") |  | ||||||
| 	} |  | ||||||
| 	// quit the other loop if listening |  | ||||||
| 	if self.listening { |  | ||||||
| 		logger.Infoln("stop listening...") |  | ||||||
| 		listenq := make(chan bool) |  | ||||||
| 		self.quit <- listenq |  | ||||||
| 		<-listenq |  | ||||||
| 		fmt.Println("quit one") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	fmt.Println("quit waited") |  | ||||||
|  |  | ||||||
| 	logger.Infoln("stopping peers...") |  | ||||||
| 	peers := []net.Addr{} |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	self.closed = true |  | ||||||
| 	for _, peer := range self.peers { |  | ||||||
| 		if peer != nil { |  | ||||||
| 			peers = append(peers, peer.Address) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	self.peersLock.RUnlock() |  | ||||||
| 	for _, address := range peers { |  | ||||||
| 		go self.removePeer(DisconnectRequest{ |  | ||||||
| 			addr:   address, |  | ||||||
| 			reason: DiscQuitting, |  | ||||||
| 		}) |  | ||||||
| 	} |  | ||||||
| 	// wait till they actually disconnect |  | ||||||
| 	// this is checked by draining the peerSlots (slots are released back if a peer is removed) |  | ||||||
| 	i := 0 |  | ||||||
| 	fmt.Println("draining peers") |  | ||||||
|  |  | ||||||
| FOR: |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case slot := <-self.peerSlots: |  | ||||||
| 			i++ |  | ||||||
| 			fmt.Printf("%v: found slot %v", i, slot) |  | ||||||
| 			if i == self.maxPeers { |  | ||||||
| 				break FOR |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	logger.Infoln("server stopped") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // main loop for adding connections via listening |  | ||||||
| func (self *Server) inboundPeerHandler(listener net.Listener) { |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case slot := <-self.peerSlots: |  | ||||||
| 			go self.connectInboundPeer(listener, slot) |  | ||||||
| 		case errc := <-self.quit: |  | ||||||
| 			listener.Close() |  | ||||||
| 			fmt.Println("quit listenloop") |  | ||||||
| 			errc <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // main loop for adding outbound peers based on peerConnect address pool |  | ||||||
| // this same loop handles peer disconnect requests as well |  | ||||||
| func (self *Server) outboundPeerHandler(dialer Dialer) { |  | ||||||
| 	// addressChan initially set to nil (only watches peerConnect if we need more peers) |  | ||||||
| 	var addressChan chan net.Addr |  | ||||||
| 	slots := self.peerSlots |  | ||||||
| 	var slot *int |  | ||||||
| 	for { |  | ||||||
| 		select { |  | ||||||
| 		case i := <-slots: |  | ||||||
| 			// we need a peer in slot i, slot reserved |  | ||||||
| 			slot = &i |  | ||||||
| 			// now we can watch for candidate peers in the next loop |  | ||||||
| 			addressChan = self.peerConnect |  | ||||||
| 			// do not consume more until candidate peer is found |  | ||||||
| 			slots = nil |  | ||||||
| 		case address := <-addressChan: |  | ||||||
| 			// candidate peer found, will dial out asyncronously |  | ||||||
| 			// if connection fails slot will be released |  | ||||||
| 			go self.connectOutboundPeer(dialer, address, *slot) |  | ||||||
| 			// we can watch if more peers needed in the next loop |  | ||||||
| 			slots = self.peerSlots |  | ||||||
| 			// until then we dont care about candidate peers |  | ||||||
| 			addressChan = nil |  | ||||||
| 		case request := <-self.peerDisconnect: |  | ||||||
| 			go self.removePeer(request) |  | ||||||
| 		case errc := <-self.quit: |  | ||||||
| 			if addressChan != nil && slot != nil { |  | ||||||
| 				self.peerSlots <- *slot |  | ||||||
| 			} |  | ||||||
| 			fmt.Println("quit dialloop") |  | ||||||
| 			errc <- true |  | ||||||
| 			return |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // check if peer address already connected |  | ||||||
| func (self *Server) connected(address net.Addr) (err error) { |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	// fmt.Printf("address: %v\n", address) |  | ||||||
| 	slot, found := self.peersTable[address.String()] |  | ||||||
| 	if found { |  | ||||||
| 		err = fmt.Errorf("already connected as peer %v (%v)", slot, address) |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // connect to peer via listener.Accept() |  | ||||||
| func (self *Server) connectInboundPeer(listener net.Listener, slot int) { |  | ||||||
| 	var address net.Addr |  | ||||||
| 	conn, err := listener.Accept() |  | ||||||
| 	if err == nil { |  | ||||||
| 		address = conn.RemoteAddr() |  | ||||||
| 		err = self.connected(address) |  | ||||||
| 		if err != nil { |  | ||||||
| 			conn.Close() |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Debugln(err) |  | ||||||
| 		self.peerSlots <- slot |  | ||||||
| 	} else { |  | ||||||
| 		fmt.Printf("adding %v\n", address) |  | ||||||
| 		go self.addPeer(conn, address, true, slot) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // connect to peer via dial out |  | ||||||
| func (self *Server) connectOutboundPeer(dialer Dialer, address net.Addr, slot int) { |  | ||||||
| 	var conn net.Conn |  | ||||||
| 	err := self.connected(address) |  | ||||||
| 	if err == nil { |  | ||||||
| 		conn, err = dialer.Dial(address.Network(), address.String()) |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		logger.Debugln(err) |  | ||||||
| 		self.peerSlots <- slot |  | ||||||
| 	} else { |  | ||||||
| 		go self.addPeer(conn, address, false, slot) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // creates the new peer object and inserts it into its slot |  | ||||||
| func (self *Server) addPeer(conn net.Conn, address net.Addr, inbound bool, slot int) { |  | ||||||
| 	self.peersLock.Lock() |  | ||||||
| 	defer self.peersLock.Unlock() |  | ||||||
| 	if self.closed { |  | ||||||
| 		fmt.Println("oopsy, not no longer need peer") |  | ||||||
| 		conn.Close()           //oopsy our bad |  | ||||||
| 		self.peerSlots <- slot // release slot |  | ||||||
| 	} else { |  | ||||||
| 		peer := NewPeer(conn, address, inbound, self) |  | ||||||
| 		self.peers[slot] = peer |  | ||||||
| 		self.peersTable[address.String()] = slot |  | ||||||
| 		self.peerCount++ |  | ||||||
| 		// reset peersmsg |  | ||||||
| 		self.peersMsg = nil |  | ||||||
| 		fmt.Printf("added peer %v %v (slot %v)\n", address, peer, slot) |  | ||||||
| 		peer.Start() |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot |  | ||||||
| func (self *Server) removePeer(request DisconnectRequest) { |  | ||||||
| 	self.peersLock.Lock() |  | ||||||
|  |  | ||||||
| 	address := request.addr |  | ||||||
| 	slot := self.peersTable[address.String()] |  | ||||||
| 	peer := self.peers[slot] |  | ||||||
| 	fmt.Printf("removing peer %v %v (slot %v)\n", address, peer, slot) |  | ||||||
| 	if peer == nil { |  | ||||||
| 		logger.Debugf("already removed peer on %v", address) |  | ||||||
| 		self.peersLock.Unlock() |  | ||||||
| 		return |  | ||||||
| 	} |  | ||||||
| 	// remove from list and index |  | ||||||
| 	self.peerCount-- |  | ||||||
| 	self.peers[slot] = nil |  | ||||||
| 	delete(self.peersTable, address.String()) |  | ||||||
| 	// reset peersmsg |  | ||||||
| 	self.peersMsg = nil |  | ||||||
| 	fmt.Printf("removed peer %v (slot %v)\n", peer, slot) |  | ||||||
| 	self.peersLock.Unlock() |  | ||||||
|  |  | ||||||
| 	// sending disconnect message |  | ||||||
| 	disconnectMsg, _ := NewMsg(DiscMsg, request.reason) |  | ||||||
| 	peer.Write("", disconnectMsg) |  | ||||||
| 	// be nice and wait |  | ||||||
| 	time.Sleep(disconnectGracePeriod * time.Second) |  | ||||||
| 	// switch off peer and close connections etc. |  | ||||||
| 	fmt.Println("stopping peer") |  | ||||||
| 	peer.Stop() |  | ||||||
| 	fmt.Println("stopped peer") |  | ||||||
| 	// release slot to signal need for a new peer, last! |  | ||||||
| 	self.peerSlots <- slot |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // fix handshake message to push to peers |  | ||||||
| func (self *Server) Handshake() *Msg { |  | ||||||
| 	fmt.Println(self.identity.Pubkey()[1:]) |  | ||||||
| 	msg, _ := NewMsg(HandshakeMsg, P2PVersion, []byte(self.identity.String()), []interface{}{self.protocols}, self.port, self.identity.Pubkey()[1:]) |  | ||||||
| 	return msg |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *Server) RegisterPubkey(candidate *Peer, pubkey []byte) error { |  | ||||||
| 	// Check for blacklisting |  | ||||||
| 	if self.blacklist.Exists(pubkey) { |  | ||||||
| 		return fmt.Errorf("blacklisted") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	self.peersLock.RLock() |  | ||||||
| 	defer self.peersLock.RUnlock() |  | ||||||
| 	for _, peer := range self.peers { |  | ||||||
| 		if peer != nil && peer != candidate && bytes.Compare(peer.Pubkey, pubkey) == 0 { |  | ||||||
| 			return fmt.Errorf("already connected") |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	candidate.Pubkey = pubkey |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -2,207 +2,160 @@ package p2p | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"fmt" | 	"io" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"sync" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| type TestNetwork struct { | func startTestServer(t *testing.T, pf peerFunc) *Server { | ||||||
| 	connections map[string]*TestNetworkConnection | 	server := &Server{ | ||||||
| 	dialer      Dialer | 		Identity:    NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey"), | ||||||
| 	maxinbound  int | 		MaxPeers:    10, | ||||||
|  | 		ListenAddr:  "127.0.0.1:0", | ||||||
|  | 		newPeerFunc: pf, | ||||||
|  | 	} | ||||||
|  | 	if err := server.Start(); err != nil { | ||||||
|  | 		t.Fatalf("Could not start server: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return server | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewTestNetwork(maxinbound int) *TestNetwork { | func TestServerListen(t *testing.T) { | ||||||
| 	connections := make(map[string]*TestNetworkConnection) | 	defer testlog(t).detach() | ||||||
| 	return &TestNetwork{ |  | ||||||
| 		connections: connections, | 	// start the test server | ||||||
| 		dialer:      &TestDialer{connections}, | 	connected := make(chan *Peer) | ||||||
| 		maxinbound:  maxinbound, | 	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||||
|  | 		if conn == nil { | ||||||
|  | 			t.Error("peer func called with nil conn") | ||||||
|  | 		} | ||||||
|  | 		if dialAddr != nil { | ||||||
|  | 			t.Error("peer func called with non-nil dialAddr") | ||||||
|  | 		} | ||||||
|  | 		peer := newPeer(conn, nil, dialAddr) | ||||||
|  | 		connected <- peer | ||||||
|  | 		return peer | ||||||
|  | 	}) | ||||||
|  | 	defer close(connected) | ||||||
|  | 	defer srv.Stop() | ||||||
|  |  | ||||||
|  | 	// dial the test server | ||||||
|  | 	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("could not dial: %v", err) | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case peer := <-connected: | ||||||
|  | 		if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { | ||||||
|  | 			t.Errorf("peer started with wrong conn: got %v, want %v", | ||||||
|  | 				peer.conn.LocalAddr(), conn.RemoteAddr()) | ||||||
|  | 		} | ||||||
|  | 	case <-time.After(1 * time.Second): | ||||||
|  | 		t.Error("server did not accept within one second") | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (self *TestNetwork) Dialer(addr net.Addr) (Dialer, error) { | func TestServerDial(t *testing.T) { | ||||||
| 	return self.dialer, nil | 	defer testlog(t).detach() | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetwork) Listener(addr net.Addr) (net.Listener, error) { | 	// run a fake TCP server to handle the connection. | ||||||
| 	return &TestListener{ | 	listener, err := net.Listen("tcp", "127.0.0.1:0") | ||||||
| 		connections: self.connections, | 	if err != nil { | ||||||
| 		addr:        addr, | 		t.Fatalf("could not setup listener: %v") | ||||||
| 		max:         self.maxinbound, |  | ||||||
| 	}, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetwork) Start() error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetwork) NewAddr(string, int) (addr net.Addr, err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestNetwork) ParseAddr(string) (addr net.Addr, err error) { |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TestAddr struct { |  | ||||||
| 	name string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestAddr) String() string { |  | ||||||
| 	return self.name |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (*TestAddr) Network() string { |  | ||||||
| 	return "test" |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TestDialer struct { |  | ||||||
| 	connections map[string]*TestNetworkConnection |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestDialer) Dial(network string, addr string) (conn net.Conn, err error) { |  | ||||||
| 	address := &TestAddr{addr} |  | ||||||
| 	tconn := NewTestNetworkConnection(address) |  | ||||||
| 	self.connections[addr] = tconn |  | ||||||
| 	conn = net.Conn(tconn) |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type TestListener struct { |  | ||||||
| 	connections map[string]*TestNetworkConnection |  | ||||||
| 	addr        net.Addr |  | ||||||
| 	max         int |  | ||||||
| 	i           int |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (self *TestListener) Accept() (conn net.Conn, err error) { |  | ||||||
| 	self.i++ |  | ||||||
| 	if self.i > self.max { |  | ||||||
| 		err = fmt.Errorf("no more") |  | ||||||
| 	} else { |  | ||||||
| 		addr := &TestAddr{fmt.Sprintf("inboundpeer-%d", self.i)} |  | ||||||
| 		tconn := NewTestNetworkConnection(addr) |  | ||||||
| 		key := tconn.RemoteAddr().String() |  | ||||||
| 		self.connections[key] = tconn |  | ||||||
| 		conn = net.Conn(tconn) |  | ||||||
| 		fmt.Printf("accepted connection from: %v \n", addr) |  | ||||||
| 	} | 	} | ||||||
| 	return | 	defer listener.Close() | ||||||
| } | 	accepted := make(chan net.Conn) | ||||||
|  | 	go func() { | ||||||
| func (self *TestListener) Close() error { | 		conn, err := listener.Accept() | ||||||
| 	return nil | 		if err != nil { | ||||||
| } | 			t.Error("acccept error:", err) | ||||||
|  |  | ||||||
| func (self *TestListener) Addr() net.Addr { |  | ||||||
| 	return self.addr |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func SetupTestServer(handlers Handlers) (network *TestNetwork, server *Server) { |  | ||||||
| 	network = NewTestNetwork(1) |  | ||||||
| 	addr := &TestAddr{"test:30303"} |  | ||||||
| 	identity := NewSimpleClientIdentity("clientIdentifier", "version", "customIdentifier", "pubkey") |  | ||||||
| 	maxPeers := 2 |  | ||||||
| 	if handlers == nil { |  | ||||||
| 		handlers = make(Handlers) |  | ||||||
| 		} | 		} | ||||||
| 	blackist := NewBlacklist() | 		conn.Close() | ||||||
| 	server = New(network, addr, identity, handlers, maxPeers, blackist) | 		accepted <- conn | ||||||
| 	fmt.Println(server.identity.Pubkey()) | 	}() | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestServerListener(t *testing.T) { | 	// start the test server | ||||||
| 	network, server := SetupTestServer(nil) | 	connected := make(chan *Peer) | ||||||
| 	server.Start(true, false) | 	srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { | ||||||
| 	time.Sleep(10 * time.Millisecond) | 		if conn == nil { | ||||||
| 	server.Stop() | 			t.Error("peer func called with nil conn") | ||||||
| 	peer1, ok := network.connections["inboundpeer-1"] |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Error("not found inbound peer 1") |  | ||||||
| 	} else { |  | ||||||
| 		fmt.Printf("out: %v\n", peer1.Out) |  | ||||||
| 		if len(peer1.Out) != 2 { |  | ||||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |  | ||||||
| 		} | 		} | ||||||
|  | 		peer := newPeer(conn, nil, dialAddr) | ||||||
|  | 		connected <- peer | ||||||
|  | 		return peer | ||||||
|  | 	}) | ||||||
|  | 	defer close(connected) | ||||||
|  | 	defer srv.Stop() | ||||||
|  |  | ||||||
|  | 	// tell the server to connect. | ||||||
|  | 	connAddr := newPeerAddr(listener.Addr(), nil) | ||||||
|  | 	srv.peerConnect <- connAddr | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case conn := <-accepted: | ||||||
|  | 		select { | ||||||
|  | 		case peer := <-connected: | ||||||
|  | 			if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { | ||||||
|  | 				t.Errorf("peer started with wrong conn: got %v, want %v", | ||||||
|  | 					peer.conn.RemoteAddr(), conn.LocalAddr()) | ||||||
|  | 			} | ||||||
|  | 			if peer.dialAddr != connAddr { | ||||||
|  | 				t.Errorf("peer started with wrong dialAddr: got %v, want %v", | ||||||
|  | 					peer.dialAddr, connAddr) | ||||||
|  | 			} | ||||||
|  | 		case <-time.After(1 * time.Second): | ||||||
|  | 			t.Error("server did not launch peer within one second") | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| } | 	case <-time.After(1 * time.Second): | ||||||
|  | 		t.Error("server did not connect within one second") | ||||||
| func TestServerDialer(t *testing.T) { |  | ||||||
| 	network, server := SetupTestServer(nil) |  | ||||||
| 	server.Start(false, true) |  | ||||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
| 	server.Stop() |  | ||||||
| 	peer1, ok := network.connections["outboundpeer-1"] |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Error("not found outbound peer 1") |  | ||||||
| 	} else { |  | ||||||
| 		fmt.Printf("out: %v\n", peer1.Out) |  | ||||||
| 		if len(peer1.Out) != 2 { |  | ||||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestServerBroadcast(t *testing.T) { | func TestServerBroadcast(t *testing.T) { | ||||||
| 	handlers := make(Handlers) | 	defer testlog(t).detach() | ||||||
| 	testProtocol := &TestProtocol{Msgs: []*Msg{}} | 	var connected sync.WaitGroup | ||||||
| 	handlers["aaa"] = func(p *Peer) Protocol { return testProtocol } | 	srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { | ||||||
| 	network, server := SetupTestServer(handlers) | 		peer := newPeer(c, []Protocol{discard}, dialAddr) | ||||||
| 	server.Start(true, true) | 		peer.startSubprotocols([]Cap{discard.cap()}) | ||||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} | 		connected.Done() | ||||||
| 	time.Sleep(10 * time.Millisecond) | 		return peer | ||||||
| 	msg, _ := NewMsg(0) | 	}) | ||||||
| 	server.Broadcast("", msg) | 	defer srv.Stop() | ||||||
| 	packet := Packet(0, 0) |  | ||||||
| 	time.Sleep(10 * time.Millisecond) |  | ||||||
| 	server.Stop() |  | ||||||
| 	peer1, ok := network.connections["outboundpeer-1"] |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Error("not found outbound peer 1") |  | ||||||
| 	} else { |  | ||||||
| 		fmt.Printf("out: %v\n", peer1.Out) |  | ||||||
| 		if len(peer1.Out) != 3 { |  | ||||||
| 			t.Errorf("not enough messages sent to peer 1: %v ", len(peer1.Out)) |  | ||||||
| 		} else { |  | ||||||
| 			if bytes.Compare(peer1.Out[1], packet) != 0 { |  | ||||||
| 				t.Errorf("incorrect broadcast packet %v != %v", peer1.Out[1], packet) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	peer2, ok := network.connections["inboundpeer-1"] |  | ||||||
| 	if !ok { |  | ||||||
| 		t.Error("not found inbound peer 2") |  | ||||||
| 	} else { |  | ||||||
| 		fmt.Printf("out: %v\n", peer2.Out) |  | ||||||
| 		if len(peer1.Out) != 3 { |  | ||||||
| 			t.Errorf("not enough messages sent to peer 2: %v ", len(peer2.Out)) |  | ||||||
| 		} else { |  | ||||||
| 			if bytes.Compare(peer2.Out[1], packet) != 0 { |  | ||||||
| 				t.Errorf("incorrect broadcast packet %v != %v", peer2.Out[1], packet) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestServerPeersMessage(t *testing.T) { | 	// dial a bunch of conns | ||||||
| 	handlers := make(Handlers) | 	var conns = make([]net.Conn, 8) | ||||||
| 	_, server := SetupTestServer(handlers) | 	connected.Add(len(conns)) | ||||||
| 	server.Start(true, true) | 	deadline := time.Now().Add(3 * time.Second) | ||||||
| 	defer server.Stop() | 	dialer := &net.Dialer{Deadline: deadline} | ||||||
| 	server.peerConnect <- &TestAddr{"outboundpeer-1"} | 	for i := range conns { | ||||||
| 	time.Sleep(10 * time.Millisecond) | 		conn, err := dialer.Dial("tcp", srv.ListenAddr) | ||||||
| 	peersMsg, err := server.PeersMessage() |  | ||||||
| 	fmt.Println(peersMsg) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 		t.Errorf("expect no error, got %v", err) | 			t.Fatalf("conn %d: dial error: %v", i, err) | ||||||
|  | 		} | ||||||
|  | 		defer conn.Close() | ||||||
|  | 		conn.SetDeadline(deadline) | ||||||
|  | 		conns[i] = conn | ||||||
|  | 	} | ||||||
|  | 	connected.Wait() | ||||||
|  |  | ||||||
|  | 	// broadcast one message | ||||||
|  | 	srv.Broadcast("discard", 0, "foo") | ||||||
|  | 	goldbuf := new(bytes.Buffer) | ||||||
|  | 	writeMsg(goldbuf, NewMsg(16, "foo")) | ||||||
|  | 	golden := goldbuf.Bytes() | ||||||
|  |  | ||||||
|  | 	// check that the message has been written everywhere | ||||||
|  | 	for i, conn := range conns { | ||||||
|  | 		buf := make([]byte, len(golden)) | ||||||
|  | 		if _, err := io.ReadFull(conn, buf); err != nil { | ||||||
|  | 			t.Errorf("conn %d: read error: %v", i, err) | ||||||
|  | 		} else if !bytes.Equal(buf, golden) { | ||||||
|  | 			t.Errorf("conn %d: msg mismatch\ngot:  %x\nwant: %x", i, buf, golden) | ||||||
| 		} | 		} | ||||||
| 	if c := server.PeerCount(); c != 2 { |  | ||||||
| 		t.Errorf("expect 2 peers, got %v", c) |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										28
									
								
								p2p/testlog_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								p2p/testlog_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | package p2p | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/logger" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type testLogger struct{ t *testing.T } | ||||||
|  |  | ||||||
|  | func testlog(t *testing.T) testLogger { | ||||||
|  | 	logger.Reset() | ||||||
|  | 	l := testLogger{t} | ||||||
|  | 	logger.AddLogSystem(l) | ||||||
|  | 	return l | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } | ||||||
|  | func (testLogger) SetLogLevel(logger.LogLevel)  {} | ||||||
|  |  | ||||||
|  | func (l testLogger) LogPrint(level logger.LogLevel, msg string) { | ||||||
|  | 	l.t.Logf("%s", msg) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (testLogger) detach() { | ||||||
|  | 	logger.Flush() | ||||||
|  | 	logger.Reset() | ||||||
|  | } | ||||||
							
								
								
									
										40
									
								
								p2p/testpoc7.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								p2p/testpoc7.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | // +build none | ||||||
|  |  | ||||||
|  | package main | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"log" | ||||||
|  | 	"net" | ||||||
|  | 	"os" | ||||||
|  |  | ||||||
|  | 	"github.com/ethereum/go-ethereum/logger" | ||||||
|  | 	"github.com/ethereum/go-ethereum/p2p" | ||||||
|  | 	"github.com/obscuren/secp256k1-go" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func main() { | ||||||
|  | 	logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel)) | ||||||
|  |  | ||||||
|  | 	pub, _ := secp256k1.GenerateKeyPair() | ||||||
|  | 	srv := p2p.Server{ | ||||||
|  | 		MaxPeers:   10, | ||||||
|  | 		Identity:   p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)), | ||||||
|  | 		ListenAddr: ":30303", | ||||||
|  | 		NAT:        p2p.PMP(net.ParseIP("10.0.0.1")), | ||||||
|  | 	} | ||||||
|  | 	if err := srv.Start(); err != nil { | ||||||
|  | 		fmt.Println("could not start server:", err) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// add seed peers | ||||||
|  | 	seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303") | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Println("couldn't resolve:", err) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | 	srv.SuggestPeer(seed.IP, seed.Port, nil) | ||||||
|  |  | ||||||
|  | 	select {} | ||||||
|  | } | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| package rlp | package rlp | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"bufio" | ||||||
| 	"encoding/binary" | 	"encoding/binary" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -24,8 +25,9 @@ type Decoder interface { | |||||||
| 	DecodeRLP(*Stream) error | 	DecodeRLP(*Stream) error | ||||||
| } | } | ||||||
|  |  | ||||||
| // Decode parses RLP-encoded data from r and stores the result | // Decode parses RLP-encoded data from r and stores the result in the | ||||||
| // in the value pointed to by val. Val must be a non-nil pointer. | // value pointed to by val. Val must be a non-nil pointer. If r does | ||||||
|  | // not implement ByteReader, Decode will do its own buffering. | ||||||
| // | // | ||||||
| // Decode uses the following type-dependent decoding rules: | // Decode uses the following type-dependent decoding rules: | ||||||
| // | // | ||||||
| @@ -66,10 +68,19 @@ type Decoder interface { | |||||||
| // | // | ||||||
| // Non-empty interface types are not supported, nor are bool, float32, | // Non-empty interface types are not supported, nor are bool, float32, | ||||||
| // float64, maps, channel types and functions. | // float64, maps, channel types and functions. | ||||||
| func Decode(r ByteReader, val interface{}) error { | func Decode(r io.Reader, val interface{}) error { | ||||||
| 	return NewStream(r).Decode(val) | 	return NewStream(r).Decode(val) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type decodeError struct { | ||||||
|  | 	msg string | ||||||
|  | 	typ reflect.Type | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (err decodeError) Error() string { | ||||||
|  | 	return fmt.Sprintf("rlp: %s for %v", err.msg, err.typ) | ||||||
|  | } | ||||||
|  |  | ||||||
| func makeNumDecoder(typ reflect.Type) decoder { | func makeNumDecoder(typ reflect.Type) decoder { | ||||||
| 	kind := typ.Kind() | 	kind := typ.Kind() | ||||||
| 	switch { | 	switch { | ||||||
| @@ -83,8 +94,11 @@ func makeNumDecoder(typ reflect.Type) decoder { | |||||||
| } | } | ||||||
|  |  | ||||||
| func decodeInt(s *Stream, val reflect.Value) error { | func decodeInt(s *Stream, val reflect.Value) error { | ||||||
| 	num, err := s.uint(val.Type().Bits()) | 	typ := val.Type() | ||||||
| 	if err != nil { | 	num, err := s.uint(typ.Bits()) | ||||||
|  | 	if err == errUintOverflow { | ||||||
|  | 		return decodeError{"input string too long", typ} | ||||||
|  | 	} else if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	val.SetInt(int64(num)) | 	val.SetInt(int64(num)) | ||||||
| @@ -92,8 +106,11 @@ func decodeInt(s *Stream, val reflect.Value) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func decodeUint(s *Stream, val reflect.Value) error { | func decodeUint(s *Stream, val reflect.Value) error { | ||||||
| 	num, err := s.uint(val.Type().Bits()) | 	typ := val.Type() | ||||||
| 	if err != nil { | 	num, err := s.uint(typ.Bits()) | ||||||
|  | 	if err == errUintOverflow { | ||||||
|  | 		return decodeError{"input string too big", typ} | ||||||
|  | 	} else if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	val.SetUint(num) | 	val.SetUint(num) | ||||||
| @@ -175,7 +192,7 @@ func decodeList(s *Stream, val reflect.Value, elemdec decoder, maxelem int) erro | |||||||
| 	i := 0 | 	i := 0 | ||||||
| 	for { | 	for { | ||||||
| 		if i > maxelem { | 		if i > maxelem { | ||||||
| 			return fmt.Errorf("rlp: input List has more than %d elements", maxelem) | 			return decodeError{"input list has too many elements", val.Type()} | ||||||
| 		} | 		} | ||||||
| 		if val.Kind() == reflect.Slice { | 		if val.Kind() == reflect.Slice { | ||||||
| 			// grow slice if necessary | 			// grow slice if necessary | ||||||
| @@ -226,8 +243,6 @@ func decodeByteSlice(s *Stream, val reflect.Value) error { | |||||||
| 	return err | 	return err | ||||||
| } | } | ||||||
|  |  | ||||||
| var errStringDoesntFitArray = errors.New("rlp: string value doesn't fit into target array") |  | ||||||
|  |  | ||||||
| func decodeByteArray(s *Stream, val reflect.Value) error { | func decodeByteArray(s *Stream, val reflect.Value) error { | ||||||
| 	kind, size, err := s.Kind() | 	kind, size, err := s.Kind() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -236,14 +251,14 @@ func decodeByteArray(s *Stream, val reflect.Value) error { | |||||||
| 	switch kind { | 	switch kind { | ||||||
| 	case Byte: | 	case Byte: | ||||||
| 		if val.Len() == 0 { | 		if val.Len() == 0 { | ||||||
| 			return errStringDoesntFitArray | 			return decodeError{"input string too big", val.Type()} | ||||||
| 		} | 		} | ||||||
| 		bv, _ := s.Uint() | 		bv, _ := s.Uint() | ||||||
| 		val.Index(0).SetUint(bv) | 		val.Index(0).SetUint(bv) | ||||||
| 		zero(val, 1) | 		zero(val, 1) | ||||||
| 	case String: | 	case String: | ||||||
| 		if uint64(val.Len()) < size { | 		if uint64(val.Len()) < size { | ||||||
| 			return errStringDoesntFitArray | 			return decodeError{"input string too big", val.Type()} | ||||||
| 		} | 		} | ||||||
| 		slice := val.Slice(0, int(size)).Interface().([]byte) | 		slice := val.Slice(0, int(size)).Interface().([]byte) | ||||||
| 		if err := s.readFull(slice); err != nil { | 		if err := s.readFull(slice); err != nil { | ||||||
| @@ -293,7 +308,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 		if err = s.ListEnd(); err == errNotAtEOL { | 		if err = s.ListEnd(); err == errNotAtEOL { | ||||||
| 			err = errors.New("rlp: input List has too many elements") | 			err = decodeError{"input list has too many elements", typ} | ||||||
| 		} | 		} | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -432,8 +447,23 @@ type Stream struct { | |||||||
|  |  | ||||||
| type listpos struct{ pos, size uint64 } | type listpos struct{ pos, size uint64 } | ||||||
|  |  | ||||||
| func NewStream(r ByteReader) *Stream { | // NewStream creates a new stream reading from r. | ||||||
| 	return &Stream{r: r, uintbuf: make([]byte, 8), kind: -1} | // If r does not implement ByteReader, the Stream will | ||||||
|  | // introduce its own buffering. | ||||||
|  | func NewStream(r io.Reader) *Stream { | ||||||
|  | 	s := new(Stream) | ||||||
|  | 	s.Reset(r) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // NewListStream creates a new stream that pretends to be positioned | ||||||
|  | // at an encoded list of the given length. | ||||||
|  | func NewListStream(r io.Reader, len uint64) *Stream { | ||||||
|  | 	s := new(Stream) | ||||||
|  | 	s.Reset(r) | ||||||
|  | 	s.kind = List | ||||||
|  | 	s.size = len | ||||||
|  | 	return s | ||||||
| } | } | ||||||
|  |  | ||||||
| // Bytes reads an RLP string and returns its contents as a byte slice. | // Bytes reads an RLP string and returns its contents as a byte slice. | ||||||
| @@ -459,6 +489,8 @@ func (s *Stream) Bytes() ([]byte, error) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | var errUintOverflow = errors.New("rlp: uint overflow") | ||||||
|  |  | ||||||
| // Uint reads an RLP string of up to 8 bytes and returns its contents | // Uint reads an RLP string of up to 8 bytes and returns its contents | ||||||
| // as an unsigned integer. If the input does not contain an RLP string, the | // as an unsigned integer. If the input does not contain an RLP string, the | ||||||
| // returned error will be ErrExpectedString. | // returned error will be ErrExpectedString. | ||||||
| @@ -477,7 +509,7 @@ func (s *Stream) uint(maxbits int) (uint64, error) { | |||||||
| 		return uint64(s.byteval), nil | 		return uint64(s.byteval), nil | ||||||
| 	case String: | 	case String: | ||||||
| 		if size > uint64(maxbits/8) { | 		if size > uint64(maxbits/8) { | ||||||
| 			return 0, fmt.Errorf("rlp: string is larger than %d bits", maxbits) | 			return 0, errUintOverflow | ||||||
| 		} | 		} | ||||||
| 		return s.readUint(byte(size)) | 		return s.readUint(byte(size)) | ||||||
| 	default: | 	default: | ||||||
| @@ -543,6 +575,23 @@ func (s *Stream) Decode(val interface{}) error { | |||||||
| 	return info.decoder(s, rval.Elem()) | 	return info.decoder(s, rval.Elem()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Reset discards any information about the current decoding context | ||||||
|  | // and starts reading from r. If r does not also implement ByteReader, | ||||||
|  | // Stream will do its own buffering. | ||||||
|  | func (s *Stream) Reset(r io.Reader) { | ||||||
|  | 	bufr, ok := r.(ByteReader) | ||||||
|  | 	if !ok { | ||||||
|  | 		bufr = bufio.NewReader(r) | ||||||
|  | 	} | ||||||
|  | 	s.r = bufr | ||||||
|  | 	s.stack = s.stack[:0] | ||||||
|  | 	s.size = 0 | ||||||
|  | 	s.kind = -1 | ||||||
|  | 	if s.uintbuf == nil { | ||||||
|  | 		s.uintbuf = make([]byte, 8) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Kind returns the kind and size of the next value in the | // Kind returns the kind and size of the next value in the | ||||||
| // input stream. | // input stream. | ||||||
| // | // | ||||||
|   | |||||||
| @@ -3,7 +3,6 @@ package rlp | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
| 	"errors" |  | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| 	"math/big" | 	"math/big" | ||||||
| @@ -54,6 +53,24 @@ func TestStreamKind(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestNewListStream(t *testing.T) { | ||||||
|  | 	ls := NewListStream(bytes.NewReader(unhex("0101010101")), 3) | ||||||
|  | 	if k, size, err := ls.Kind(); k != List || size != 3 || err != nil { | ||||||
|  | 		t.Errorf("Kind() returned (%v, %d, %v), expected (List, 3, nil)", k, size, err) | ||||||
|  | 	} | ||||||
|  | 	if size, err := ls.List(); size != 3 || err != nil { | ||||||
|  | 		t.Errorf("List() returned (%d, %v), expected (3, nil)", size, err) | ||||||
|  | 	} | ||||||
|  | 	for i := 0; i < 3; i++ { | ||||||
|  | 		if val, err := ls.Uint(); val != 1 || err != nil { | ||||||
|  | 			t.Errorf("Uint() returned (%d, %v), expected (1, nil)", val, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if err := ls.ListEnd(); err != nil { | ||||||
|  | 		t.Errorf("ListEnd() returned %v, expected (3, nil)", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestStreamErrors(t *testing.T) { | func TestStreamErrors(t *testing.T) { | ||||||
| 	type calls []string | 	type calls []string | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| @@ -69,7 +86,7 @@ func TestStreamErrors(t *testing.T) { | |||||||
| 		{"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, | 		{"81", calls{"Bytes"}, io.ErrUnexpectedEOF}, | ||||||
| 		{"81", calls{"Uint"}, io.ErrUnexpectedEOF}, | 		{"81", calls{"Uint"}, io.ErrUnexpectedEOF}, | ||||||
| 		{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, | 		{"BFFFFFFFFFFFFFFF", calls{"Bytes"}, io.ErrUnexpectedEOF}, | ||||||
| 		{"89000000000000000001", calls{"Uint"}, errors.New("rlp: string is larger than 64 bits")}, | 		{"89000000000000000001", calls{"Uint"}, errUintOverflow}, | ||||||
| 		{"00", calls{"List"}, ErrExpectedList}, | 		{"00", calls{"List"}, ErrExpectedList}, | ||||||
| 		{"80", calls{"List"}, ErrExpectedList}, | 		{"80", calls{"List"}, ErrExpectedList}, | ||||||
| 		{"C0", calls{"List", "Uint"}, EOL}, | 		{"C0", calls{"List", "Uint"}, EOL}, | ||||||
| @@ -163,7 +180,7 @@ type decodeTest struct { | |||||||
| 	input string | 	input string | ||||||
| 	ptr   interface{} | 	ptr   interface{} | ||||||
| 	value interface{} | 	value interface{} | ||||||
| 	error error | 	error string | ||||||
| } | } | ||||||
|  |  | ||||||
| type simplestruct struct { | type simplestruct struct { | ||||||
| @@ -196,8 +213,8 @@ var decodeTests = []decodeTest{ | |||||||
| 	{input: "820505", ptr: new(uint32), value: uint32(0x0505)}, | 	{input: "820505", ptr: new(uint32), value: uint32(0x0505)}, | ||||||
| 	{input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, | 	{input: "83050505", ptr: new(uint32), value: uint32(0x050505)}, | ||||||
| 	{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, | 	{input: "8405050505", ptr: new(uint32), value: uint32(0x05050505)}, | ||||||
| 	{input: "850505050505", ptr: new(uint32), error: errors.New("rlp: string is larger than 32 bits")}, | 	{input: "850505050505", ptr: new(uint32), error: "rlp: input string too big for uint32"}, | ||||||
| 	{input: "C0", ptr: new(uint32), error: ErrExpectedString}, | 	{input: "C0", ptr: new(uint32), error: ErrExpectedString.Error()}, | ||||||
|  |  | ||||||
| 	// slices | 	// slices | ||||||
| 	{input: "C0", ptr: new([]int), value: []int{}}, | 	{input: "C0", ptr: new([]int), value: []int{}}, | ||||||
| @@ -206,7 +223,7 @@ var decodeTests = []decodeTest{ | |||||||
| 	// arrays | 	// arrays | ||||||
| 	{input: "C0", ptr: new([5]int), value: [5]int{}}, | 	{input: "C0", ptr: new([5]int), value: [5]int{}}, | ||||||
| 	{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, | 	{input: "C50102030405", ptr: new([5]int), value: [5]int{1, 2, 3, 4, 5}}, | ||||||
| 	{input: "C6010203040506", ptr: new([5]int), error: errors.New("rlp: input List has more than 5 elements")}, | 	{input: "C6010203040506", ptr: new([5]int), error: "rlp: input list has too many elements for [5]int"}, | ||||||
|  |  | ||||||
| 	// byte slices | 	// byte slices | ||||||
| 	{input: "01", ptr: new([]byte), value: []byte{1}}, | 	{input: "01", ptr: new([]byte), value: []byte{1}}, | ||||||
| @@ -214,7 +231,7 @@ var decodeTests = []decodeTest{ | |||||||
| 	{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, | 	{input: "8D6162636465666768696A6B6C6D", ptr: new([]byte), value: []byte("abcdefghijklm")}, | ||||||
| 	{input: "C0", ptr: new([]byte), value: []byte{}}, | 	{input: "C0", ptr: new([]byte), value: []byte{}}, | ||||||
| 	{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, | 	{input: "C3010203", ptr: new([]byte), value: []byte{1, 2, 3}}, | ||||||
| 	{input: "C3820102", ptr: new([]byte), error: errors.New("rlp: string is larger than 8 bits")}, | 	{input: "C3820102", ptr: new([]byte), error: "rlp: input string too big for uint8"}, | ||||||
|  |  | ||||||
| 	// byte arrays | 	// byte arrays | ||||||
| 	{input: "01", ptr: new([5]byte), value: [5]byte{1}}, | 	{input: "01", ptr: new([5]byte), value: [5]byte{1}}, | ||||||
| @@ -222,9 +239,9 @@ var decodeTests = []decodeTest{ | |||||||
| 	{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, | 	{input: "850102030405", ptr: new([5]byte), value: [5]byte{1, 2, 3, 4, 5}}, | ||||||
| 	{input: "C0", ptr: new([5]byte), value: [5]byte{}}, | 	{input: "C0", ptr: new([5]byte), value: [5]byte{}}, | ||||||
| 	{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, | 	{input: "C3010203", ptr: new([5]byte), value: [5]byte{1, 2, 3, 0, 0}}, | ||||||
| 	{input: "C3820102", ptr: new([5]byte), error: errors.New("rlp: string is larger than 8 bits")}, | 	{input: "C3820102", ptr: new([5]byte), error: "rlp: input string too big for uint8"}, | ||||||
| 	{input: "86010203040506", ptr: new([5]byte), error: errStringDoesntFitArray}, | 	{input: "86010203040506", ptr: new([5]byte), error: "rlp: input string too big for [5]uint8"}, | ||||||
| 	{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF}, | 	{input: "850101", ptr: new([5]byte), error: io.ErrUnexpectedEOF.Error()}, | ||||||
|  |  | ||||||
| 	// byte array reuse (should be zeroed) | 	// byte array reuse (should be zeroed) | ||||||
| 	{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, | 	{input: "850102030405", ptr: &sharedByteArray, value: [5]byte{1, 2, 3, 4, 5}}, | ||||||
| @@ -237,25 +254,25 @@ var decodeTests = []decodeTest{ | |||||||
| 	// zero sized byte arrays | 	// zero sized byte arrays | ||||||
| 	{input: "80", ptr: new([0]byte), value: [0]byte{}}, | 	{input: "80", ptr: new([0]byte), value: [0]byte{}}, | ||||||
| 	{input: "C0", ptr: new([0]byte), value: [0]byte{}}, | 	{input: "C0", ptr: new([0]byte), value: [0]byte{}}, | ||||||
| 	{input: "01", ptr: new([0]byte), error: errStringDoesntFitArray}, | 	{input: "01", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, | ||||||
| 	{input: "8101", ptr: new([0]byte), error: errStringDoesntFitArray}, | 	{input: "8101", ptr: new([0]byte), error: "rlp: input string too big for [0]uint8"}, | ||||||
|  |  | ||||||
| 	// strings | 	// strings | ||||||
| 	{input: "00", ptr: new(string), value: "\000"}, | 	{input: "00", ptr: new(string), value: "\000"}, | ||||||
| 	{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, | 	{input: "8D6162636465666768696A6B6C6D", ptr: new(string), value: "abcdefghijklm"}, | ||||||
| 	{input: "C0", ptr: new(string), error: ErrExpectedString}, | 	{input: "C0", ptr: new(string), error: ErrExpectedString.Error()}, | ||||||
|  |  | ||||||
| 	// big ints | 	// big ints | ||||||
| 	{input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, | 	{input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, | ||||||
| 	{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, | 	{input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, | ||||||
| 	{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works | 	{input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works | ||||||
| 	{input: "C0", ptr: new(*big.Int), error: ErrExpectedString}, | 	{input: "C0", ptr: new(*big.Int), error: ErrExpectedString.Error()}, | ||||||
|  |  | ||||||
| 	// structs | 	// structs | ||||||
| 	{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, | 	{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}}, | ||||||
| 	{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, | 	{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}}, | ||||||
| 	{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, | 	{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}}, | ||||||
| 	{input: "C3010101", ptr: new(simplestruct), error: errors.New("rlp: input List has too many elements")}, | 	{input: "C3010101", ptr: new(simplestruct), error: "rlp: input list has too many elements for rlp.simplestruct"}, | ||||||
| 	{ | 	{ | ||||||
| 		input: "C501C302C103", | 		input: "C501C302C103", | ||||||
| 		ptr:   new(recstruct), | 		ptr:   new(recstruct), | ||||||
| @@ -286,20 +303,20 @@ var decodeTests = []decodeTest{ | |||||||
|  |  | ||||||
| func intp(i int) *int { return &i } | func intp(i int) *int { return &i } | ||||||
|  |  | ||||||
| func TestDecode(t *testing.T) { | func runTests(t *testing.T, decode func([]byte, interface{}) error) { | ||||||
| 	for i, test := range decodeTests { | 	for i, test := range decodeTests { | ||||||
| 		input, err := hex.DecodeString(test.input) | 		input, err := hex.DecodeString(test.input) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Errorf("test %d: invalid hex input %q", i, test.input) | 			t.Errorf("test %d: invalid hex input %q", i, test.input) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		err = Decode(bytes.NewReader(input), test.ptr) | 		err = decode(input, test.ptr) | ||||||
| 		if err != nil && test.error == nil { | 		if err != nil && test.error == "" { | ||||||
| 			t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", | 			t.Errorf("test %d: unexpected Decode error: %v\ndecoding into %T\ninput %q", | ||||||
| 				i, err, test.ptr, test.input) | 				i, err, test.ptr, test.input) | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		if test.error != nil && fmt.Sprint(err) != fmt.Sprint(test.error) { | 		if test.error != "" && fmt.Sprint(err) != test.error { | ||||||
| 			t.Errorf("test %d: Decode error mismatch\ngot  %v\nwant %v\ndecoding into %T\ninput %q", | 			t.Errorf("test %d: Decode error mismatch\ngot  %v\nwant %v\ndecoding into %T\ninput %q", | ||||||
| 				i, err, test.error, test.ptr, test.input) | 				i, err, test.error, test.ptr, test.input) | ||||||
| 			continue | 			continue | ||||||
| @@ -312,6 +329,40 @@ func TestDecode(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestDecodeWithByteReader(t *testing.T) { | ||||||
|  | 	runTests(t, func(input []byte, into interface{}) error { | ||||||
|  | 		return Decode(bytes.NewReader(input), into) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // dumbReader reads from a byte slice but does not | ||||||
|  | // implement ReadByte. | ||||||
|  | type dumbReader []byte | ||||||
|  |  | ||||||
|  | func (r *dumbReader) Read(buf []byte) (n int, err error) { | ||||||
|  | 	if len(*r) == 0 { | ||||||
|  | 		return 0, io.EOF | ||||||
|  | 	} | ||||||
|  | 	n = copy(buf, *r) | ||||||
|  | 	*r = (*r)[n:] | ||||||
|  | 	return n, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecodeWithNonByteReader(t *testing.T) { | ||||||
|  | 	runTests(t, func(input []byte, into interface{}) error { | ||||||
|  | 		r := dumbReader(input) | ||||||
|  | 		return Decode(&r, into) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestDecodeStreamReset(t *testing.T) { | ||||||
|  | 	s := NewStream(nil) | ||||||
|  | 	runTests(t, func(input []byte, into interface{}) error { | ||||||
|  | 		s.Reset(bytes.NewReader(input)) | ||||||
|  | 		return s.Decode(into) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
| type testDecoder struct{ called bool } | type testDecoder struct{ called bool } | ||||||
|  |  | ||||||
| func (t *testDecoder) DecodeRLP(s *Stream) error { | func (t *testDecoder) DecodeRLP(s *Stream) error { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user