This mostly changes how information is passed around. Instead of using many function parameters and return values, put the entire state in a struct and pass that. This also adds back derivation of ecdhe-shared-secret. I deleted it by accident in a previous refactoring.
		
			
				
	
	
		
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			172 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package p2p
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"crypto/rand"
 | 
						|
	"fmt"
 | 
						|
	"net"
 | 
						|
	"reflect"
 | 
						|
	"testing"
 | 
						|
	"time"
 | 
						|
 | 
						|
	"github.com/ethereum/go-ethereum/crypto"
 | 
						|
	"github.com/ethereum/go-ethereum/crypto/ecies"
 | 
						|
	"github.com/ethereum/go-ethereum/p2p/discover"
 | 
						|
)
 | 
						|
 | 
						|
func TestSharedSecret(t *testing.T) {
 | 
						|
	prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
 | 
						|
	pub0 := &prv0.PublicKey
 | 
						|
	prv1, _ := crypto.GenerateKey()
 | 
						|
	pub1 := &prv1.PublicKey
 | 
						|
 | 
						|
	ss0, err := ecies.ImportECDSA(prv0).GenerateShared(ecies.ImportECDSAPublic(pub1), sskLen, sskLen)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	ss1, err := ecies.ImportECDSA(prv1).GenerateShared(ecies.ImportECDSAPublic(pub0), sskLen, sskLen)
 | 
						|
	if err != nil {
 | 
						|
		return
 | 
						|
	}
 | 
						|
	t.Logf("Secret:\n%v %x\n%v %x", len(ss0), ss0, len(ss0), ss1)
 | 
						|
	if !bytes.Equal(ss0, ss1) {
 | 
						|
		t.Errorf("dont match :(")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestEncHandshake(t *testing.T) {
 | 
						|
	for i := 0; i < 20; i++ {
 | 
						|
		start := time.Now()
 | 
						|
		if err := testEncHandshake(nil); err != nil {
 | 
						|
			t.Fatalf("i=%d %v", i, err)
 | 
						|
		}
 | 
						|
		t.Logf("(without token) %d %v\n", i+1, time.Since(start))
 | 
						|
	}
 | 
						|
 | 
						|
	for i := 0; i < 20; i++ {
 | 
						|
		tok := make([]byte, shaLen)
 | 
						|
		rand.Reader.Read(tok)
 | 
						|
		start := time.Now()
 | 
						|
		if err := testEncHandshake(tok); err != nil {
 | 
						|
			t.Fatalf("i=%d %v", i, err)
 | 
						|
		}
 | 
						|
		t.Logf("(with token) %d %v\n", i+1, time.Since(start))
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func testEncHandshake(token []byte) error {
 | 
						|
	type result struct {
 | 
						|
		side string
 | 
						|
		s    secrets
 | 
						|
		err  error
 | 
						|
	}
 | 
						|
	var (
 | 
						|
		prv0, _  = crypto.GenerateKey()
 | 
						|
		prv1, _  = crypto.GenerateKey()
 | 
						|
		rw0, rw1 = net.Pipe()
 | 
						|
		output   = make(chan result)
 | 
						|
	)
 | 
						|
 | 
						|
	go func() {
 | 
						|
		r := result{side: "initiator"}
 | 
						|
		defer func() { output <- r }()
 | 
						|
 | 
						|
		pub1s := discover.PubkeyID(&prv1.PublicKey)
 | 
						|
		r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
 | 
						|
		if r.err != nil {
 | 
						|
			return
 | 
						|
		}
 | 
						|
		id1 := discover.PubkeyID(&prv1.PublicKey)
 | 
						|
		if r.s.RemoteID != id1 {
 | 
						|
			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
	go func() {
 | 
						|
		r := result{side: "receiver"}
 | 
						|
		defer func() { output <- r }()
 | 
						|
 | 
						|
		r.s, r.err = receiverEncHandshake(rw1, prv1, token)
 | 
						|
		if r.err != nil {
 | 
						|
			return
 | 
						|
		}
 | 
						|
		id0 := discover.PubkeyID(&prv0.PublicKey)
 | 
						|
		if r.s.RemoteID != id0 {
 | 
						|
			r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	// wait for results from both sides
 | 
						|
	r1, r2 := <-output, <-output
 | 
						|
 | 
						|
	if r1.err != nil {
 | 
						|
		return fmt.Errorf("%s side error: %v", r1.side, r1.err)
 | 
						|
	}
 | 
						|
	if r2.err != nil {
 | 
						|
		return fmt.Errorf("%s side error: %v", r2.side, r2.err)
 | 
						|
	}
 | 
						|
 | 
						|
	// don't compare remote node IDs
 | 
						|
	r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
 | 
						|
	// flip MACs on one of them so they compare equal
 | 
						|
	r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
 | 
						|
	if !reflect.DeepEqual(r1.s, r2.s) {
 | 
						|
		return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
func TestSetupConn(t *testing.T) {
 | 
						|
	prv0, _ := crypto.GenerateKey()
 | 
						|
	prv1, _ := crypto.GenerateKey()
 | 
						|
	node0 := &discover.Node{
 | 
						|
		ID:      discover.PubkeyID(&prv0.PublicKey),
 | 
						|
		IP:      net.IP{1, 2, 3, 4},
 | 
						|
		TCPPort: 33,
 | 
						|
	}
 | 
						|
	node1 := &discover.Node{
 | 
						|
		ID:      discover.PubkeyID(&prv1.PublicKey),
 | 
						|
		IP:      net.IP{5, 6, 7, 8},
 | 
						|
		TCPPort: 44,
 | 
						|
	}
 | 
						|
	hs0 := &protoHandshake{
 | 
						|
		Version: baseProtocolVersion,
 | 
						|
		ID:      node0.ID,
 | 
						|
		Caps:    []Cap{{"a", 0}, {"b", 2}},
 | 
						|
	}
 | 
						|
	hs1 := &protoHandshake{
 | 
						|
		Version: baseProtocolVersion,
 | 
						|
		ID:      node1.ID,
 | 
						|
		Caps:    []Cap{{"c", 1}, {"d", 3}},
 | 
						|
	}
 | 
						|
	fd0, fd1 := net.Pipe()
 | 
						|
 | 
						|
	done := make(chan struct{})
 | 
						|
	go func() {
 | 
						|
		defer close(done)
 | 
						|
		conn0, err := setupConn(fd0, prv0, hs0, node1)
 | 
						|
		if err != nil {
 | 
						|
			t.Errorf("outbound side error: %v", err)
 | 
						|
			return
 | 
						|
		}
 | 
						|
		if conn0.ID != node1.ID {
 | 
						|
			t.Errorf("outbound conn id mismatch: got %v, want %v", conn0.ID, node1.ID)
 | 
						|
		}
 | 
						|
		if !reflect.DeepEqual(conn0.Caps, hs1.Caps) {
 | 
						|
			t.Errorf("outbound caps mismatch: got %v, want %v", conn0.Caps, hs1.Caps)
 | 
						|
		}
 | 
						|
	}()
 | 
						|
 | 
						|
	conn1, err := setupConn(fd1, prv1, hs1, nil)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("inbound side error: %v", err)
 | 
						|
	}
 | 
						|
	if conn1.ID != node0.ID {
 | 
						|
		t.Errorf("inbound conn id mismatch: got %v, want %v", conn1.ID, node0.ID)
 | 
						|
	}
 | 
						|
	if !reflect.DeepEqual(conn1.Caps, hs0.Caps) {
 | 
						|
		t.Errorf("inbound caps mismatch: got %v, want %v", conn1.Caps, hs0.Caps)
 | 
						|
	}
 | 
						|
 | 
						|
	<-done
 | 
						|
}
 |