p2p: make encryption handshake code easier to follow
This mostly changes how information is passed around. Instead of using many function parameters and return values, put the entire state in a struct and pass that. This also adds back derivation of ecdhe-shared-secret. I deleted it by accident in a previous refactoring.
This commit is contained in:
@ -2,51 +2,18 @@ package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/crypto/ecies"
|
||||
"github.com/ethereum/go-ethereum/p2p/discover"
|
||||
)
|
||||
|
||||
func TestPublicKeyEncoding(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
pub0s := crypto.FromECDSAPub(pub0)
|
||||
pub1, err := importPublicKey(pub0s)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
eciesPub1 := ecies.ImportECDSAPublic(pub1)
|
||||
if eciesPub1 == nil {
|
||||
t.Errorf("invalid ecdsa public key")
|
||||
}
|
||||
pub1s, err := exportPublicKey(pub1)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if len(pub1s) != 64 {
|
||||
t.Errorf("wrong length expect 64, got", len(pub1s))
|
||||
}
|
||||
pub2, err := importPublicKey(pub1s)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
pub2s, err := exportPublicKey(pub2)
|
||||
if err != nil {
|
||||
t.Errorf("%v", err)
|
||||
}
|
||||
if !bytes.Equal(pub1s, pub2s) {
|
||||
t.Errorf("exports dont match")
|
||||
}
|
||||
pub2sEC := crypto.FromECDSAPub(pub2)
|
||||
if !bytes.Equal(pub0s, pub2sEC) {
|
||||
t.Errorf("exports dont match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSharedSecret(t *testing.T) {
|
||||
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
|
||||
pub0 := &prv0.PublicKey
|
||||
@ -68,46 +35,84 @@ func TestSharedSecret(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEncHandshake(t *testing.T) {
|
||||
defer testlog(t).detach()
|
||||
for i := 0; i < 20; i++ {
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(nil); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(without token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
|
||||
prv0, _ := crypto.GenerateKey()
|
||||
prv1, _ := crypto.GenerateKey()
|
||||
rw0, rw1 := net.Pipe()
|
||||
secrets := make(chan secrets)
|
||||
for i := 0; i < 20; i++ {
|
||||
tok := make([]byte, shaLen)
|
||||
rand.Reader.Read(tok)
|
||||
start := time.Now()
|
||||
if err := testEncHandshake(tok); err != nil {
|
||||
t.Fatalf("i=%d %v", i, err)
|
||||
}
|
||||
t.Logf("(with token) %d %v\n", i+1, time.Since(start))
|
||||
}
|
||||
}
|
||||
|
||||
func testEncHandshake(token []byte) error {
|
||||
type result struct {
|
||||
side string
|
||||
s secrets
|
||||
err error
|
||||
}
|
||||
var (
|
||||
prv0, _ = crypto.GenerateKey()
|
||||
prv1, _ = crypto.GenerateKey()
|
||||
rw0, rw1 = net.Pipe()
|
||||
output = make(chan result)
|
||||
)
|
||||
|
||||
go func() {
|
||||
pub1s, _ := exportPublicKey(&prv1.PublicKey)
|
||||
s, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
|
||||
if err != nil {
|
||||
t.Errorf("outbound side error: %v", err)
|
||||
r := result{side: "initiator"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
pub1s := discover.PubkeyID(&prv1.PublicKey)
|
||||
r.s, r.err = initiatorEncHandshake(rw0, prv0, pub1s, token)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id1 := discover.PubkeyID(&prv1.PublicKey)
|
||||
if s.RemoteID != id1 {
|
||||
t.Errorf("outbound side remote ID mismatch")
|
||||
if r.s.RemoteID != id1 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id1)
|
||||
}
|
||||
secrets <- s
|
||||
}()
|
||||
go func() {
|
||||
s, err := inboundEncHandshake(rw1, prv1, nil)
|
||||
if err != nil {
|
||||
t.Errorf("inbound side error: %v", err)
|
||||
r := result{side: "receiver"}
|
||||
defer func() { output <- r }()
|
||||
|
||||
r.s, r.err = receiverEncHandshake(rw1, prv1, token)
|
||||
if r.err != nil {
|
||||
return
|
||||
}
|
||||
id0 := discover.PubkeyID(&prv0.PublicKey)
|
||||
if s.RemoteID != id0 {
|
||||
t.Errorf("inbound side remote ID mismatch")
|
||||
if r.s.RemoteID != id0 {
|
||||
r.err = fmt.Errorf("remote ID mismatch: got %v, want: %v", r.s.RemoteID, id0)
|
||||
}
|
||||
secrets <- s
|
||||
}()
|
||||
|
||||
// get computed secrets from both sides
|
||||
t1, t2 := <-secrets, <-secrets
|
||||
// don't compare remote node IDs
|
||||
t1.RemoteID, t2.RemoteID = discover.NodeID{}, discover.NodeID{}
|
||||
// flip MACs on one of them so they compare equal
|
||||
t1.EgressMAC, t1.IngressMAC = t1.IngressMAC, t1.EgressMAC
|
||||
if !reflect.DeepEqual(t1, t2) {
|
||||
t.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", t1, t2)
|
||||
// wait for results from both sides
|
||||
r1, r2 := <-output, <-output
|
||||
|
||||
if r1.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r1.side, r1.err)
|
||||
}
|
||||
if r2.err != nil {
|
||||
return fmt.Errorf("%s side error: %v", r2.side, r2.err)
|
||||
}
|
||||
|
||||
// don't compare remote node IDs
|
||||
r1.s.RemoteID, r2.s.RemoteID = discover.NodeID{}, discover.NodeID{}
|
||||
// flip MACs on one of them so they compare equal
|
||||
r1.s.EgressMAC, r1.s.IngressMAC = r1.s.IngressMAC, r1.s.EgressMAC
|
||||
if !reflect.DeepEqual(r1.s, r2.s) {
|
||||
return fmt.Errorf("secrets mismatch:\n t1: %#v\n t2: %#v", r1.s, r2.s)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSetupConn(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user