crypto/ecies: improve concatKDF (#20836)

This removes a bunch of weird code around the counter overflow check in
concatKDF and makes it actually work for different hash output sizes.

The overflow check worked as follows: concatKDF applies the hash function N
times, where N is roundup(kdLen, hashsize) / hashsize. N should not
overflow 32 bits because that would lead to a repetition in the KDF output.

A couple issues with the overflow check:

- It used the hash.BlockSize, which is wrong because the
  block size is about the input of the hash function. Luckily, all standard
  hash functions have a block size that's greater than the output size, so
  concatKDF didn't crash, it just generated too much key material.
- The check used big.Int to compare against 2^32-1.
- The calculation could still overflow before reaching the check.

The new code in concatKDF doesn't check for overflow. Instead, there is a
new check on ECIESParams which ensures that params.KeyLen is < 512. This
removes any possibility of overflow.

There are a couple of miscellaneous improvements bundled in with this
change:

- The key buffer is pre-allocated instead of appending the hash output
  to an initially empty slice.
- The code that uses concatKDF to derive keys is now shared between Encrypt
  and Decrypt.
- There was a redundant invocation of IsOnCurve in Decrypt. This is now removed
  because elliptic.Unmarshal already checks whether the input is a valid curve
  point since Go 1.5.

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
Luke Champine
2020-04-03 05:57:24 -04:00
committed by GitHub
parent f7b29ec942
commit 462ddce5b2
3 changed files with 94 additions and 110 deletions

View File

@@ -35,6 +35,7 @@ import (
"crypto/elliptic"
"crypto/hmac"
"crypto/subtle"
"encoding/binary"
"fmt"
"hash"
"io"
@@ -44,7 +45,6 @@ import (
var (
ErrImport = fmt.Errorf("ecies: failed to import key")
ErrInvalidCurve = fmt.Errorf("ecies: invalid elliptic curve")
ErrInvalidParams = fmt.Errorf("ecies: invalid ECIES parameters")
ErrInvalidPublicKey = fmt.Errorf("ecies: invalid public key")
ErrSharedKeyIsPointAtInfinity = fmt.Errorf("ecies: shared key is point at infinity")
ErrSharedKeyTooBig = fmt.Errorf("ecies: shared key params are too big")
@@ -138,57 +138,39 @@ func (prv *PrivateKey) GenerateShared(pub *PublicKey, skLen, macLen int) (sk []b
}
var (
ErrKeyDataTooLong = fmt.Errorf("ecies: can't supply requested key data")
ErrSharedTooLong = fmt.Errorf("ecies: shared secret is too long")
ErrInvalidMessage = fmt.Errorf("ecies: invalid message")
)
var (
big2To32 = new(big.Int).Exp(big.NewInt(2), big.NewInt(32), nil)
big2To32M1 = new(big.Int).Sub(big2To32, big.NewInt(1))
)
func incCounter(ctr []byte) {
if ctr[3]++; ctr[3] != 0 {
return
}
if ctr[2]++; ctr[2] != 0 {
return
}
if ctr[1]++; ctr[1] != 0 {
return
}
if ctr[0]++; ctr[0] != 0 {
return
}
}
// NIST SP 800-56 Concatenation Key Derivation Function (see section 5.8.1).
func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) (k []byte, err error) {
if s1 == nil {
s1 = make([]byte, 0)
}
reps := ((kdLen + 7) * 8) / (hash.BlockSize() * 8)
if big.NewInt(int64(reps)).Cmp(big2To32M1) > 0 {
fmt.Println(big2To32M1)
return nil, ErrKeyDataTooLong
}
counter := []byte{0, 0, 0, 1}
k = make([]byte, 0)
for i := 0; i <= reps; i++ {
hash.Write(counter)
func concatKDF(hash hash.Hash, z, s1 []byte, kdLen int) []byte {
counterBytes := make([]byte, 4)
k := make([]byte, 0, roundup(kdLen, hash.Size()))
for counter := uint32(1); len(k) < kdLen; counter++ {
binary.BigEndian.PutUint32(counterBytes, counter)
hash.Reset()
hash.Write(counterBytes)
hash.Write(z)
hash.Write(s1)
k = append(k, hash.Sum(nil)...)
hash.Reset()
incCounter(counter)
k = hash.Sum(k)
}
return k[:kdLen]
}
k = k[:kdLen]
return
// roundup rounds size up to the next multiple of blocksize.
func roundup(size, blocksize int) int {
return size + blocksize - (size % blocksize)
}
// deriveKeys creates the encryption and MAC keys using concatKDF.
func deriveKeys(hash hash.Hash, z, s1 []byte, keyLen int) (Ke, Km []byte) {
K := concatKDF(hash, z, s1, 2*keyLen)
Ke = K[:keyLen]
Km = K[keyLen:]
hash.Reset()
hash.Write(Km)
Km = hash.Sum(Km[:0])
return Ke, Km
}
// messageTag computes the MAC of a message (called the tag) as per
@@ -209,7 +191,6 @@ func generateIV(params *ECIESParams, rand io.Reader) (iv []byte, err error) {
}
// symEncrypt carries out CTR encryption using the block cipher specified in the
// parameters.
func symEncrypt(rand io.Reader, params *ECIESParams, key, m []byte) (ct []byte, err error) {
c, err := params.Cipher(key)
if err != nil {
@@ -249,36 +230,27 @@ func symDecrypt(params *ECIESParams, key, ct []byte) (m []byte, err error) {
// ciphertext. s1 is fed into key derivation, s2 is fed into the MAC. If the
// shared information parameters aren't being used, they should be nil.
func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err error) {
params := pub.Params
if params == nil {
if params = ParamsFromCurve(pub.Curve); params == nil {
err = ErrUnsupportedECIESParameters
return
}
params, err := pubkeyParams(pub)
if err != nil {
return nil, err
}
R, err := GenerateKey(rand, pub.Curve, params)
if err != nil {
return
return nil, err
}
z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen)
if err != nil {
return nil, err
}
hash := params.Hash()
z, err := R.GenerateShared(pub, params.KeyLen, params.KeyLen)
if err != nil {
return
}
K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
if err != nil {
return
}
Ke := K[:params.KeyLen]
Km := K[params.KeyLen:]
hash.Write(Km)
Km = hash.Sum(nil)
hash.Reset()
Ke, Km := deriveKeys(hash, z, s1, params.KeyLen)
em, err := symEncrypt(rand, params, Ke, m)
if err != nil || len(em) <= params.BlockSize {
return
return nil, err
}
d := messageTag(params.Hash, Km, em, s2)
@@ -288,7 +260,7 @@ func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err e
copy(ct, Rb)
copy(ct[len(Rb):], em)
copy(ct[len(Rb)+len(em):], d)
return
return ct, nil
}
// Decrypt decrypts an ECIES ciphertext.
@@ -296,13 +268,11 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
if len(c) == 0 {
return nil, ErrInvalidMessage
}
params := prv.PublicKey.Params
if params == nil {
if params = ParamsFromCurve(prv.PublicKey.Curve); params == nil {
err = ErrUnsupportedECIESParameters
return
}
params, err := pubkeyParams(&prv.PublicKey)
if err != nil {
return nil, err
}
hash := params.Hash()
var (
@@ -316,12 +286,10 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
case 2, 3, 4:
rLen = (prv.PublicKey.Curve.Params().BitSize + 7) / 4
if len(c) < (rLen + hLen + 1) {
err = ErrInvalidMessage
return
return nil, ErrInvalidMessage
}
default:
err = ErrInvalidPublicKey
return
return nil, ErrInvalidPublicKey
}
mStart = rLen
@@ -331,36 +299,19 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) {
R.Curve = prv.PublicKey.Curve
R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen])
if R.X == nil {
err = ErrInvalidPublicKey
return
}
if !R.Curve.IsOnCurve(R.X, R.Y) {
err = ErrInvalidCurve
return
return nil, ErrInvalidPublicKey
}
z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen)
if err != nil {
return
return nil, err
}
K, err := concatKDF(hash, z, s1, params.KeyLen+params.KeyLen)
if err != nil {
return
}
Ke := K[:params.KeyLen]
Km := K[params.KeyLen:]
hash.Write(Km)
Km = hash.Sum(nil)
hash.Reset()
Ke, Km := deriveKeys(hash, z, s1, params.KeyLen)
d := messageTag(params.Hash, Km, c[mStart:mEnd], s2)
if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 {
err = ErrInvalidMessage
return
return nil, ErrInvalidMessage
}
m, err = symDecrypt(params, Ke, c[mStart:mEnd])
return
return symDecrypt(params, Ke, c[mStart:mEnd])
}