Add versioned dependencies from godep

This commit is contained in:
Taylor Gerring
2015-02-16 14:28:33 +01:00
parent 202362d925
commit 702218008e
667 changed files with 435414 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
Copyright (c) 2009 Google Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,84 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"fmt"
"os"
)
// A Domain represents a Version 2 domain
type Domain byte
// Domain constants for DCE Security (Version 2) UUIDs.
const (
Person = Domain(0)
Group = Domain(1)
Org = Domain(2)
)
// NewDCESecurity returns a DCE Security (Version 2) UUID.
//
// The domain should be one of Person, Group or Org.
// On a POSIX system the id should be the users UID for the Person
// domain and the users GID for the Group. The meaning of id for
// the domain Org or on non-POSIX systems is site defined.
//
// For a given domain/id pair the same token may be returned for up to
// 7 minutes and 10 seconds.
func NewDCESecurity(domain Domain, id uint32) UUID {
uuid := NewUUID()
if uuid != nil {
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
uuid[9] = byte(domain)
binary.BigEndian.PutUint32(uuid[0:], id)
}
return uuid
}
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
// domain with the id returned by os.Getuid.
//
// NewDCEPerson(Person, uint32(os.Getuid()))
func NewDCEPerson() UUID {
return NewDCESecurity(Person, uint32(os.Getuid()))
}
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
// domain with the id returned by os.Getgid.
//
// NewDCEGroup(Group, uint32(os.Getgid()))
func NewDCEGroup() UUID {
return NewDCESecurity(Group, uint32(os.Getgid()))
}
// Domain returns the domain for a Version 2 UUID or false.
func (uuid UUID) Domain() (Domain, bool) {
if v, _ := uuid.Version(); v != 2 {
return 0, false
}
return Domain(uuid[9]), true
}
// Id returns the id for a Version 2 UUID or false.
func (uuid UUID) Id() (uint32, bool) {
if v, _ := uuid.Version(); v != 2 {
return 0, false
}
return binary.BigEndian.Uint32(uuid[0:4]), true
}
func (d Domain) String() string {
switch d {
case Person:
return "Person"
case Group:
return "Group"
case Org:
return "Org"
}
return fmt.Sprintf("Domain%d", int(d))
}

View File

@@ -0,0 +1,8 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The uuid package generates and inspects UUIDs.
//
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security Services.
package uuid

View File

@@ -0,0 +1,53 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"crypto/md5"
"crypto/sha1"
"hash"
)
// Well known Name Space IDs and UUIDs
var (
NameSpace_DNS = Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
NameSpace_URL = Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
NameSpace_OID = Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8")
NameSpace_X500 = Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8")
NIL = Parse("00000000-0000-0000-0000-000000000000")
)
// NewHash returns a new UUID dervied from the hash of space concatenated with
// data generated by h. The hash should be at least 16 byte in length. The
// first 16 bytes of the hash are used to form the UUID. The version of the
// UUID will be the lower 4 bits of version. NewHash is used to implement
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space)
h.Write([]byte(data))
s := h.Sum(nil)
uuid := make([]byte, 16)
copy(uuid, s)
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
return uuid
}
// NewMD5 returns a new MD5 (Version 3) UUID based on the
// supplied name space and data.
//
// NewHash(md5.New(), space, data, 3)
func NewMD5(space UUID, data []byte) UUID {
return NewHash(md5.New(), space, data, 3)
}
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
// supplied name space and data.
//
// NewHash(sha1.New(), space, data, 5)
func NewSHA1(space UUID, data []byte) UUID {
return NewHash(sha1.New(), space, data, 5)
}

View File

@@ -0,0 +1,101 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "net"
var (
interfaces []net.Interface // cached list of interfaces
ifname string // name of interface being used
nodeID []byte // hardware for version 1 UUIDs
)
// NodeInterface returns the name of the interface from which the NodeID was
// derived. The interface "user" is returned if the NodeID was set by
// SetNodeID.
func NodeInterface() string {
return ifname
}
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
// If name is "" then the first usable interface found will be used or a random
// Node ID will be generated. If a named interface cannot be found then false
// is returned.
//
// SetNodeInterface never fails when name is "".
func SetNodeInterface(name string) bool {
if interfaces == nil {
var err error
interfaces, err = net.Interfaces()
if err != nil && name != "" {
return false
}
}
for _, ifs := range interfaces {
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
if setNodeID(ifs.HardwareAddr) {
ifname = ifs.Name
return true
}
}
}
// We found no interfaces with a valid hardware address. If name
// does not specify a specific interface generate a random Node ID
// (section 4.1.6)
if name == "" {
if nodeID == nil {
nodeID = make([]byte, 6)
}
randomBits(nodeID)
return true
}
return false
}
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
// if not already set.
func NodeID() []byte {
if nodeID == nil {
SetNodeInterface("")
}
nid := make([]byte, 6)
copy(nid, nodeID)
return nid
}
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
// of id are used. If id is less than 6 bytes then false is returned and the
// Node ID is not set.
func SetNodeID(id []byte) bool {
if setNodeID(id) {
ifname = "user"
return true
}
return false
}
func setNodeID(id []byte) bool {
if len(id) < 6 {
return false
}
if nodeID == nil {
nodeID = make([]byte, 6)
}
copy(nodeID, id)
return true
}
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) NodeID() []byte {
if len(uuid) != 16 {
return nil
}
node := make([]byte, 6)
copy(node, uuid[10:])
return node
}

View File

@@ -0,0 +1,132 @@
// Copyright 2014 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"sync"
"time"
)
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
// 1582.
type Time int64
const (
lillian = 2299160 // Julian day of 15 Oct 1582
unix = 2440587 // Julian day of 1 Jan 1970
epoch = unix - lillian // Days between epochs
g1582 = epoch * 86400 // seconds between epochs
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
)
var (
mu sync.Mutex
lasttime uint64 // last time we returned
clock_seq uint16 // clock sequence for this run
timeNow = time.Now // for testing
)
// UnixTime converts t the number of seconds and nanoseconds using the Unix
// epoch of 1 Jan 1970.
func (t Time) UnixTime() (sec, nsec int64) {
sec = int64(t - g1582ns100)
nsec = (sec % 10000000) * 100
sec /= 10000000
return sec, nsec
}
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
// adjusts the clock sequence as needed. An error is returned if the current
// time cannot be determined.
func GetTime() (Time, error) {
defer mu.Unlock()
mu.Lock()
return getTime()
}
func getTime() (Time, error) {
t := timeNow()
// If we don't have a clock sequence already, set one.
if clock_seq == 0 {
setClockSequence(-1)
}
now := uint64(t.UnixNano()/100) + g1582ns100
// If time has gone backwards with this clock sequence then we
// increment the clock sequence
if now <= lasttime {
clock_seq = ((clock_seq + 1) & 0x3fff) | 0x8000
}
lasttime = now
return Time(now), nil
}
// ClockSequence returns the current clock sequence, generating one if not
// already set. The clock sequence is only used for Version 1 UUIDs.
//
// The uuid package does not use global static storage for the clock sequence or
// the last time a UUID was generated. Unless SetClockSequence a new random
// clock sequence is generated the first time a clock sequence is requested by
// ClockSequence, GetTime, or NewUUID. (section 4.2.1.1) sequence is generated
// for
func ClockSequence() int {
defer mu.Unlock()
mu.Lock()
return clockSequence()
}
func clockSequence() int {
if clock_seq == 0 {
setClockSequence(-1)
}
return int(clock_seq & 0x3fff)
}
// SetClockSeq sets the clock sequence to the lower 14 bits of seq. Setting to
// -1 causes a new sequence to be generated.
func SetClockSequence(seq int) {
defer mu.Unlock()
mu.Lock()
setClockSequence(seq)
}
func setClockSequence(seq int) {
if seq == -1 {
var b [2]byte
randomBits(b[:]) // clock sequence
seq = int(b[0])<<8 | int(b[1])
}
old_seq := clock_seq
clock_seq = uint16(seq&0x3fff) | 0x8000 // Set our variant
if old_seq != clock_seq {
lasttime = 0
}
}
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
// uuid. It returns false if uuid is not valid. The time is only well defined
// for version 1 and 2 UUIDs.
func (uuid UUID) Time() (Time, bool) {
if len(uuid) != 16 {
return 0, false
}
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
return Time(time), true
}
// ClockSequence returns the clock sequence encoded in uuid. It returns false
// if uuid is not valid. The clock sequence is only well defined for version 1
// and 2 UUIDs.
func (uuid UUID) ClockSequence() (int, bool) {
if len(uuid) != 16 {
return 0, false
}
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff, true
}

View File

@@ -0,0 +1,43 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"io"
)
// randomBits completely fills slice b with random data.
func randomBits(b []byte) {
if _, err := io.ReadFull(rander, b); err != nil {
panic(err.Error()) // rand should never fail
}
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = []byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
// xtob converts the the first two hex bytes of x into a byte.
func xtob(x string) (byte, bool) {
b1 := xvalues[x[0]]
b2 := xvalues[x[1]]
return (b1 << 4) | b2, b1 != 255 && b2 != 255
}

View File

@@ -0,0 +1,163 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"strings"
)
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
// 4122.
type UUID []byte
// A Version represents a UUIDs version.
type Version byte
// A Variant represents a UUIDs variant.
type Variant byte
// Constants returned by Variant.
const (
Invalid = Variant(iota) // Invalid UUID
RFC4122 // The variant specified in RFC4122
Reserved // Reserved, NCS backward compatibility.
Microsoft // Reserved, Microsoft Corporation backward compatibility.
Future // Reserved for future definition.
)
var rander = rand.Reader // random function
// New returns a new random (version 4) UUID as a string. It is a convenience
// function for NewRandom().String().
func New() string {
return NewRandom().String()
}
// Parse decodes s into a UUID or returns nil. Both the UUID form of
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded.
func Parse(s string) UUID {
if len(s) == 36+9 {
if strings.ToLower(s[:9]) != "urn:uuid:" {
return nil
}
s = s[9:]
} else if len(s) != 36 {
return nil
}
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return nil
}
uuid := make([]byte, 16)
for i, x := range []int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
if v, ok := xtob(s[x:]); !ok {
return nil
} else {
uuid[i] = v
}
}
return uuid
}
// Equal returns true if uuid1 and uuid2 are equal.
func Equal(uuid1, uuid2 UUID) bool {
return bytes.Equal(uuid1, uuid2)
}
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
// , or "" if uuid is invalid.
func (uuid UUID) String() string {
if uuid == nil || len(uuid) != 16 {
return ""
}
b := []byte(uuid)
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// URN returns the RFC 2141 URN form of uuid,
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
func (uuid UUID) URN() string {
if uuid == nil || len(uuid) != 16 {
return ""
}
b := []byte(uuid)
return fmt.Sprintf("urn:uuid:%08x-%04x-%04x-%04x-%012x",
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// Variant returns the variant encoded in uuid. It returns Invalid if
// uuid is invalid.
func (uuid UUID) Variant() Variant {
if len(uuid) != 16 {
return Invalid
}
switch {
case (uuid[8] & 0xc0) == 0x80:
return RFC4122
case (uuid[8] & 0xe0) == 0xc0:
return Microsoft
case (uuid[8] & 0xe0) == 0xe0:
return Future
default:
return Reserved
}
panic("unreachable")
}
// Version returns the verison of uuid. It returns false if uuid is not
// valid.
func (uuid UUID) Version() (Version, bool) {
if len(uuid) != 16 {
return 0, false
}
return Version(uuid[6] >> 4), true
}
func (v Version) String() string {
if v > 15 {
return fmt.Sprintf("BAD_VERSION_%d", v)
}
return fmt.Sprintf("VERSION_%d", v)
}
func (v Variant) String() string {
switch v {
case RFC4122:
return "RFC4122"
case Reserved:
return "Reserved"
case Microsoft:
return "Microsoft"
case Future:
return "Future"
case Invalid:
return "Invalid"
}
return fmt.Sprintf("BadVariant%d", int(v))
}
// SetRand sets the random number generator to r, which implents io.Reader.
// If r.Read returns an error when the package requests random data then
// a panic will be issued.
//
// Calling SetRand with nil sets the random number generator to the default
// generator.
func SetRand(r io.Reader) {
if r == nil {
rander = rand.Reader
return
}
rander = r
}

View File

@@ -0,0 +1,390 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
"time"
)
type test struct {
in string
version Version
variant Variant
isuuid bool
}
var tests = []test{
{"f47ac10b-58cc-0372-8567-0e02b2c3d479", 0, RFC4122, true},
{"f47ac10b-58cc-1372-8567-0e02b2c3d479", 1, RFC4122, true},
{"f47ac10b-58cc-2372-8567-0e02b2c3d479", 2, RFC4122, true},
{"f47ac10b-58cc-3372-8567-0e02b2c3d479", 3, RFC4122, true},
{"f47ac10b-58cc-4372-8567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-5372-8567-0e02b2c3d479", 5, RFC4122, true},
{"f47ac10b-58cc-6372-8567-0e02b2c3d479", 6, RFC4122, true},
{"f47ac10b-58cc-7372-8567-0e02b2c3d479", 7, RFC4122, true},
{"f47ac10b-58cc-8372-8567-0e02b2c3d479", 8, RFC4122, true},
{"f47ac10b-58cc-9372-8567-0e02b2c3d479", 9, RFC4122, true},
{"f47ac10b-58cc-a372-8567-0e02b2c3d479", 10, RFC4122, true},
{"f47ac10b-58cc-b372-8567-0e02b2c3d479", 11, RFC4122, true},
{"f47ac10b-58cc-c372-8567-0e02b2c3d479", 12, RFC4122, true},
{"f47ac10b-58cc-d372-8567-0e02b2c3d479", 13, RFC4122, true},
{"f47ac10b-58cc-e372-8567-0e02b2c3d479", 14, RFC4122, true},
{"f47ac10b-58cc-f372-8567-0e02b2c3d479", 15, RFC4122, true},
{"urn:uuid:f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"URN:UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-1567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-2567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-3567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-4567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-5567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-6567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-7567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-8567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-9567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-a567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-b567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-c567-0e02b2c3d479", 4, Microsoft, true},
{"f47ac10b-58cc-4372-d567-0e02b2c3d479", 4, Microsoft, true},
{"f47ac10b-58cc-4372-e567-0e02b2c3d479", 4, Future, true},
{"f47ac10b-58cc-4372-f567-0e02b2c3d479", 4, Future, true},
{"f47ac10b158cc-5372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc25372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-53723a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a56740e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a567-0e02-2c3d479", 0, Invalid, false},
{"g47ac10b-58cc-4372-a567-0e02b2c3d479", 0, Invalid, false},
}
var constants = []struct {
c interface{}
name string
}{
{Person, "Person"},
{Group, "Group"},
{Org, "Org"},
{Invalid, "Invalid"},
{RFC4122, "RFC4122"},
{Reserved, "Reserved"},
{Microsoft, "Microsoft"},
{Future, "Future"},
{Domain(17), "Domain17"},
{Variant(42), "BadVariant42"},
}
func testTest(t *testing.T, in string, tt test) {
uuid := Parse(in)
if ok := (uuid != nil); ok != tt.isuuid {
t.Errorf("Parse(%s) got %v expected %v\b", in, ok, tt.isuuid)
}
if uuid == nil {
return
}
if v := uuid.Variant(); v != tt.variant {
t.Errorf("Variant(%s) got %d expected %d\b", in, v, tt.variant)
}
if v, _ := uuid.Version(); v != tt.version {
t.Errorf("Version(%s) got %d expected %d\b", in, v, tt.version)
}
}
func TestUUID(t *testing.T) {
for _, tt := range tests {
testTest(t, tt.in, tt)
testTest(t, strings.ToUpper(tt.in), tt)
}
}
func TestConstants(t *testing.T) {
for x, tt := range constants {
v, ok := tt.c.(fmt.Stringer)
if !ok {
t.Errorf("%x: %v: not a stringer", x, v)
} else if s := v.String(); s != tt.name {
v, _ := tt.c.(int)
t.Errorf("%x: Constant %T:%d gives %q, expected %q\n", x, tt.c, v, s, tt.name)
}
}
}
func TestRandomUUID(t *testing.T) {
m := make(map[string]bool)
for x := 1; x < 32; x++ {
uuid := NewRandom()
s := uuid.String()
if m[s] {
t.Errorf("NewRandom returned duplicated UUID %s\n", s)
}
m[s] = true
if v, _ := uuid.Version(); v != 4 {
t.Errorf("Random UUID of version %s\n", v)
}
if uuid.Variant() != RFC4122 {
t.Errorf("Random UUID is variant %d\n", uuid.Variant())
}
}
}
func TestNew(t *testing.T) {
m := make(map[string]bool)
for x := 1; x < 32; x++ {
s := New()
if m[s] {
t.Errorf("New returned duplicated UUID %s\n", s)
}
m[s] = true
uuid := Parse(s)
if uuid == nil {
t.Errorf("New returned %q which does not decode\n", s)
continue
}
if v, _ := uuid.Version(); v != 4 {
t.Errorf("Random UUID of version %s\n", v)
}
if uuid.Variant() != RFC4122 {
t.Errorf("Random UUID is variant %d\n", uuid.Variant())
}
}
}
func clockSeq(t *testing.T, uuid UUID) int {
seq, ok := uuid.ClockSequence()
if !ok {
t.Fatalf("%s: invalid clock sequence\n", uuid)
}
return seq
}
func TestClockSeq(t *testing.T) {
// Fake time.Now for this test to return a monotonically advancing time; restore it at end.
defer func(orig func() time.Time) { timeNow = orig }(timeNow)
monTime := time.Now()
timeNow = func() time.Time {
monTime = monTime.Add(1 * time.Second)
return monTime
}
SetClockSequence(-1)
uuid1 := NewUUID()
uuid2 := NewUUID()
if clockSeq(t, uuid1) != clockSeq(t, uuid2) {
t.Errorf("clock sequence %d != %d\n", clockSeq(t, uuid1), clockSeq(t, uuid2))
}
SetClockSequence(-1)
uuid2 = NewUUID()
// Just on the very off chance we generated the same sequence
// two times we try again.
if clockSeq(t, uuid1) == clockSeq(t, uuid2) {
SetClockSequence(-1)
uuid2 = NewUUID()
}
if clockSeq(t, uuid1) == clockSeq(t, uuid2) {
t.Errorf("Duplicate clock sequence %d\n", clockSeq(t, uuid1))
}
SetClockSequence(0x1234)
uuid1 = NewUUID()
if seq := clockSeq(t, uuid1); seq != 0x1234 {
t.Errorf("%s: expected seq 0x1234 got 0x%04x\n", uuid1, seq)
}
}
func TestCoding(t *testing.T) {
text := "7d444840-9dc0-11d1-b245-5ffdce74fad2"
urn := "urn:uuid:7d444840-9dc0-11d1-b245-5ffdce74fad2"
data := UUID{
0x7d, 0x44, 0x48, 0x40,
0x9d, 0xc0,
0x11, 0xd1,
0xb2, 0x45,
0x5f, 0xfd, 0xce, 0x74, 0xfa, 0xd2,
}
if v := data.String(); v != text {
t.Errorf("%x: encoded to %s, expected %s\n", data, v, text)
}
if v := data.URN(); v != urn {
t.Errorf("%x: urn is %s, expected %s\n", data, v, urn)
}
uuid := Parse(text)
if !Equal(uuid, data) {
t.Errorf("%s: decoded to %s, expected %s\n", text, uuid, data)
}
}
func TestVersion1(t *testing.T) {
uuid1 := NewUUID()
uuid2 := NewUUID()
if Equal(uuid1, uuid2) {
t.Errorf("%s:duplicate uuid\n", uuid1)
}
if v, _ := uuid1.Version(); v != 1 {
t.Errorf("%s: version %s expected 1\n", uuid1, v)
}
if v, _ := uuid2.Version(); v != 1 {
t.Errorf("%s: version %s expected 1\n", uuid2, v)
}
n1 := uuid1.NodeID()
n2 := uuid2.NodeID()
if !bytes.Equal(n1, n2) {
t.Errorf("Different nodes %x != %x\n", n1, n2)
}
t1, ok := uuid1.Time()
if !ok {
t.Errorf("%s: invalid time\n", uuid1)
}
t2, ok := uuid2.Time()
if !ok {
t.Errorf("%s: invalid time\n", uuid2)
}
q1, ok := uuid1.ClockSequence()
if !ok {
t.Errorf("%s: invalid clock sequence\n", uuid1)
}
q2, ok := uuid2.ClockSequence()
if !ok {
t.Errorf("%s: invalid clock sequence", uuid2)
}
switch {
case t1 == t2 && q1 == q2:
t.Errorf("time stopped\n")
case t1 > t2 && q1 == q2:
t.Errorf("time reversed\n")
case t1 < t2 && q1 != q2:
t.Errorf("clock sequence chaned unexpectedly\n")
}
}
func TestNodeAndTime(t *testing.T) {
// Time is February 5, 1998 12:30:23.136364800 AM GMT
uuid := Parse("7d444840-9dc0-11d1-b245-5ffdce74fad2")
node := []byte{0x5f, 0xfd, 0xce, 0x74, 0xfa, 0xd2}
ts, ok := uuid.Time()
if ok {
c := time.Unix(ts.UnixTime())
want := time.Date(1998, 2, 5, 0, 30, 23, 136364800, time.UTC)
if !c.Equal(want) {
t.Errorf("Got time %v, want %v", c, want)
}
} else {
t.Errorf("%s: bad time\n", uuid)
}
if !bytes.Equal(node, uuid.NodeID()) {
t.Errorf("Expected node %v got %v\n", node, uuid.NodeID())
}
}
func TestMD5(t *testing.T) {
uuid := NewMD5(NameSpace_DNS, []byte("python.org")).String()
want := "6fa459ea-ee8a-3ca4-894e-db77e160355e"
if uuid != want {
t.Errorf("MD5: got %q expected %q\n", uuid, want)
}
}
func TestSHA1(t *testing.T) {
uuid := NewSHA1(NameSpace_DNS, []byte("python.org")).String()
want := "886313e1-3b8a-5372-9b90-0c9aee199e5d"
if uuid != want {
t.Errorf("SHA1: got %q expected %q\n", uuid, want)
}
}
func TestNodeID(t *testing.T) {
nid := []byte{1, 2, 3, 4, 5, 6}
SetNodeInterface("")
s := NodeInterface()
if s == "" || s == "user" {
t.Errorf("NodeInterface %q after SetInteface\n", s)
}
node1 := NodeID()
if node1 == nil {
t.Errorf("NodeID nil after SetNodeInterface\n", s)
}
SetNodeID(nid)
s = NodeInterface()
if s != "user" {
t.Errorf("Expected NodeInterface %q got %q\n", "user", s)
}
node2 := NodeID()
if node2 == nil {
t.Errorf("NodeID nil after SetNodeID\n", s)
}
if bytes.Equal(node1, node2) {
t.Errorf("NodeID not changed after SetNodeID\n", s)
} else if !bytes.Equal(nid, node2) {
t.Errorf("NodeID is %x, expected %x\n", node2, nid)
}
}
func testDCE(t *testing.T, name string, uuid UUID, domain Domain, id uint32) {
if uuid == nil {
t.Errorf("%s failed\n", name)
return
}
if v, _ := uuid.Version(); v != 2 {
t.Errorf("%s: %s: expected version 2, got %s\n", name, uuid, v)
return
}
if v, ok := uuid.Domain(); !ok || v != domain {
if !ok {
t.Errorf("%s: %d: Domain failed\n", name, uuid)
} else {
t.Errorf("%s: %s: expected domain %d, got %d\n", name, uuid, domain, v)
}
}
if v, ok := uuid.Id(); !ok || v != id {
if !ok {
t.Errorf("%s: %d: Id failed\n", name, uuid)
} else {
t.Errorf("%s: %s: expected id %d, got %d\n", name, uuid, id, v)
}
}
}
func TestDCE(t *testing.T) {
testDCE(t, "NewDCESecurity", NewDCESecurity(42, 12345678), 42, 12345678)
testDCE(t, "NewDCEPerson", NewDCEPerson(), Person, uint32(os.Getuid()))
testDCE(t, "NewDCEGroup", NewDCEGroup(), Group, uint32(os.Getgid()))
}
type badRand struct{}
func (r badRand) Read(buf []byte) (int, error) {
for i, _ := range buf {
buf[i] = byte(i)
}
return len(buf), nil
}
func TestBadRand(t *testing.T) {
SetRand(badRand{})
uuid1 := New()
uuid2 := New()
if uuid1 != uuid2 {
t.Errorf("execpted duplicates, got %q and %q\n", uuid1, uuid2)
}
SetRand(nil)
uuid1 = New()
uuid2 = New()
if uuid1 == uuid2 {
t.Errorf("unexecpted duplicates, got %q\n", uuid1)
}
}

View File

@@ -0,0 +1,41 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
)
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
// sequence, and the current time. If the NodeID has not been set by SetNodeID
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
// be set NewUUID returns nil. If clock sequence has not been set by
// SetClockSequence then it will be set automatically. If GetTime fails to
// return the current NewUUID returns nil.
func NewUUID() UUID {
if nodeID == nil {
SetNodeInterface("")
}
now, err := GetTime()
if err != nil {
return nil
}
uuid := make([]byte, 16)
time_low := uint32(now & 0xffffffff)
time_mid := uint16((now >> 32) & 0xffff)
time_hi := uint16((now >> 48) & 0x0fff)
time_hi |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uuid[0:], time_low)
binary.BigEndian.PutUint16(uuid[4:], time_mid)
binary.BigEndian.PutUint16(uuid[6:], time_hi)
binary.BigEndian.PutUint16(uuid[8:], clock_seq)
copy(uuid[10:], nodeID)
return uuid
}

View File

@@ -0,0 +1,25 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
// Random returns a Random (Version 4) UUID or panics.
//
// The strength of the UUIDs is based on the strength of the crypto/rand
// package.
//
// A note about uniqueness derived from from the UUID Wikipedia entry:
//
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
// hit by a meteorite is estimated to be one chance in 17 billion, that
// means the probability is about 0.00000000006 (6 × 1011),
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
// year and having one duplicate.
func NewRandom() UUID {
uuid := make([]byte, 16)
randomBits([]byte(uuid))
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid
}

View File

@@ -0,0 +1,77 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC
2898 / PKCS #5 v2.0.
A key derivation function is useful when encrypting data based on a password
or any other not-fully-random data. It uses a pseudorandom function to derive
a secure encryption key based on the password.
While v2.0 of the standard defines only one pseudorandom function to use,
HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved
Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To
choose, you can pass the `New` functions from the different SHA packages to
pbkdf2.Key.
*/
package pbkdf2
import (
"crypto/hmac"
"hash"
)
// Key derives a key from the password, salt and iteration count, returning a
// []byte of length keylen that can be used as cryptographic key. The key is
// derived based on the method described as PBKDF2 with the HMAC variant using
// the supplied hash function.
//
// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you
// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by
// doing:
//
// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New)
//
// Remember to get a good random salt. At least 8 bytes is recommended by the
// RFC.
//
// Using a higher iteration count will increase the cost of an exhaustive
// search but will also make derivation proportionally slower.
func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
prf := hmac.New(h, password)
hashLen := prf.Size()
numBlocks := (keyLen + hashLen - 1) / hashLen
var buf [4]byte
dk := make([]byte, 0, numBlocks*hashLen)
U := make([]byte, hashLen)
for block := 1; block <= numBlocks; block++ {
// N.B.: || means concatenation, ^ means XOR
// for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
// U_1 = PRF(password, salt || uint(i))
prf.Reset()
prf.Write(salt)
buf[0] = byte(block >> 24)
buf[1] = byte(block >> 16)
buf[2] = byte(block >> 8)
buf[3] = byte(block)
prf.Write(buf[:4])
dk = prf.Sum(dk)
T := dk[len(dk)-hashLen:]
copy(U, T)
// U_n = PRF(password, U_(n-1))
for n := 2; n <= iter; n++ {
prf.Reset()
prf.Write(U)
U = U[:0]
U = prf.Sum(U)
for x := range U {
T[x] ^= U[x]
}
}
}
return dk[:keyLen]
}

View File

@@ -0,0 +1,157 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pbkdf2
import (
"bytes"
"crypto/sha1"
"crypto/sha256"
"hash"
"testing"
)
type testVector struct {
password string
salt string
iter int
output []byte
}
// Test vectors from RFC 6070, http://tools.ietf.org/html/rfc6070
var sha1TestVectors = []testVector{
{
"password",
"salt",
1,
[]byte{
0x0c, 0x60, 0xc8, 0x0f, 0x96, 0x1f, 0x0e, 0x71,
0xf3, 0xa9, 0xb5, 0x24, 0xaf, 0x60, 0x12, 0x06,
0x2f, 0xe0, 0x37, 0xa6,
},
},
{
"password",
"salt",
2,
[]byte{
0xea, 0x6c, 0x01, 0x4d, 0xc7, 0x2d, 0x6f, 0x8c,
0xcd, 0x1e, 0xd9, 0x2a, 0xce, 0x1d, 0x41, 0xf0,
0xd8, 0xde, 0x89, 0x57,
},
},
{
"password",
"salt",
4096,
[]byte{
0x4b, 0x00, 0x79, 0x01, 0xb7, 0x65, 0x48, 0x9a,
0xbe, 0xad, 0x49, 0xd9, 0x26, 0xf7, 0x21, 0xd0,
0x65, 0xa4, 0x29, 0xc1,
},
},
// // This one takes too long
// {
// "password",
// "salt",
// 16777216,
// []byte{
// 0xee, 0xfe, 0x3d, 0x61, 0xcd, 0x4d, 0xa4, 0xe4,
// 0xe9, 0x94, 0x5b, 0x3d, 0x6b, 0xa2, 0x15, 0x8c,
// 0x26, 0x34, 0xe9, 0x84,
// },
// },
{
"passwordPASSWORDpassword",
"saltSALTsaltSALTsaltSALTsaltSALTsalt",
4096,
[]byte{
0x3d, 0x2e, 0xec, 0x4f, 0xe4, 0x1c, 0x84, 0x9b,
0x80, 0xc8, 0xd8, 0x36, 0x62, 0xc0, 0xe4, 0x4a,
0x8b, 0x29, 0x1a, 0x96, 0x4c, 0xf2, 0xf0, 0x70,
0x38,
},
},
{
"pass\000word",
"sa\000lt",
4096,
[]byte{
0x56, 0xfa, 0x6a, 0xa7, 0x55, 0x48, 0x09, 0x9d,
0xcc, 0x37, 0xd7, 0xf0, 0x34, 0x25, 0xe0, 0xc3,
},
},
}
// Test vectors from
// http://stackoverflow.com/questions/5130513/pbkdf2-hmac-sha2-test-vectors
var sha256TestVectors = []testVector{
{
"password",
"salt",
1,
[]byte{
0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c,
0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4, 0xf8, 0x37,
0xa8, 0x65, 0x48, 0xc9,
},
},
{
"password",
"salt",
2,
[]byte{
0xae, 0x4d, 0x0c, 0x95, 0xaf, 0x6b, 0x46, 0xd3,
0x2d, 0x0a, 0xdf, 0xf9, 0x28, 0xf0, 0x6d, 0xd0,
0x2a, 0x30, 0x3f, 0x8e,
},
},
{
"password",
"salt",
4096,
[]byte{
0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41,
0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c, 0x4c, 0x8d,
0x96, 0x28, 0x93, 0xa0,
},
},
{
"passwordPASSWORDpassword",
"saltSALTsaltSALTsaltSALTsaltSALTsalt",
4096,
[]byte{
0x34, 0x8c, 0x89, 0xdb, 0xcb, 0xd3, 0x2b, 0x2f,
0x32, 0xd8, 0x14, 0xb8, 0x11, 0x6e, 0x84, 0xcf,
0x2b, 0x17, 0x34, 0x7e, 0xbc, 0x18, 0x00, 0x18,
0x1c,
},
},
{
"pass\000word",
"sa\000lt",
4096,
[]byte{
0x89, 0xb6, 0x9d, 0x05, 0x16, 0xf8, 0x29, 0x89,
0x3c, 0x69, 0x62, 0x26, 0x65, 0x0a, 0x86, 0x87,
},
},
}
func testHash(t *testing.T, h func() hash.Hash, hashName string, vectors []testVector) {
for i, v := range vectors {
o := Key([]byte(v.password), []byte(v.salt), v.iter, len(v.output), h)
if !bytes.Equal(o, v.output) {
t.Errorf("%s %d: expected %x, got %x", hashName, i, v.output, o)
}
}
}
func TestWithHMACSHA1(t *testing.T) {
testHash(t, sha1.New, "SHA1", sha1TestVectors)
}
func TestWithHMACSHA256(t *testing.T) {
testHash(t, sha256.New, "SHA256", sha256TestVectors)
}

View File

@@ -0,0 +1,120 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ripemd160 implements the RIPEMD-160 hash algorithm.
package ripemd160
// RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart
// Preneel with specifications available at:
// http://homes.esat.kuleuven.be/~cosicart/pdf/AB-9601/AB-9601.pdf.
import (
"crypto"
"hash"
)
func init() {
crypto.RegisterHash(crypto.RIPEMD160, New)
}
// The size of the checksum in bytes.
const Size = 20
// The block size of the hash algorithm in bytes.
const BlockSize = 64
const (
_s0 = 0x67452301
_s1 = 0xefcdab89
_s2 = 0x98badcfe
_s3 = 0x10325476
_s4 = 0xc3d2e1f0
)
// digest represents the partial evaluation of a checksum.
type digest struct {
s [5]uint32 // running context
x [BlockSize]byte // temporary buffer
nx int // index into x
tc uint64 // total count of bytes processed
}
func (d *digest) Reset() {
d.s[0], d.s[1], d.s[2], d.s[3], d.s[4] = _s0, _s1, _s2, _s3, _s4
d.nx = 0
d.tc = 0
}
// New returns a new hash.Hash computing the checksum.
func New() hash.Hash {
result := new(digest)
result.Reset()
return result
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
nn = len(p)
d.tc += uint64(nn)
if d.nx > 0 {
n := len(p)
if n > BlockSize-d.nx {
n = BlockSize - d.nx
}
for i := 0; i < n; i++ {
d.x[d.nx+i] = p[i]
}
d.nx += n
if d.nx == BlockSize {
_Block(d, d.x[0:])
d.nx = 0
}
p = p[n:]
}
n := _Block(d, p)
p = p[n:]
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
tc := d.tc
var tmp [64]byte
tmp[0] = 0x80
if tc%64 < 56 {
d.Write(tmp[0 : 56-tc%64])
} else {
d.Write(tmp[0 : 64+56-tc%64])
}
// Length in bits.
tc <<= 3
for i := uint(0); i < 8; i++ {
tmp[i] = byte(tc >> (8 * i))
}
d.Write(tmp[0:8])
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
for i, s := range d.s {
digest[i*4] = byte(s)
digest[i*4+1] = byte(s >> 8)
digest[i*4+2] = byte(s >> 16)
digest[i*4+3] = byte(s >> 24)
}
return append(in, digest[:]...)
}

View File

@@ -0,0 +1,64 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ripemd160
// Test vectors are from:
// http://homes.esat.kuleuven.be/~bosselae/ripemd160.html
import (
"fmt"
"io"
"testing"
)
type mdTest struct {
out string
in string
}
var vectors = [...]mdTest{
{"9c1185a5c5e9fc54612808977ee8f548b2258d31", ""},
{"0bdc9d2d256b3ee9daae347be6f4dc835a467ffe", "a"},
{"8eb208f7e05d987a9b044a8e98c6b087f15a0bfc", "abc"},
{"5d0689ef49d2fae572b881b123a85ffa21595f36", "message digest"},
{"f71c27109c692c1b56bbdceb5b9d2865b3708dbc", "abcdefghijklmnopqrstuvwxyz"},
{"12a053384a9c0c88e405a06c27dcf49ada62eb2b", "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"},
{"b0e20b6e3116640286ed3a87a5713079b21f5189", "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"},
{"9b752e45573d4b39f4dbd3323cab82bf63326bfb", "12345678901234567890123456789012345678901234567890123456789012345678901234567890"},
}
func TestVectors(t *testing.T) {
for i := 0; i < len(vectors); i++ {
tv := vectors[i]
md := New()
for j := 0; j < 3; j++ {
if j < 2 {
io.WriteString(md, tv.in)
} else {
io.WriteString(md, tv.in[0:len(tv.in)/2])
md.Sum(nil)
io.WriteString(md, tv.in[len(tv.in)/2:])
}
s := fmt.Sprintf("%x", md.Sum(nil))
if s != tv.out {
t.Fatalf("RIPEMD-160[%d](%s) = %s, expected %s", j, tv.in, s, tv.out)
}
md.Reset()
}
}
}
func TestMillionA(t *testing.T) {
md := New()
for i := 0; i < 100000; i++ {
io.WriteString(md, "aaaaaaaaaa")
}
out := "52783243c1697bdbe16d37f97f68f08325dc1528"
s := fmt.Sprintf("%x", md.Sum(nil))
if s != out {
t.Fatalf("RIPEMD-160 (1 million 'a') = %s, expected %s", s, out)
}
md.Reset()
}

View File

@@ -0,0 +1,161 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// RIPEMD-160 block step.
// In its own file so that a faster assembly or C version
// can be substituted easily.
package ripemd160
// work buffer indices and roll amounts for one line
var _n = [80]uint{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8,
3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12,
1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2,
4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13,
}
var _r = [80]uint{
11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8,
7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12,
11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5,
11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12,
9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6,
}
// same for the other parallel one
var n_ = [80]uint{
5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12,
6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2,
15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13,
8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14,
12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11,
}
var r_ = [80]uint{
8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6,
9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11,
9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5,
15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8,
8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11,
}
func _Block(md *digest, p []byte) int {
n := 0
var x [16]uint32
var alpha, beta uint32
for len(p) >= BlockSize {
a, b, c, d, e := md.s[0], md.s[1], md.s[2], md.s[3], md.s[4]
aa, bb, cc, dd, ee := a, b, c, d, e
j := 0
for i := 0; i < 16; i++ {
x[i] = uint32(p[j]) | uint32(p[j+1])<<8 | uint32(p[j+2])<<16 | uint32(p[j+3])<<24
j += 4
}
// round 1
i := 0
for i < 16 {
alpha = a + (b ^ c ^ d) + x[_n[i]]
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb ^ (cc | ^dd)) + x[n_[i]] + 0x50a28be6
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 2
for i < 32 {
alpha = a + (b&c | ^b&d) + x[_n[i]] + 0x5a827999
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb&dd | cc&^dd) + x[n_[i]] + 0x5c4dd124
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 3
for i < 48 {
alpha = a + (b | ^c ^ d) + x[_n[i]] + 0x6ed9eba1
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb | ^cc ^ dd) + x[n_[i]] + 0x6d703ef3
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 4
for i < 64 {
alpha = a + (b&d | c&^d) + x[_n[i]] + 0x8f1bbcdc
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb&cc | ^bb&dd) + x[n_[i]] + 0x7a6d76e9
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 5
for i < 80 {
alpha = a + (b ^ (c | ^d)) + x[_n[i]] + 0xa953fd4e
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb ^ cc ^ dd) + x[n_[i]]
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// combine results
dd += c + md.s[1]
md.s[1] = md.s[2] + d + ee
md.s[2] = md.s[3] + e + aa
md.s[3] = md.s[4] + a + bb
md.s[4] = md.s[0] + b + cc
md.s[0] = dd
p = p[BlockSize:]
n += BlockSize
}
return n
}

View File

@@ -0,0 +1,243 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package scrypt implements the scrypt key derivation function as defined in
// Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard
// Functions" (http://www.tarsnap.com/scrypt/scrypt.pdf).
package scrypt
import (
"crypto/sha256"
"errors"
"golang.org/x/crypto/pbkdf2"
)
const maxInt = int(^uint(0) >> 1)
// blockCopy copies n numbers from src into dst.
func blockCopy(dst, src []uint32, n int) {
copy(dst, src[:n])
}
// blockXOR XORs numbers from dst with n numbers from src.
func blockXOR(dst, src []uint32, n int) {
for i, v := range src[:n] {
dst[i] ^= v
}
}
// salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in,
// and puts the result into both both tmp and out.
func salsaXOR(tmp *[16]uint32, in, out []uint32) {
w0 := tmp[0] ^ in[0]
w1 := tmp[1] ^ in[1]
w2 := tmp[2] ^ in[2]
w3 := tmp[3] ^ in[3]
w4 := tmp[4] ^ in[4]
w5 := tmp[5] ^ in[5]
w6 := tmp[6] ^ in[6]
w7 := tmp[7] ^ in[7]
w8 := tmp[8] ^ in[8]
w9 := tmp[9] ^ in[9]
w10 := tmp[10] ^ in[10]
w11 := tmp[11] ^ in[11]
w12 := tmp[12] ^ in[12]
w13 := tmp[13] ^ in[13]
w14 := tmp[14] ^ in[14]
w15 := tmp[15] ^ in[15]
x0, x1, x2, x3, x4, x5, x6, x7, x8 := w0, w1, w2, w3, w4, w5, w6, w7, w8
x9, x10, x11, x12, x13, x14, x15 := w9, w10, w11, w12, w13, w14, w15
for i := 0; i < 8; i += 2 {
u := x0 + x12
x4 ^= u<<7 | u>>(32-7)
u = x4 + x0
x8 ^= u<<9 | u>>(32-9)
u = x8 + x4
x12 ^= u<<13 | u>>(32-13)
u = x12 + x8
x0 ^= u<<18 | u>>(32-18)
u = x5 + x1
x9 ^= u<<7 | u>>(32-7)
u = x9 + x5
x13 ^= u<<9 | u>>(32-9)
u = x13 + x9
x1 ^= u<<13 | u>>(32-13)
u = x1 + x13
x5 ^= u<<18 | u>>(32-18)
u = x10 + x6
x14 ^= u<<7 | u>>(32-7)
u = x14 + x10
x2 ^= u<<9 | u>>(32-9)
u = x2 + x14
x6 ^= u<<13 | u>>(32-13)
u = x6 + x2
x10 ^= u<<18 | u>>(32-18)
u = x15 + x11
x3 ^= u<<7 | u>>(32-7)
u = x3 + x15
x7 ^= u<<9 | u>>(32-9)
u = x7 + x3
x11 ^= u<<13 | u>>(32-13)
u = x11 + x7
x15 ^= u<<18 | u>>(32-18)
u = x0 + x3
x1 ^= u<<7 | u>>(32-7)
u = x1 + x0
x2 ^= u<<9 | u>>(32-9)
u = x2 + x1
x3 ^= u<<13 | u>>(32-13)
u = x3 + x2
x0 ^= u<<18 | u>>(32-18)
u = x5 + x4
x6 ^= u<<7 | u>>(32-7)
u = x6 + x5
x7 ^= u<<9 | u>>(32-9)
u = x7 + x6
x4 ^= u<<13 | u>>(32-13)
u = x4 + x7
x5 ^= u<<18 | u>>(32-18)
u = x10 + x9
x11 ^= u<<7 | u>>(32-7)
u = x11 + x10
x8 ^= u<<9 | u>>(32-9)
u = x8 + x11
x9 ^= u<<13 | u>>(32-13)
u = x9 + x8
x10 ^= u<<18 | u>>(32-18)
u = x15 + x14
x12 ^= u<<7 | u>>(32-7)
u = x12 + x15
x13 ^= u<<9 | u>>(32-9)
u = x13 + x12
x14 ^= u<<13 | u>>(32-13)
u = x14 + x13
x15 ^= u<<18 | u>>(32-18)
}
x0 += w0
x1 += w1
x2 += w2
x3 += w3
x4 += w4
x5 += w5
x6 += w6
x7 += w7
x8 += w8
x9 += w9
x10 += w10
x11 += w11
x12 += w12
x13 += w13
x14 += w14
x15 += w15
out[0], tmp[0] = x0, x0
out[1], tmp[1] = x1, x1
out[2], tmp[2] = x2, x2
out[3], tmp[3] = x3, x3
out[4], tmp[4] = x4, x4
out[5], tmp[5] = x5, x5
out[6], tmp[6] = x6, x6
out[7], tmp[7] = x7, x7
out[8], tmp[8] = x8, x8
out[9], tmp[9] = x9, x9
out[10], tmp[10] = x10, x10
out[11], tmp[11] = x11, x11
out[12], tmp[12] = x12, x12
out[13], tmp[13] = x13, x13
out[14], tmp[14] = x14, x14
out[15], tmp[15] = x15, x15
}
func blockMix(tmp *[16]uint32, in, out []uint32, r int) {
blockCopy(tmp[:], in[(2*r-1)*16:], 16)
for i := 0; i < 2*r; i += 2 {
salsaXOR(tmp, in[i*16:], out[i*8:])
salsaXOR(tmp, in[i*16+16:], out[i*8+r*16:])
}
}
func integer(b []uint32, r int) uint64 {
j := (2*r - 1) * 16
return uint64(b[j]) | uint64(b[j+1])<<32
}
func smix(b []byte, r, N int, v, xy []uint32) {
var tmp [16]uint32
x := xy
y := xy[32*r:]
j := 0
for i := 0; i < 32*r; i++ {
x[i] = uint32(b[j]) | uint32(b[j+1])<<8 | uint32(b[j+2])<<16 | uint32(b[j+3])<<24
j += 4
}
for i := 0; i < N; i += 2 {
blockCopy(v[i*(32*r):], x, 32*r)
blockMix(&tmp, x, y, r)
blockCopy(v[(i+1)*(32*r):], y, 32*r)
blockMix(&tmp, y, x, r)
}
for i := 0; i < N; i += 2 {
j := int(integer(x, r) & uint64(N-1))
blockXOR(x, v[j*(32*r):], 32*r)
blockMix(&tmp, x, y, r)
j = int(integer(y, r) & uint64(N-1))
blockXOR(y, v[j*(32*r):], 32*r)
blockMix(&tmp, y, x, r)
}
j = 0
for _, v := range x[:32*r] {
b[j+0] = byte(v >> 0)
b[j+1] = byte(v >> 8)
b[j+2] = byte(v >> 16)
b[j+3] = byte(v >> 24)
j += 4
}
}
// Key derives a key from the password, salt, and cost parameters, returning
// a byte slice of length keyLen that can be used as cryptographic key.
//
// N is a CPU/memory cost parameter, which must be a power of two greater than 1.
// r and p must satisfy r * p < 2³⁰. If the parameters do not satisfy the
// limits, the function returns a nil byte slice and an error.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a
// 32-byte key) by doing:
//
// dk := scrypt.Key([]byte("some password"), salt, 16384, 8, 1, 32)
//
// The recommended parameters for interactive logins as of 2009 are N=16384,
// r=8, p=1. They should be increased as memory latency and CPU parallelism
// increases. Remember to get a good random salt.
func Key(password, salt []byte, N, r, p, keyLen int) ([]byte, error) {
if N <= 1 || N&(N-1) != 0 {
return nil, errors.New("scrypt: N must be > 1 and a power of 2")
}
if uint64(r)*uint64(p) >= 1<<30 || r > maxInt/128/p || r > maxInt/256 || N > maxInt/128/r {
return nil, errors.New("scrypt: parameters are too large")
}
xy := make([]uint32, 64*r)
v := make([]uint32, 32*N*r)
b := pbkdf2.Key(password, salt, 1, p*128*r, sha256.New)
for i := 0; i < p; i++ {
smix(b[i*128*r:], r, N, v, xy)
}
return pbkdf2.Key(password, b, 1, keyLen, sha256.New), nil
}

View File

@@ -0,0 +1,160 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package scrypt
import (
"bytes"
"testing"
)
type testVector struct {
password string
salt string
N, r, p int
output []byte
}
var good = []testVector{
{
"password",
"salt",
2, 10, 10,
[]byte{
0x48, 0x2c, 0x85, 0x8e, 0x22, 0x90, 0x55, 0xe6, 0x2f,
0x41, 0xe0, 0xec, 0x81, 0x9a, 0x5e, 0xe1, 0x8b, 0xdb,
0x87, 0x25, 0x1a, 0x53, 0x4f, 0x75, 0xac, 0xd9, 0x5a,
0xc5, 0xe5, 0xa, 0xa1, 0x5f,
},
},
{
"password",
"salt",
16, 100, 100,
[]byte{
0x88, 0xbd, 0x5e, 0xdb, 0x52, 0xd1, 0xdd, 0x0, 0x18,
0x87, 0x72, 0xad, 0x36, 0x17, 0x12, 0x90, 0x22, 0x4e,
0x74, 0x82, 0x95, 0x25, 0xb1, 0x8d, 0x73, 0x23, 0xa5,
0x7f, 0x91, 0x96, 0x3c, 0x37,
},
},
{
"this is a long \000 password",
"and this is a long \000 salt",
16384, 8, 1,
[]byte{
0xc3, 0xf1, 0x82, 0xee, 0x2d, 0xec, 0x84, 0x6e, 0x70,
0xa6, 0x94, 0x2f, 0xb5, 0x29, 0x98, 0x5a, 0x3a, 0x09,
0x76, 0x5e, 0xf0, 0x4c, 0x61, 0x29, 0x23, 0xb1, 0x7f,
0x18, 0x55, 0x5a, 0x37, 0x07, 0x6d, 0xeb, 0x2b, 0x98,
0x30, 0xd6, 0x9d, 0xe5, 0x49, 0x26, 0x51, 0xe4, 0x50,
0x6a, 0xe5, 0x77, 0x6d, 0x96, 0xd4, 0x0f, 0x67, 0xaa,
0xee, 0x37, 0xe1, 0x77, 0x7b, 0x8a, 0xd5, 0xc3, 0x11,
0x14, 0x32, 0xbb, 0x3b, 0x6f, 0x7e, 0x12, 0x64, 0x40,
0x18, 0x79, 0xe6, 0x41, 0xae,
},
},
{
"p",
"s",
2, 1, 1,
[]byte{
0x48, 0xb0, 0xd2, 0xa8, 0xa3, 0x27, 0x26, 0x11, 0x98,
0x4c, 0x50, 0xeb, 0xd6, 0x30, 0xaf, 0x52,
},
},
{
"",
"",
16, 1, 1,
[]byte{
0x77, 0xd6, 0x57, 0x62, 0x38, 0x65, 0x7b, 0x20, 0x3b,
0x19, 0xca, 0x42, 0xc1, 0x8a, 0x04, 0x97, 0xf1, 0x6b,
0x48, 0x44, 0xe3, 0x07, 0x4a, 0xe8, 0xdf, 0xdf, 0xfa,
0x3f, 0xed, 0xe2, 0x14, 0x42, 0xfc, 0xd0, 0x06, 0x9d,
0xed, 0x09, 0x48, 0xf8, 0x32, 0x6a, 0x75, 0x3a, 0x0f,
0xc8, 0x1f, 0x17, 0xe8, 0xd3, 0xe0, 0xfb, 0x2e, 0x0d,
0x36, 0x28, 0xcf, 0x35, 0xe2, 0x0c, 0x38, 0xd1, 0x89,
0x06,
},
},
{
"password",
"NaCl",
1024, 8, 16,
[]byte{
0xfd, 0xba, 0xbe, 0x1c, 0x9d, 0x34, 0x72, 0x00, 0x78,
0x56, 0xe7, 0x19, 0x0d, 0x01, 0xe9, 0xfe, 0x7c, 0x6a,
0xd7, 0xcb, 0xc8, 0x23, 0x78, 0x30, 0xe7, 0x73, 0x76,
0x63, 0x4b, 0x37, 0x31, 0x62, 0x2e, 0xaf, 0x30, 0xd9,
0x2e, 0x22, 0xa3, 0x88, 0x6f, 0xf1, 0x09, 0x27, 0x9d,
0x98, 0x30, 0xda, 0xc7, 0x27, 0xaf, 0xb9, 0x4a, 0x83,
0xee, 0x6d, 0x83, 0x60, 0xcb, 0xdf, 0xa2, 0xcc, 0x06,
0x40,
},
},
{
"pleaseletmein", "SodiumChloride",
16384, 8, 1,
[]byte{
0x70, 0x23, 0xbd, 0xcb, 0x3a, 0xfd, 0x73, 0x48, 0x46,
0x1c, 0x06, 0xcd, 0x81, 0xfd, 0x38, 0xeb, 0xfd, 0xa8,
0xfb, 0xba, 0x90, 0x4f, 0x8e, 0x3e, 0xa9, 0xb5, 0x43,
0xf6, 0x54, 0x5d, 0xa1, 0xf2, 0xd5, 0x43, 0x29, 0x55,
0x61, 0x3f, 0x0f, 0xcf, 0x62, 0xd4, 0x97, 0x05, 0x24,
0x2a, 0x9a, 0xf9, 0xe6, 0x1e, 0x85, 0xdc, 0x0d, 0x65,
0x1e, 0x40, 0xdf, 0xcf, 0x01, 0x7b, 0x45, 0x57, 0x58,
0x87,
},
},
/*
// Disabled: needs 1 GiB RAM and takes too long for a simple test.
{
"pleaseletmein", "SodiumChloride",
1048576, 8, 1,
[]byte{
0x21, 0x01, 0xcb, 0x9b, 0x6a, 0x51, 0x1a, 0xae, 0xad,
0xdb, 0xbe, 0x09, 0xcf, 0x70, 0xf8, 0x81, 0xec, 0x56,
0x8d, 0x57, 0x4a, 0x2f, 0xfd, 0x4d, 0xab, 0xe5, 0xee,
0x98, 0x20, 0xad, 0xaa, 0x47, 0x8e, 0x56, 0xfd, 0x8f,
0x4b, 0xa5, 0xd0, 0x9f, 0xfa, 0x1c, 0x6d, 0x92, 0x7c,
0x40, 0xf4, 0xc3, 0x37, 0x30, 0x40, 0x49, 0xe8, 0xa9,
0x52, 0xfb, 0xcb, 0xf4, 0x5c, 0x6f, 0xa7, 0x7a, 0x41,
0xa4,
},
},
*/
}
var bad = []testVector{
{"p", "s", 0, 1, 1, nil}, // N == 0
{"p", "s", 1, 1, 1, nil}, // N == 1
{"p", "s", 7, 8, 1, nil}, // N is not power of 2
{"p", "s", 16, maxInt / 2, maxInt / 2, nil}, // p * r too large
}
func TestKey(t *testing.T) {
for i, v := range good {
k, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, len(v.output))
if err != nil {
t.Errorf("%d: got unexpected error: %s", i, err)
}
if !bytes.Equal(k, v.output) {
t.Errorf("%d: expected %x, got %x", i, v.output, k)
}
}
for i, v := range bad {
_, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, 32)
if err == nil {
t.Errorf("%d: expected error, got nil", i)
}
}
}
func BenchmarkKey(b *testing.B) {
for i := 0; i < b.N; i++ {
Key([]byte("password"), []byte("salt"), 16384, 8, 1, 64)
}
}

View File

@@ -0,0 +1,98 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
switch config.Location.Scheme {
case "ws":
client, err = net.Dial("tcp", config.Location.Host)
case "wss":
client, err = tls.Dial("tcp", config.Location.Host, config.TlsConfig)
default:
err = ErrBadScheme
}
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
goto Error
}
return
Error:
return nil, &DialError{config, err}
}

View File

@@ -0,0 +1,31 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"fmt"
"log"
"code.google.com/p/go.net/websocket"
)
// This example demonstrates a trivial client.
func ExampleDial() {
origin := "http://localhost/"
url := "ws://localhost:12345/ws"
ws, err := websocket.Dial(url, "", origin)
if err != nil {
log.Fatal(err)
}
if _, err := ws.Write([]byte("hello, world!\n")); err != nil {
log.Fatal(err)
}
var msg = make([]byte, 512)
var n int
if n, err = ws.Read(msg); err != nil {
log.Fatal(err)
}
fmt.Printf("Received: %s.\n", msg[:n])
}

View File

@@ -0,0 +1,26 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"io"
"net/http"
"code.google.com/p/go.net/websocket"
)
// Echo the data received on the WebSocket.
func EchoServer(ws *websocket.Conn) {
io.Copy(ws, ws)
}
// This example demonstrates a trivial echo server.
func ExampleHandler() {
http.Handle("/echo", websocket.Handler(EchoServer))
err := http.ListenAndServe(":12345", nil)
if err != nil {
panic("ListenAndServe: " + err.Error())
}
}

View File

@@ -0,0 +1,564 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
// This file implements a protocol of hybi draft.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
const (
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
closeStatusNormal = 1000
closeStatusGoingAway = 1001
closeStatusProtocolError = 1002
closeStatusUnsupportedData = 1003
closeStatusFrameTooLarge = 1004
closeStatusNoStatusRcvd = 1005
closeStatusAbnormalClosure = 1006
closeStatusBadMessageData = 1007
closeStatusPolicyViolation = 1008
closeStatusTooBigData = 1009
closeStatusExtensionMismatch = 1010
maxControlFramePayloadLength = 125
)
var (
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
ErrBadPongMessage = &ProtocolError{"bad pong message"}
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
type hybiFrameHeader struct {
Fin bool
Rsv [3]bool
OpCode byte
Length int64
MaskingKey []byte
data *bytes.Buffer
}
// A hybiFrameReader is a reader for hybi frame.
type hybiFrameReader struct {
reader io.Reader
header hybiFrameHeader
pos int64
length int
}
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
n, err = frame.reader.Read(msg)
if err != nil {
return 0, err
}
if frame.header.MaskingKey != nil {
for i := 0; i < n; i++ {
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
frame.pos++
}
}
return n, err
}
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
func (frame *hybiFrameReader) HeaderReader() io.Reader {
if frame.header.data == nil {
return nil
}
if frame.header.data.Len() == 0 {
return nil
}
return frame.header.data
}
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
type hybiFrameReaderFactory struct {
*bufio.Reader
}
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
// See Section 5.2 Base Framing protocol for detail.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
hybiFrame := new(hybiFrameReader)
frame = hybiFrame
var header []byte
var b byte
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
for i := 0; i < 3; i++ {
j := uint(6 - i)
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
}
hybiFrame.header.OpCode = header[0] & 0x0f
// Second byte. Mask/Payload len(7bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
mask := (b & 0x80) != 0
b &= 0x7f
lengthFields := 0
switch {
case b <= 125: // Payload length 7bits.
hybiFrame.header.Length = int64(b)
case b == 126: // Payload length 7+16bits
lengthFields = 2
case b == 127: // Payload length 7+64bits
lengthFields = 8
}
for i := 0; i < lengthFields; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
}
if mask {
// Masking key. 4 bytes.
for i := 0; i < 4; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
}
}
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
hybiFrame.header.data = bytes.NewBuffer(header)
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
return
}
// A HybiFrameWriter is a writer for hybi frame.
type hybiFrameWriter struct {
writer *bufio.Writer
header *hybiFrameHeader
}
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
var header []byte
var b byte
if frame.header.Fin {
b |= 0x80
}
for i := 0; i < 3; i++ {
if frame.header.Rsv[i] {
j := uint(6 - i)
b |= 1 << j
}
}
b |= frame.header.OpCode
header = append(header, b)
if frame.header.MaskingKey != nil {
b = 0x80
} else {
b = 0
}
lengthFields := 0
length := len(msg)
switch {
case length <= 125:
b |= byte(length)
case length < 65536:
b |= 126
lengthFields = 2
default:
b |= 127
lengthFields = 8
}
header = append(header, b)
for i := 0; i < lengthFields; i++ {
j := uint((lengthFields - i - 1) * 8)
b = byte((length >> j) & 0xff)
header = append(header, b)
}
if frame.header.MaskingKey != nil {
if len(frame.header.MaskingKey) != 4 {
return 0, ErrBadMaskingKey
}
header = append(header, frame.header.MaskingKey...)
frame.writer.Write(header)
data := make([]byte, length)
for i := range data {
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
}
frame.writer.Write(data)
err = frame.writer.Flush()
return length, err
}
frame.writer.Write(header)
frame.writer.Write(msg)
err = frame.writer.Flush()
return length, err
}
func (frame *hybiFrameWriter) Close() error { return nil }
type hybiFrameWriterFactory struct {
*bufio.Writer
needMaskingKey bool
}
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
if buf.needMaskingKey {
frameHeader.MaskingKey, err = generateMaskingKey()
if err != nil {
return nil, err
}
}
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
}
type hybiFrameHandler struct {
conn *Conn
payloadType byte
}
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
} else {
// The server MUST NOT mask all frames.
if frame.(*hybiFrameReader).header.MaskingKey != nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
case TextFrame, BinaryFrame:
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
case PingFrame:
pingMsg := make([]byte, maxControlFramePayloadLength)
n, err := io.ReadFull(frame, pingMsg)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
n, err = handler.WritePong(pingMsg[:n])
if err != nil {
return nil, err
}
return nil, nil
case PongFrame:
return nil, ErrNotImplemented
}
return frame, nil
}
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
if err != nil {
return err
}
msg := make([]byte, 2)
binary.BigEndian.PutUint16(msg, uint16(status))
_, err = w.Write(msg)
w.Close()
return err
}
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
if buf == nil {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
frameWriterFactory: hybiFrameWriterFactory{
buf.Writer, request == nil},
PayloadType: TextFrame,
defaultCloseStatus: closeStatusNormal}
ws.frameHandler = &hybiFrameHandler{conn: ws}
return ws
}
// generateMaskingKey generates a masking key for a frame.
func generateMaskingKey() (maskingKey []byte, err error) {
maskingKey = make([]byte, 4)
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
return
}
return
}
// generateNonce generates a nonce consisting of a randomly selected 16-byte
// value that has been base64-encoded.
func generateNonce() (nonce []byte) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
nonce = make([]byte, 24)
base64.StdEncoding.Encode(nonce, key)
return
}
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func getNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(websocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
bw.WriteString("Host: " + config.Location.Host + "\r\n")
bw.WriteString("Upgrade: websocket\r\n")
bw.WriteString("Connection: Upgrade\r\n")
nonce := generateNonce()
if config.handshakeData != nil {
nonce = []byte(config.handshakeData["key"])
}
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
if config.Version != ProtocolVersionHybi13 {
return ErrBadProtocolVersion
}
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
return err
}
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
if err != nil {
return err
}
if resp.StatusCode != 101 {
return ErrBadStatus
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
expectedAccept, err := getNonceAccept(nonce)
if err != nil {
return err
}
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
return ErrChallengeResponse
}
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
return ErrUnsupportedExtensions
}
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return ErrBadWebSocketProtocol
}
config.Protocol = []string{offeredProtocol}
}
return nil
}
// newHybiClientConn creates a client WebSocket connection after handshake.
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
return newHybiConn(config, buf, rwc, nil)
}
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
type hybiServerHandshaker struct {
*Config
accept []byte
}
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
c.Version = ProtocolVersionHybi13
if req.Method != "GET" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}
// HTTP version can be safely ignored.
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
return http.StatusBadRequest, ErrNotWebSocket
}
key := req.Header.Get("Sec-Websocket-Key")
if key == "" {
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
switch version {
case "13":
c.Version = ProtocolVersionHybi13
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
var scheme string
if req.TLS != nil {
scheme = "wss"
} else {
scheme = "ws"
}
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
if err != nil {
return http.StatusBadRequest, err
}
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocol != "" {
protocols := strings.Split(protocol, ",")
for i := 0; i < len(protocols); i++ {
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
}
}
c.accept, err = getNonceAccept([]byte(key))
if err != nil {
return http.StatusInternalServerError, err
}
return http.StatusSwitchingProtocols, nil
}
// Origin parses Origin header in "req".
// If origin is "null", returns (nil, nil).
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
if origin == "null" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
// You need choose a Protocol in Handshake func in Server.
return ErrBadWebSocketProtocol
}
}
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiServerConn(c.Config, buf, rwc, request)
}
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiConn(config, buf, rwc, request)
}

View File

@@ -0,0 +1,590 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
// Test the getNonceAccept function with values in
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
func TestSecWebSocketAccept(t *testing.T) {
nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==")
expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
accept, err := getNonceAccept(nonce)
if err != nil {
t.Errorf("getNonceAccept: returned error %v", err)
return
}
if !bytes.Equal(expected, accept) {
t.Errorf("getNonceAccept: expected %q got %q", expected, accept)
}
}
func TestHybiClientHandshake(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiClientHandshakeWithHeader(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.Header = http.Header(make(map[string][]string))
config.Header.Add("User-Agent", "test")
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
"User-Agent": "test",
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiServerHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
expectedProtocols := []string{"chat", "superchat"}
if fmt.Sprintf("%v", config.Protocol) != fmt.Sprintf("%v", expectedProtocols) {
t.Errorf("protocol expected %q but got %q", expectedProtocols, config.Protocol)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = config.Protocol[:1]
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeNoSubProtocol(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
if len(config.Protocol) != 0 {
t.Errorf("len(config.Protocol) expected 0, but got %q", len(config.Protocol))
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeHybiBadVersion(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 9
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != ErrBadWebSocketVersion {
t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err)
}
if code != http.StatusBadRequest {
t.Errorf("status expected %q but got %q", http.StatusBadRequest, code)
}
}
func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) {
b := bytes.NewBuffer([]byte{})
frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false}
w, _ := frameWriterFactory.NewFrameWriter(TextFrame)
w.(*hybiFrameWriter).header = frameHeader
_, err := w.Write(testPayload)
w.Close()
if err != nil {
t.Errorf("Write error %q", err)
}
var expectedFrame []byte
expectedFrame = append(expectedFrame, testHeader...)
expectedFrame = append(expectedFrame, testMaskedPayload...)
if !bytes.Equal(expectedFrame, b.Bytes()) {
t.Errorf("frame expected %q got %q", expectedFrame, b.Bytes())
}
frameReaderFactory := &hybiFrameReaderFactory{bufio.NewReader(b)}
r, err := frameReaderFactory.NewFrameReader()
if err != nil {
t.Errorf("Read error %q", err)
}
if header := r.HeaderReader(); header == nil {
t.Errorf("no header")
} else {
actualHeader := make([]byte, r.Len())
n, err := header.Read(actualHeader)
if err != nil {
t.Errorf("Read header error %q", err)
} else {
if n < len(testHeader) {
t.Errorf("header too short %q got %q", testHeader, actualHeader[:n])
}
if !bytes.Equal(testHeader, actualHeader[:n]) {
t.Errorf("header expected %q got %q", testHeader, actualHeader[:n])
}
}
}
if trailer := r.TrailerReader(); trailer != nil {
t.Errorf("unexpected trailer %q", trailer)
}
frame := r.(*hybiFrameReader)
if frameHeader.Fin != frame.header.Fin ||
frameHeader.OpCode != frame.header.OpCode ||
len(testPayload) != int(frame.header.Length) {
t.Errorf("mismatch %v (%d) vs %v", frameHeader, len(testPayload), frame)
}
payload := make([]byte, len(testPayload))
_, err = r.Read(payload)
if err != nil {
t.Errorf("read %v", err)
}
if !bytes.Equal(testPayload, payload) {
t.Errorf("payload %q vs %q", testPayload, payload)
}
}
func TestHybiShortTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x81, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x81, 125}, payload, payload, frameHeader)
}
func TestHybiShortMaskedTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame,
MaskingKey: []byte{0xcc, 0x55, 0x80, 0x20}}
payload := []byte("hello")
maskedPayload := []byte{0xa4, 0x30, 0xec, 0x4c, 0xa3}
header := []byte{0x81, 0x85}
header = append(header, frameHeader.MaskingKey...)
testHybiFrame(t, header, payload, maskedPayload, frameHeader)
}
func TestHybiShortBinaryFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: BinaryFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x82, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x82, 125}, payload, payload, frameHeader)
}
func TestHybiControlFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame}
payload = []byte{0x03, 0xe8} // 1000
testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader)
}
func TestHybiLongFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := make([]byte, 126)
testHybiFrame(t, []byte{0x81, 126, 0x00, 126}, payload, payload, frameHeader)
payload = make([]byte, 65535)
testHybiFrame(t, []byte{0x81, 126, 0xff, 0xff}, payload, payload, frameHeader)
payload = make([]byte, 65536)
testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader)
}
func TestHybiClientRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[2:7], msg[:n]) {
t.Errorf("read 1st frame %v, got %v", wireData[2:7], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[16:21], msg[:n]) {
t.Errorf("read 2nd frame %v, got %v", wireData[16:21], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiShortRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
step := 0
pos := 0
expectedPos := []int{2, 5, 16, 19}
expectedLen := []int{3, 2, 3, 2}
for {
msg := make([]byte, 3)
n, err := conn.Read(msg)
if step >= len(expectedPos) {
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
return
}
pos = expectedPos[step]
endPos := pos + expectedLen[step]
if err != nil {
t.Errorf("read from %d, got error %q", pos, err)
return
}
if n != endPos-pos {
t.Errorf("read from %d, expect %d, got %d", pos, endPos-pos, n)
}
if !bytes.Equal(wireData[pos:endPos], msg[:n]) {
t.Errorf("read from %d, frame %v, got %v", pos, wireData[pos:endPos], msg[:n])
}
step++
}
}
func TestHybiServerRead(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
0x89, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello
0x81, 0x85, 0xed, 0x83, 0xb4, 0x24,
0x9a, 0xec, 0xc6, 0x48, 0x89, // world
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
expected := [][]byte{[]byte("hello"), []byte("world")}
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[0], msg[:n]) {
t.Errorf("read 1st frame %q, got %q", expected[0], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[1], msg[:n]) {
t.Errorf("read 2nd frame %q, got %q", expected[1], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiServerReadWithoutMasking(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
// server MUST close the connection upon receiving a non-masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
func TestHybiClientReadWithMasking(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
// client MUST close the connection upon receiving a masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
// Test the hybiServerHandshaker supports firefox implementation and
// checks Connection request header include (but it's not necessary
// equal to) "upgrade"
func TestHybiServerFirefoxHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: keep-alive, upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = []string{"chat"}
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}

View File

@@ -0,0 +1,114 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
return
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser client, which doesn't send Origin header, you could use Server
//. that doesn't check origin in its Handshake.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}

View File

@@ -0,0 +1,411 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements a client and server for the WebSocket protocol
// as specified in RFC 6455.
package websocket
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
ProtocolVersionHybi13 = 13
ProtocolVersionHybi = ProtocolVersionHybi13
SupportedProtocolVersion = "13"
ContinuationFrame = 0
TextFrame = 1
BinaryFrame = 2
CloseFrame = 8
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
)
// ProtocolError represents WebSocket protocol errors.
type ProtocolError struct {
ErrorString string
}
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
ErrBadScheme = &ProtocolError{"bad scheme"}
ErrBadStatus = &ProtocolError{"bad status"}
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
ErrBadFrame = &ProtocolError{"bad frame"}
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
ErrBadRequestMethod = &ProtocolError{"bad method"}
ErrNotSupported = &ProtocolError{"not supported"}
)
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string { return "websocket" }
// Config is a WebSocket configuration
type Config struct {
// A WebSocket server address.
Location *url.URL
// A Websocket client origin.
Origin *url.URL
// WebSocket subprotocols.
Protocol []string
// WebSocket protocol version.
Version int
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
handshakeData map[string]string
}
// serverHandshaker is an interface to handle WebSocket server side handshake.
type serverHandshaker interface {
// ReadHandshake reads handshake request message from client.
// Returns http response code and error if any.
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
// AcceptHandshake accepts the client handshake request and sends
// handshake response back to client.
AcceptHandshake(buf *bufio.Writer) (err error)
// NewServerConn creates a new WebSocket connection.
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
}
// frameReader is an interface to read a WebSocket frame.
type frameReader interface {
// Reader is to read payload of the frame.
io.Reader
// PayloadType returns payload type.
PayloadType() byte
// HeaderReader returns a reader to read header of the frame.
HeaderReader() io.Reader
// TrailerReader returns a reader to read trailer of the frame.
// If it returns nil, there is no trailer in the frame.
TrailerReader() io.Reader
// Len returns total length of the frame, including header and trailer.
Len() int
}
// frameReaderFactory is an interface to creates new frame reader.
type frameReaderFactory interface {
NewFrameReader() (r frameReader, err error)
}
// frameWriter is an interface to write a WebSocket frame.
type frameWriter interface {
// Writer is to write payload of the frame.
io.WriteCloser
}
// frameWriterFactory is an interface to create new frame writer.
type frameWriterFactory interface {
NewFrameWriter(payloadType byte) (w frameWriter, err error)
}
type frameHandler interface {
HandleFrame(frame frameReader) (r frameReader, err error)
WriteClose(status int) (err error)
}
// Conn represents a WebSocket connection.
type Conn struct {
config *Config
request *http.Request
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
rio sync.Mutex
frameReaderFactory
frameReader
wio sync.Mutex
frameWriterFactory
frameHandler
PayloadType byte
defaultCloseStatus int
}
// Read implements the io.Reader interface:
// it reads data of a frame from the WebSocket connection.
// if msg is not large enough for the frame data, it fills the msg and next Read
// will read the rest of the frame data.
// it reads Text frame or Binary frame.
func (ws *Conn) Read(msg []byte) (n int, err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
again:
if ws.frameReader == nil {
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return 0, err
}
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return 0, err
}
if ws.frameReader == nil {
goto again
}
}
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
}
ws.frameReader = nil
goto again
}
return n, err
}
// Write implements the io.Writer interface:
// it writes data as a frame to the WebSocket connection.
func (ws *Conn) Write(msg []byte) (n int, err error) {
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
if err != nil {
return n, err
}
return n, err
}
// Close implements the io.Closer interface.
func (ws *Conn) Close() error {
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
if err != nil {
return err
}
return ws.rwc.Close()
}
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
// LocalAddr returns the WebSocket Origin for the connection for client, or
// the WebSocket location for server.
func (ws *Conn) LocalAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Origin}
}
return &Addr{ws.config.Location}
}
// RemoteAddr returns the WebSocket location for the connection for client, or
// the Websocket Origin for server.
func (ws *Conn) RemoteAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Location}
}
return &Addr{ws.config.Origin}
}
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
// SetDeadline sets the connection's network read & write deadlines.
func (ws *Conn) SetDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetDeadline(t)
}
return errSetDeadline
}
// SetReadDeadline sets the connection's network read deadline.
func (ws *Conn) SetReadDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetReadDeadline(t)
}
return errSetDeadline
}
// SetWriteDeadline sets the connection's network write deadline.
func (ws *Conn) SetWriteDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetWriteDeadline(t)
}
return errSetDeadline
}
// Config returns the WebSocket config.
func (ws *Conn) Config() *Config { return ws.config }
// Request returns the http request upgraded to the WebSocket.
// It is nil for client side.
func (ws *Conn) Request() *http.Request { return ws.request }
// Codec represents a symmetric pair of functions that implement a codec.
type Codec struct {
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
}
// Send sends v marshaled by cd.Marshal as single frame to ws.
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
data, payloadType, err := cd.Marshal(v)
if err != nil {
return err
}
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
if err != nil {
return err
}
_, err = w.Write(data)
w.Close()
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
if err != nil {
return err
}
ws.frameReader = nil
}
again:
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return err
}
frame, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return err
}
if frame == nil {
goto again
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {
return err
}
return cd.Unmarshal(data, payloadType, v)
}
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
switch data := v.(type) {
case string:
return []byte(data), TextFrame, nil
case []byte:
return data, BinaryFrame, nil
}
return nil, UnknownFrame, ErrNotSupported
}
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
switch data := v.(type) {
case *string:
*data = string(msg)
return nil
case *[]byte:
*data = msg
return nil
}
return ErrNotSupported
}
/*
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
To send/receive text frame, use string type.
To send/receive binary frame, use []byte type.
Trivial usage:
import "websocket"
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
// send text frame
message = "hello"
websocket.Message.Send(ws, message)
// receive binary frame
var data []byte
websocket.Message.Receive(ws, &data)
// send binary frame
data = []byte{0, 1, 2}
websocket.Message.Send(ws, data)
*/
var Message = Codec{marshal, unmarshal}
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
msg, err = json.Marshal(v)
return msg, TextFrame, err
}
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
return json.Unmarshal(msg, v)
}
/*
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
Trivial usage:
import "websocket"
type T struct {
Msg string
Count int
}
// receive JSON type T
var data T
websocket.JSON.Receive(ws, &data)
// send JSON type T
websocket.JSON.Send(ws, data)
*/
var JSON = Codec{jsonMarshal, jsonUnmarshal}

View File

@@ -0,0 +1,341 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
)
var serverAddr string
var once sync.Once
func echoServer(ws *Conn) { io.Copy(ws, ws) }
type Count struct {
S string
N int
}
func countServer(ws *Conn) {
for {
var count Count
err := JSON.Receive(ws, &count)
if err != nil {
return
}
count.N++
count.S = strings.Repeat(count.S, count.N)
err = JSON.Send(ws, count)
if err != nil {
return
}
}
}
func subProtocolHandshake(config *Config, req *http.Request) error {
for _, proto := range config.Protocol {
if proto == "chat" {
config.Protocol = []string{proto}
return nil
}
}
return ErrBadWebSocketProtocol
}
func subProtoServer(ws *Conn) {
for _, proto := range ws.Config().Protocol {
io.WriteString(ws, proto)
}
}
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
subproto := Server{
Handshake: subProtocolHandshake,
Handler: Handler(subProtoServer),
}
http.Handle("/subproto", subproto)
server := httptest.NewServer(nil)
serverAddr = server.Listener.Addr().String()
log.Print("Test WebSocket server listening on ", serverAddr)
}
func newConfig(t *testing.T, path string) *Config {
config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
return config
}
func TestEcho(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var actual_msg = make([]byte, 512)
n, err := conn.Read(actual_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
actual_msg = actual_msg[0:n]
if !bytes.Equal(msg, actual_msg) {
t.Errorf("Echo: expected %q got %q", msg, actual_msg)
}
conn.Close()
}
func TestAddr(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
ra := conn.RemoteAddr().String()
if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
t.Errorf("Bad remote addr: %v", ra)
}
la := conn.LocalAddr().String()
if !strings.HasPrefix(la, "http://") {
t.Errorf("Bad local addr: %v", la)
}
conn.Close()
}
func TestCount(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/count"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
var count Count
count.S = "hello"
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 1 {
t.Errorf("count: expected %d got %d", 1, count.N)
}
if count.S != "hello" {
t.Errorf("count: expected %q got %q", "hello", count.S)
}
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 2 {
t.Errorf("count: expected %d got %d", 2, count.N)
}
if count.S != "hellohello" {
t.Errorf("count: expected %q got %q", "hellohello", count.S)
}
conn.Close()
}
func TestWithQuery(t *testing.T) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/echo")
config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
if err != nil {
t.Fatal("location url", err)
}
ws, err := NewClient(config, client)
if err != nil {
t.Errorf("WebSocket handshake: %v", err)
return
}
ws.Close()
}
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/subproto")
config.Protocol = subproto
ws, err := NewClient(config, client)
if err != nil {
return "", err
}
msg := make([]byte, 16)
n, err := ws.Read(msg)
if err != nil {
return "", err
}
ws.Close()
return string(msg[:n]), nil
}
func TestWithProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithTwoProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"test", "chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithBadProtocol(t *testing.T) {
_, err := testWithProtocol(t, []string{"test"})
if err != ErrBadStatus {
t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
}
}
func TestHTTP(t *testing.T) {
once.Do(startServer)
// If the client did not send a handshake that matches the protocol
// specification, the server MUST return an HTTP response with an
// appropriate error code (such as 400 Bad Request)
resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
if err != nil {
t.Errorf("Get: error %#v", err)
return
}
if resp == nil {
t.Error("Get: resp is null")
return
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
}
}
func TestTrailingSpaces(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=955
// The last runs of this create keys with trailing spaces that should not be
// generated by the client.
once.Do(startServer)
config := newConfig(t, "/echo")
for i := 0; i < 30; i++ {
// body
ws, err := DialConfig(config)
if err != nil {
t.Errorf("Dial #%d failed: %v", i, err)
break
}
ws.Close()
}
}
func TestDialConfigBadVersion(t *testing.T) {
once.Do(startServer)
config := newConfig(t, "/echo")
config.Version = 1234
_, err := DialConfig(config)
if dialerr, ok := err.(*DialError); ok {
if dialerr.Err != ErrBadProtocolVersion {
t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
}
}
}
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var small_msg = make([]byte, 8)
n, err := conn.Read(small_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
}
var second_msg = make([]byte, len(msg))
n, err = conn.Read(second_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
second_msg = second_msg[0:n]
if !bytes.Equal(msg[len(small_msg):], second_msg) {
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
}
conn.Close()
}

View File

@@ -0,0 +1,124 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
)
// ErrCorrupt reports that the input is invalid.
var ErrCorrupt = errors.New("snappy: corrupt input")
// DecodedLen returns the length of the decoded block.
func DecodedLen(src []byte) (int, error) {
v, _, err := decodedLen(src)
return v, err
}
// decodedLen returns the length of the decoded block and the number of bytes
// that the length header occupied.
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
v, n := binary.Uvarint(src)
if n == 0 {
return 0, 0, ErrCorrupt
}
if uint64(int(v)) != v {
return 0, 0, errors.New("snappy: decoded block is too large")
}
return int(v), n, nil
}
// Decode returns the decoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire decoded block.
// Otherwise, a newly allocated slice will be returned.
// It is valid to pass a nil dst.
func Decode(dst, src []byte) ([]byte, error) {
dLen, s, err := decodedLen(src)
if err != nil {
return nil, err
}
if len(dst) < dLen {
dst = make([]byte, dLen)
}
var d, offset, length int
for s < len(src) {
switch src[s] & 0x03 {
case tagLiteral:
x := uint(src[s] >> 2)
switch {
case x < 60:
s += 1
case x == 60:
s += 2
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-1])
case x == 61:
s += 3
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-2]) | uint(src[s-1])<<8
case x == 62:
s += 4
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-3]) | uint(src[s-2])<<8 | uint(src[s-1])<<16
case x == 63:
s += 5
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-4]) | uint(src[s-3])<<8 | uint(src[s-2])<<16 | uint(src[s-1])<<24
}
length = int(x + 1)
if length <= 0 {
return nil, errors.New("snappy: unsupported literal length")
}
if length > len(dst)-d || length > len(src)-s {
return nil, ErrCorrupt
}
copy(dst[d:], src[s:s+length])
d += length
s += length
continue
case tagCopy1:
s += 2
if s > len(src) {
return nil, ErrCorrupt
}
length = 4 + int(src[s-2])>>2&0x7
offset = int(src[s-2])&0xe0<<3 | int(src[s-1])
case tagCopy2:
s += 3
if s > len(src) {
return nil, ErrCorrupt
}
length = 1 + int(src[s-3])>>2
offset = int(src[s-2]) | int(src[s-1])<<8
case tagCopy4:
return nil, errors.New("snappy: unsupported COPY_4 tag")
}
end := d + length
if offset > d || end > len(dst) {
return nil, ErrCorrupt
}
for ; d < end; d++ {
dst[d] = dst[d-offset]
}
}
if d != dLen {
return nil, ErrCorrupt
}
return dst[:d], nil
}

View File

@@ -0,0 +1,174 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
)
// We limit how far copy back-references can go, the same as the C++ code.
const maxOffset = 1 << 15
// emitLiteral writes a literal chunk and returns the number of bytes written.
func emitLiteral(dst, lit []byte) int {
i, n := 0, uint(len(lit)-1)
switch {
case n < 60:
dst[0] = uint8(n)<<2 | tagLiteral
i = 1
case n < 1<<8:
dst[0] = 60<<2 | tagLiteral
dst[1] = uint8(n)
i = 2
case n < 1<<16:
dst[0] = 61<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
i = 3
case n < 1<<24:
dst[0] = 62<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
dst[3] = uint8(n >> 16)
i = 4
case int64(n) < 1<<32:
dst[0] = 63<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
dst[3] = uint8(n >> 16)
dst[4] = uint8(n >> 24)
i = 5
default:
panic("snappy: source buffer is too long")
}
if copy(dst[i:], lit) != len(lit) {
panic("snappy: destination buffer is too short")
}
return i + len(lit)
}
// emitCopy writes a copy chunk and returns the number of bytes written.
func emitCopy(dst []byte, offset, length int) int {
i := 0
for length > 0 {
x := length - 4
if 0 <= x && x < 1<<3 && offset < 1<<11 {
dst[i+0] = uint8(offset>>8)&0x07<<5 | uint8(x)<<2 | tagCopy1
dst[i+1] = uint8(offset)
i += 2
break
}
x = length
if x > 1<<6 {
x = 1 << 6
}
dst[i+0] = uint8(x-1)<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= x
}
return i
}
// Encode returns the encoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire encoded block.
// Otherwise, a newly allocated slice will be returned.
// It is valid to pass a nil dst.
func Encode(dst, src []byte) ([]byte, error) {
if n := MaxEncodedLen(len(src)); len(dst) < n {
dst = make([]byte, n)
}
// The block starts with the varint-encoded length of the decompressed bytes.
d := binary.PutUvarint(dst, uint64(len(src)))
// Return early if src is short.
if len(src) <= 4 {
if len(src) != 0 {
d += emitLiteral(dst[d:], src)
}
return dst[:d], nil
}
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
const maxTableSize = 1 << 14
shift, tableSize := uint(32-8), 1<<8
for tableSize < maxTableSize && tableSize < len(src) {
shift--
tableSize *= 2
}
var table [maxTableSize]int
// Iterate over the source bytes.
var (
s int // The iterator position.
t int // The last position with the same hash as s.
lit int // The start position of any pending literal bytes.
)
for s+3 < len(src) {
// Update the hash table.
b0, b1, b2, b3 := src[s], src[s+1], src[s+2], src[s+3]
h := uint32(b0) | uint32(b1)<<8 | uint32(b2)<<16 | uint32(b3)<<24
p := &table[(h*0x1e35a7bd)>>shift]
// We need to to store values in [-1, inf) in table. To save
// some initialization time, (re)use the table's zero value
// and shift the values against this zero: add 1 on writes,
// subtract 1 on reads.
t, *p = *p-1, s+1
// If t is invalid or src[s:s+4] differs from src[t:t+4], accumulate a literal byte.
if t < 0 || s-t >= maxOffset || b0 != src[t] || b1 != src[t+1] || b2 != src[t+2] || b3 != src[t+3] {
s++
continue
}
// Otherwise, we have a match. First, emit any pending literal bytes.
if lit != s {
d += emitLiteral(dst[d:], src[lit:s])
}
// Extend the match to be as long as possible.
s0 := s
s, t = s+4, t+4
for s < len(src) && src[s] == src[t] {
s++
t++
}
// Emit the copied bytes.
d += emitCopy(dst[d:], s-t, s-s0)
lit = s
}
// Emit any final pending literal bytes and return.
if lit != len(src) {
d += emitLiteral(dst[d:], src[lit:])
}
return dst[:d], nil
}
// MaxEncodedLen returns the maximum length of a snappy block, given its
// uncompressed length.
func MaxEncodedLen(srcLen int) int {
// Compressed data can be defined as:
// compressed := item* literal*
// item := literal* copy
//
// The trailing literal sequence has a space blowup of at most 62/60
// since a literal of length 60 needs one tag byte + one extra byte
// for length information.
//
// Item blowup is trickier to measure. Suppose the "copy" op copies
// 4 bytes of data. Because of a special check in the encoding code,
// we produce a 4-byte copy only if the offset is < 65536. Therefore
// the copy op takes 3 bytes to encode, and this type of item leads
// to at most the 62/60 blowup for representing literals.
//
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
// enough, it will take 5 bytes to encode the copy op. Therefore the
// worst case here is a one-byte literal followed by a five-byte copy.
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
//
// This last factor dominates the blowup, so the final estimate is:
return 32 + srcLen + srcLen/6
}

View File

@@ -0,0 +1,38 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package snappy implements the snappy block-based compression format.
// It aims for very high speeds and reasonable compression.
//
// The C++ snappy implementation is at http://code.google.com/p/snappy/
package snappy
/*
Each encoded block begins with the varint-encoded length of the decoded data,
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
first byte of each chunk is broken into its 2 least and 6 most significant bits
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
Zero means a literal tag. All other values mean a copy tag.
For literal tags:
- If m < 60, the next 1 + m bytes are literal bytes.
- Otherwise, let n be the little-endian unsigned integer denoted by the next
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
For copy tags, length bytes are copied from offset bytes ago, in the style of
Lempel-Ziv compression algorithms. In particular:
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
of the offset. The next byte is bits 0-7 of the offset.
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
The length is 1 + m. The offset is the little-endian unsigned integer
denoted by the next 2 bytes.
- For l == 3, this tag is a legacy format that is no longer supported.
*/
const (
tagLiteral = 0x00
tagCopy1 = 0x01
tagCopy2 = 0x02
tagCopy4 = 0x03
)

View File

@@ -0,0 +1,261 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
)
var download = flag.Bool("download", false, "If true, download any missing files before running benchmarks")
func roundtrip(b, ebuf, dbuf []byte) error {
e, err := Encode(ebuf, b)
if err != nil {
return fmt.Errorf("encoding error: %v", err)
}
d, err := Decode(dbuf, e)
if err != nil {
return fmt.Errorf("decoding error: %v", err)
}
if !bytes.Equal(b, d) {
return fmt.Errorf("roundtrip mismatch:\n\twant %v\n\tgot %v", b, d)
}
return nil
}
func TestEmpty(t *testing.T) {
if err := roundtrip(nil, nil, nil); err != nil {
t.Fatal(err)
}
}
func TestSmallCopy(t *testing.T) {
for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
for i := 0; i < 32; i++ {
s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
if err := roundtrip([]byte(s), ebuf, dbuf); err != nil {
t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err)
}
}
}
}
}
func TestSmallRand(t *testing.T) {
rand.Seed(27354294)
for n := 1; n < 20000; n += 23 {
b := make([]byte, n)
for i, _ := range b {
b[i] = uint8(rand.Uint32())
}
if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}
}
func TestSmallRegular(t *testing.T) {
for n := 1; n < 20000; n += 23 {
b := make([]byte, n)
for i, _ := range b {
b[i] = uint8(i%10 + 'a')
}
if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}
}
func benchDecode(b *testing.B, src []byte) {
encoded, err := Encode(nil, src)
if err != nil {
b.Fatal(err)
}
// Bandwidth is in amount of uncompressed data.
b.SetBytes(int64(len(src)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
Decode(src, encoded)
}
}
func benchEncode(b *testing.B, src []byte) {
// Bandwidth is in amount of uncompressed data.
b.SetBytes(int64(len(src)))
dst := make([]byte, MaxEncodedLen(len(src)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
Encode(dst, src)
}
}
func readFile(b *testing.B, filename string) []byte {
src, err := ioutil.ReadFile(filename)
if err != nil {
b.Fatalf("failed reading %s: %s", filename, err)
}
if len(src) == 0 {
b.Fatalf("%s has zero length", filename)
}
return src
}
// expand returns a slice of length n containing repeated copies of src.
func expand(src []byte, n int) []byte {
dst := make([]byte, n)
for x := dst; len(x) > 0; {
i := copy(x, src)
x = x[i:]
}
return dst
}
func benchWords(b *testing.B, n int, decode bool) {
// Note: the file is OS-language dependent so the resulting values are not
// directly comparable for non-US-English OS installations.
data := expand(readFile(b, "/usr/share/dict/words"), n)
if decode {
benchDecode(b, data)
} else {
benchEncode(b, data)
}
}
func BenchmarkWordsDecode1e3(b *testing.B) { benchWords(b, 1e3, true) }
func BenchmarkWordsDecode1e4(b *testing.B) { benchWords(b, 1e4, true) }
func BenchmarkWordsDecode1e5(b *testing.B) { benchWords(b, 1e5, true) }
func BenchmarkWordsDecode1e6(b *testing.B) { benchWords(b, 1e6, true) }
func BenchmarkWordsEncode1e3(b *testing.B) { benchWords(b, 1e3, false) }
func BenchmarkWordsEncode1e4(b *testing.B) { benchWords(b, 1e4, false) }
func BenchmarkWordsEncode1e5(b *testing.B) { benchWords(b, 1e5, false) }
func BenchmarkWordsEncode1e6(b *testing.B) { benchWords(b, 1e6, false) }
// testFiles' values are copied directly from
// https://code.google.com/p/snappy/source/browse/trunk/snappy_unittest.cc.
// The label field is unused in snappy-go.
var testFiles = []struct {
label string
filename string
}{
{"html", "html"},
{"urls", "urls.10K"},
{"jpg", "house.jpg"},
{"pdf", "mapreduce-osdi-1.pdf"},
{"html4", "html_x_4"},
{"cp", "cp.html"},
{"c", "fields.c"},
{"lsp", "grammar.lsp"},
{"xls", "kennedy.xls"},
{"txt1", "alice29.txt"},
{"txt2", "asyoulik.txt"},
{"txt3", "lcet10.txt"},
{"txt4", "plrabn12.txt"},
{"bin", "ptt5"},
{"sum", "sum"},
{"man", "xargs.1"},
{"pb", "geo.protodata"},
{"gaviota", "kppkn.gtb"},
}
// The test data files are present at this canonical URL.
const baseURL = "https://snappy.googlecode.com/svn/trunk/testdata/"
func downloadTestdata(basename string) (errRet error) {
filename := filepath.Join("testdata", basename)
f, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create %s: %s", filename, err)
}
defer f.Close()
defer func() {
if errRet != nil {
os.Remove(filename)
}
}()
resp, err := http.Get(baseURL + basename)
if err != nil {
return fmt.Errorf("failed to download %s: %s", baseURL+basename, err)
}
defer resp.Body.Close()
_, err = io.Copy(f, resp.Body)
if err != nil {
return fmt.Errorf("failed to write %s: %s", filename, err)
}
return nil
}
func benchFile(b *testing.B, n int, decode bool) {
filename := filepath.Join("testdata", testFiles[n].filename)
if stat, err := os.Stat(filename); err != nil || stat.Size() == 0 {
if !*download {
b.Fatal("test data not found; skipping benchmark without the -download flag")
}
// Download the official snappy C++ implementation reference test data
// files for benchmarking.
if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) {
b.Fatalf("failed to create testdata: %s", err)
}
for _, tf := range testFiles {
if err := downloadTestdata(tf.filename); err != nil {
b.Fatalf("failed to download testdata: %s", err)
}
}
}
data := readFile(b, filename)
if decode {
benchDecode(b, data)
} else {
benchEncode(b, data)
}
}
// Naming convention is kept similar to what snappy's C++ implementation uses.
func Benchmark_UFlat0(b *testing.B) { benchFile(b, 0, true) }
func Benchmark_UFlat1(b *testing.B) { benchFile(b, 1, true) }
func Benchmark_UFlat2(b *testing.B) { benchFile(b, 2, true) }
func Benchmark_UFlat3(b *testing.B) { benchFile(b, 3, true) }
func Benchmark_UFlat4(b *testing.B) { benchFile(b, 4, true) }
func Benchmark_UFlat5(b *testing.B) { benchFile(b, 5, true) }
func Benchmark_UFlat6(b *testing.B) { benchFile(b, 6, true) }
func Benchmark_UFlat7(b *testing.B) { benchFile(b, 7, true) }
func Benchmark_UFlat8(b *testing.B) { benchFile(b, 8, true) }
func Benchmark_UFlat9(b *testing.B) { benchFile(b, 9, true) }
func Benchmark_UFlat10(b *testing.B) { benchFile(b, 10, true) }
func Benchmark_UFlat11(b *testing.B) { benchFile(b, 11, true) }
func Benchmark_UFlat12(b *testing.B) { benchFile(b, 12, true) }
func Benchmark_UFlat13(b *testing.B) { benchFile(b, 13, true) }
func Benchmark_UFlat14(b *testing.B) { benchFile(b, 14, true) }
func Benchmark_UFlat15(b *testing.B) { benchFile(b, 15, true) }
func Benchmark_UFlat16(b *testing.B) { benchFile(b, 16, true) }
func Benchmark_UFlat17(b *testing.B) { benchFile(b, 17, true) }
func Benchmark_ZFlat0(b *testing.B) { benchFile(b, 0, false) }
func Benchmark_ZFlat1(b *testing.B) { benchFile(b, 1, false) }
func Benchmark_ZFlat2(b *testing.B) { benchFile(b, 2, false) }
func Benchmark_ZFlat3(b *testing.B) { benchFile(b, 3, false) }
func Benchmark_ZFlat4(b *testing.B) { benchFile(b, 4, false) }
func Benchmark_ZFlat5(b *testing.B) { benchFile(b, 5, false) }
func Benchmark_ZFlat6(b *testing.B) { benchFile(b, 6, false) }
func Benchmark_ZFlat7(b *testing.B) { benchFile(b, 7, false) }
func Benchmark_ZFlat8(b *testing.B) { benchFile(b, 8, false) }
func Benchmark_ZFlat9(b *testing.B) { benchFile(b, 9, false) }
func Benchmark_ZFlat10(b *testing.B) { benchFile(b, 10, false) }
func Benchmark_ZFlat11(b *testing.B) { benchFile(b, 11, false) }
func Benchmark_ZFlat12(b *testing.B) { benchFile(b, 12, false) }
func Benchmark_ZFlat13(b *testing.B) { benchFile(b, 13, false) }
func Benchmark_ZFlat14(b *testing.B) { benchFile(b, 14, false) }
func Benchmark_ZFlat15(b *testing.B) { benchFile(b, 15, false) }
func Benchmark_ZFlat16(b *testing.B) { benchFile(b, 16, false) }
func Benchmark_ZFlat17(b *testing.B) { benchFile(b, 17, false) }