common: added Hash unmarshal json length validation
This commit is contained in:
		| @@ -19,10 +19,12 @@ package common | |||||||
| import ( | import ( | ||||||
| 	"encoding/hex" | 	"encoding/hex" | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"math/big" | 	"math/big" | ||||||
| 	"math/rand" | 	"math/rand" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -30,6 +32,8 @@ const ( | |||||||
| 	AddressLength = 20 | 	AddressLength = 20 | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | var hashJsonLengthErr = errors.New("common: unmarshalJSON failed: hash must be exactly 32 bytes") | ||||||
|  |  | ||||||
| type ( | type ( | ||||||
| 	Hash    [HashLength]byte | 	Hash    [HashLength]byte | ||||||
| 	Address [AddressLength]byte | 	Address [AddressLength]byte | ||||||
| @@ -58,6 +62,15 @@ func (h *Hash) UnmarshalJSON(input []byte) error { | |||||||
| 	if length >= 2 && input[0] == '"' && input[length-1] == '"' { | 	if length >= 2 && input[0] == '"' && input[length-1] == '"' { | ||||||
| 		input = input[1 : length-1] | 		input = input[1 : length-1] | ||||||
| 	} | 	} | ||||||
|  | 	// strip "0x" for length check | ||||||
|  | 	if len(input) > 1 && strings.ToLower(string(input[:2])) == "0x" { | ||||||
|  | 		input = input[2:] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// validate the length of the input hash | ||||||
|  | 	if len(input) != HashLength*2 { | ||||||
|  | 		return hashJsonLengthErr | ||||||
|  | 	} | ||||||
| 	h.SetBytes(FromHex(string(input))) | 	h.SetBytes(FromHex(string(input))) | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -29,3 +29,25 @@ func TestBytesConversion(t *testing.T) { | |||||||
| 		t.Errorf("expected %x got %x", exp, hash) | 		t.Errorf("expected %x got %x", exp, hash) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestHashJsonValidation(t *testing.T) { | ||||||
|  | 	var h Hash | ||||||
|  | 	var tests = []struct { | ||||||
|  | 		Prefix string | ||||||
|  | 		Size   int | ||||||
|  | 		Error  error | ||||||
|  | 	}{ | ||||||
|  | 		{"", 2, hashJsonLengthErr}, | ||||||
|  | 		{"", 62, hashJsonLengthErr}, | ||||||
|  | 		{"", 66, hashJsonLengthErr}, | ||||||
|  | 		{"", 65, hashJsonLengthErr}, | ||||||
|  | 		{"0X", 64, nil}, | ||||||
|  | 		{"0x", 64, nil}, | ||||||
|  | 		{"0x", 62, hashJsonLengthErr}, | ||||||
|  | 	} | ||||||
|  | 	for i, test := range tests { | ||||||
|  | 		if err := h.UnmarshalJSON(append([]byte(test.Prefix), make([]byte, test.Size)...)); err != test.Error { | ||||||
|  | 			t.Error(i, "expected", test.Error, "got", err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user