rpc: migrated the RPC insterface to a new reflection based RPC layer

This commit is contained in:
Bas van Kervel
2015-12-16 10:58:01 +01:00
committed by Jeffrey Wilcke
parent f2ab351e8d
commit 19b2640e89
132 changed files with 4711 additions and 14320 deletions

14
Godeps/Godeps.json generated
View File

@ -1,6 +1,6 @@
{
"ImportPath": "github.com/ethereum/go-ethereum",
"GoVersion": "go1.4",
"GoVersion": "go1.5.2",
"Packages": [
"./..."
],
@ -40,10 +40,6 @@
"ImportPath": "github.com/jackpal/go-nat-pmp",
"Rev": "a45aa3d54aef73b504e15eb71bea0e5565b5e6e1"
},
{
"ImportPath": "github.com/kardianos/osext",
"Rev": "ccfcd0245381f0c94c68f50626665eed3c6b726a"
},
{
"ImportPath": "github.com/mattn/go-isatty",
"Rev": "7fcbc72f853b92b5720db4a6b8482be612daef24"
@ -73,10 +69,6 @@
"ImportPath": "github.com/robertkrimen/otto",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/rs/cors",
"Rev": "6e0c3cb65fc0fdb064c743d176a620e3ca446dfb"
},
{
"ImportPath": "github.com/shiena/ansicolor",
"Rev": "a5e2b567a4dd6cc74545b8a4f27c9d63b9e7735b"
@ -109,6 +101,10 @@
"ImportPath": "golang.org/x/net/html",
"Rev": "e0403b4e005737430c05a57aac078479844f919c"
},
{
"ImportPath": "golang.org/x/net/websocket",
"Rev": "e0403b4e005737430c05a57aac078479844f919c"
},
{
"ImportPath": "golang.org/x/text/encoding",
"Rev": "c93e7c9fff19fb9139b5ab04ce041833add0134e"

View File

@ -1,27 +0,0 @@
Copyright (c) 2012 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,14 +0,0 @@
### Extensions to the "os" package.
## Find the current Executable and ExecutableFolder.
There is sometimes utility in finding the current executable file
that is running. This can be used for upgrading the current executable
or finding resources located relative to the executable file.
Multi-platform and supports:
* Linux
* OS X
* Windows
* Plan 9
* BSDs.

View File

@ -1,27 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Extensions to the standard "os" package.
package osext
import "path/filepath"
// Executable returns an absolute path that can be used to
// re-invoke the current program.
// It may not be valid after the current program exits.
func Executable() (string, error) {
p, err := executable()
return filepath.Clean(p), err
}
// Returns same path as Executable, returns just the folder
// path. Excludes the executable name.
func ExecutableFolder() (string, error) {
p, err := Executable()
if err != nil {
return "", err
}
folder, _ := filepath.Split(p)
return folder, nil
}

View File

@ -1,20 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package osext
import (
"os"
"strconv"
"syscall"
)
func executable() (string, error) {
f, err := os.Open("/proc/" + strconv.Itoa(os.Getpid()) + "/text")
if err != nil {
return "", err
}
defer f.Close()
return syscall.Fd2path(int(f.Fd()))
}

View File

@ -1,28 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux netbsd openbsd solaris dragonfly
package osext
import (
"errors"
"fmt"
"os"
"runtime"
)
func executable() (string, error) {
switch runtime.GOOS {
case "linux":
return os.Readlink("/proc/self/exe")
case "netbsd":
return os.Readlink("/proc/curproc/exe")
case "openbsd", "dragonfly":
return os.Readlink("/proc/curproc/file")
case "solaris":
return os.Readlink(fmt.Sprintf("/proc/%d/path/a.out", os.Getpid()))
}
return "", errors.New("ExecPath not implemented for " + runtime.GOOS)
}

View File

@ -1,79 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd
package osext
import (
"os"
"path/filepath"
"runtime"
"syscall"
"unsafe"
)
var initCwd, initCwdErr = os.Getwd()
func executable() (string, error) {
var mib [4]int32
switch runtime.GOOS {
case "freebsd":
mib = [4]int32{1 /* CTL_KERN */, 14 /* KERN_PROC */, 12 /* KERN_PROC_PATHNAME */, -1}
case "darwin":
mib = [4]int32{1 /* CTL_KERN */, 38 /* KERN_PROCARGS */, int32(os.Getpid()), -1}
}
n := uintptr(0)
// Get length.
_, _, errNum := syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, 0, uintptr(unsafe.Pointer(&n)), 0, 0)
if errNum != 0 {
return "", errNum
}
if n == 0 { // This shouldn't happen.
return "", nil
}
buf := make([]byte, n)
_, _, errNum = syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&n)), 0, 0)
if errNum != 0 {
return "", errNum
}
if n == 0 { // This shouldn't happen.
return "", nil
}
for i, v := range buf {
if v == 0 {
buf = buf[:i]
break
}
}
var err error
execPath := string(buf)
// execPath will not be empty due to above checks.
// Try to get the absolute path if the execPath is not rooted.
if execPath[0] != '/' {
execPath, err = getAbs(execPath)
if err != nil {
return execPath, err
}
}
// For darwin KERN_PROCARGS may return the path to a symlink rather than the
// actual executable.
if runtime.GOOS == "darwin" {
if execPath, err = filepath.EvalSymlinks(execPath); err != nil {
return execPath, err
}
}
return execPath, nil
}
func getAbs(execPath string) (string, error) {
if initCwdErr != nil {
return execPath, initCwdErr
}
// The execPath may begin with a "../" or a "./" so clean it first.
// Join the two paths, trailing and starting slashes undetermined, so use
// the generic Join function.
return filepath.Join(initCwd, filepath.Clean(execPath)), nil
}

View File

@ -1,79 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin linux freebsd netbsd windows
package osext
import (
"fmt"
"os"
oexec "os/exec"
"path/filepath"
"runtime"
"testing"
)
const execPath_EnvVar = "OSTEST_OUTPUT_EXECPATH"
func TestExecPath(t *testing.T) {
ep, err := Executable()
if err != nil {
t.Fatalf("ExecPath failed: %v", err)
}
// we want fn to be of the form "dir/prog"
dir := filepath.Dir(filepath.Dir(ep))
fn, err := filepath.Rel(dir, ep)
if err != nil {
t.Fatalf("filepath.Rel: %v", err)
}
cmd := &oexec.Cmd{}
// make child start with a relative program path
cmd.Dir = dir
cmd.Path = fn
// forge argv[0] for child, so that we can verify we could correctly
// get real path of the executable without influenced by argv[0].
cmd.Args = []string{"-", "-test.run=XXXX"}
cmd.Env = []string{fmt.Sprintf("%s=1", execPath_EnvVar)}
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("exec(self) failed: %v", err)
}
outs := string(out)
if !filepath.IsAbs(outs) {
t.Fatalf("Child returned %q, want an absolute path", out)
}
if !sameFile(outs, ep) {
t.Fatalf("Child returned %q, not the same file as %q", out, ep)
}
}
func sameFile(fn1, fn2 string) bool {
fi1, err := os.Stat(fn1)
if err != nil {
return false
}
fi2, err := os.Stat(fn2)
if err != nil {
return false
}
return os.SameFile(fi1, fi2)
}
func init() {
if e := os.Getenv(execPath_EnvVar); e != "" {
// first chdir to another path
dir := "/"
if runtime.GOOS == "windows" {
dir = filepath.VolumeName(".")
}
os.Chdir(dir)
if ep, err := Executable(); err != nil {
fmt.Fprint(os.Stderr, "ERROR: ", err)
} else {
fmt.Fprint(os.Stderr, ep)
}
os.Exit(0)
}
}

View File

@ -1,34 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package osext
import (
"syscall"
"unicode/utf16"
"unsafe"
)
var (
kernel = syscall.MustLoadDLL("kernel32.dll")
getModuleFileNameProc = kernel.MustFindProc("GetModuleFileNameW")
)
// GetModuleFileName() with hModule = NULL
func executable() (exePath string, err error) {
return getModuleFileName()
}
func getModuleFileName() (string, error) {
var n uint32
b := make([]uint16, syscall.MAX_PATH)
size := uint32(len(b))
r0, _, e1 := getModuleFileNameProc.Call(0, uintptr(unsafe.Pointer(&b[0])), uintptr(size))
n = uint32(r0)
if n == 0 {
return "", e1
}
return string(utf16.Decode(b[0:n])), nil
}

View File

@ -1,4 +0,0 @@
language: go
go:
- 1.3
- 1.4

View File

@ -1,19 +0,0 @@
Copyright (c) 2014 Olivier Poitrey <rs@dailymotion.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is furnished
to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -1,84 +0,0 @@
# Go CORS handler [![godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/rs/cors) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/rs/cors/master/LICENSE) [![build](https://img.shields.io/travis/rs/cors.svg?style=flat)](https://travis-ci.org/rs/cors)
CORS is a `net/http` handler implementing [Cross Origin Resource Sharing W3 specification](http://www.w3.org/TR/cors/) in Golang.
## Getting Started
After installing Go and setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file. We'll call it `server.go`.
```go
package main
import (
"net/http"
"github.com/rs/cors"
)
func main() {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
// cors.Default() setup the middleware with default options being
// all origins accepted with simple methods (GET, POST). See
// documentation below for more options.
handler = cors.Default().Handler(h)
http.ListenAndServe(":8080", handler)
}
```
Install `cors`:
go get github.com/rs/cors
Then run your server:
go run server.go
The server now runs on `localhost:8080`:
$ curl -D - -H 'Origin: http://foo.com' http://localhost:8080/
HTTP/1.1 200 OK
Access-Control-Allow-Origin: foo.com
Content-Type: application/json
Date: Sat, 25 Oct 2014 03:43:57 GMT
Content-Length: 18
{"hello": "world"}
### More Examples
* `net/http`: [examples/nethttp/server.go](https://github.com/rs/cors/blob/master/examples/nethttp/server.go)
* [Goji](https://goji.io): [examples/goji/server.go](https://github.com/rs/cors/blob/master/examples/goji/server.go)
* [Martini](http://martini.codegangsta.io): [examples/martini/server.go](https://github.com/rs/cors/blob/master/examples/martini/server.go)
* [Negroni](https://github.com/codegangsta/negroni): [examples/negroni/server.go](https://github.com/rs/cors/blob/master/examples/negroni/server.go)
* [Alice](https://github.com/justinas/alice): [examples/alice/server.go](https://github.com/rs/cors/blob/master/examples/alice/server.go)
## Parameters
Parameters are passed to the middleware thru the `cors.New` method as follow:
```go
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
AllowCredentials: true,
})
// Insert the middleware
handler = c.Handler(handler)
```
* **AllowedOrigins** `[]string`: A list of origins a cross-domain request can be executed from. If the special `*` value is present in the list, all origins will be allowed. The default value is `*`.
* **AllowedMethods** `[]string`: A list of methods the client is allowed to use with cross-domain requests.
* **AllowedHeaders** `[]string`: A list of non simple headers the client is allowed to use with cross-domain requests. Default value is simple methods (`GET` and `POST`)
* **ExposedHeaders** `[]string`: Indicates which headers are safe to expose to the API of a CORS API specification
* **AllowCredentials** `bool`: Indicates whether the request can include user credentials like cookies, HTTP authentication or client side SSL certificates. The default is `false`.
* **MaxAge** `int`: Indicates how long (in seconds) the results of a preflight request can be cached. The default is `0` which stands for no max age.
See [API documentation](http://godoc.org/github.com/rs/cors) for more info.
## Licenses
All source code is licensed under the [MIT License](https://raw.github.com/rs/cors/master/LICENSE).

View File

@ -1,37 +0,0 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
)
func BenchmarkWithout(b *testing.B) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
for i := 0; i < b.N; i++ {
testHandler.ServeHTTP(res, req)
}
}
func BenchmarkDefault(b *testing.B) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
handler := Default()
for i := 0; i < b.N; i++ {
handler.Handler(testHandler).ServeHTTP(res, req)
}
}
func BenchmarkPreflight(b *testing.B) {
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Access-Control-Request-Method", "GET")
handler := Default()
for i := 0; i < b.N; i++ {
handler.Handler(testHandler).ServeHTTP(res, req)
}
}

View File

@ -1,308 +0,0 @@
/*
Package cors is net/http handler to handle CORS related requests
as defined by http://www.w3.org/TR/cors/
You can configure it by passing an option struct to cors.New:
c := cors.New(cors.Options{
AllowedOrigins: []string{"foo.com"},
AllowedMethods: []string{"GET", "POST", "DELETE"},
AllowCredentials: true,
})
Then insert the handler in the chain:
handler = c.Handler(handler)
See Options documentation for more options.
The resulting handler is a standard net/http handler.
*/
package cors
import (
"log"
"net/http"
"os"
"strconv"
"strings"
)
// Options is a configuration container to setup the CORS middleware.
type Options struct {
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is ["*"]
AllowedOrigins []string
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET and POST)
AllowedMethods []string
// AllowedHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
// If the special "*" value is present in the list, all headers will be allowed.
// Default value is [] but "Origin" is always appended to the list.
AllowedHeaders []string
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposedHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge int
// Debugging flag adds additional output to debug server side CORS issues
Debug bool
// log object to use when debugging
log *log.Logger
}
type Cors struct {
// The CORS Options
options Options
}
// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
// Normalize options
// Note: for origins and methods matching, the spec requires a case-sensitive matching.
// As it may error prone, we chose to ignore the spec here.
normOptions := Options{
AllowedOrigins: convert(options.AllowedOrigins, strings.ToLower),
AllowedMethods: convert(options.AllowedMethods, strings.ToUpper),
// Origin is always appended as some browsers will always request
// for this header at preflight
AllowedHeaders: convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey),
ExposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
AllowCredentials: options.AllowCredentials,
MaxAge: options.MaxAge,
Debug: options.Debug,
log: log.New(os.Stdout, "[cors] ", log.LstdFlags),
}
if len(normOptions.AllowedOrigins) == 0 {
// Default is all origins
normOptions.AllowedOrigins = []string{"*"}
}
if len(normOptions.AllowedHeaders) == 1 {
// Add some sensible defaults
normOptions.AllowedHeaders = []string{"Origin", "Accept", "Content-Type"}
}
if len(normOptions.AllowedMethods) == 0 {
// Default is simple methods
normOptions.AllowedMethods = []string{"GET", "POST"}
}
if normOptions.Debug {
normOptions.log.Printf("Options: %v", normOptions)
}
return &Cors{
options: normOptions,
}
}
// Default creates a new Cors handler with default options
func Default() *Cors {
return New(Options{})
}
// Handler apply the CORS specification on the request, and add relevant CORS headers
// as necessary.
func (cors *Cors) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
cors.logf("Handler: Preflight request")
cors.handlePreflight(w, r)
// Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example
// is authentication middleware ; OPTIONS requests won't carry authentication
// headers (see #1)
} else {
cors.logf("Handler: Actual request")
cors.handleActualRequest(w, r)
h.ServeHTTP(w, r)
}
})
}
// Martini compatible handler
func (cors *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" {
cors.logf("HandlerFunc: Preflight request")
cors.handlePreflight(w, r)
} else {
cors.logf("HandlerFunc: Actual request")
cors.handleActualRequest(w, r)
}
}
// Negroni compatible interface
func (cors *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if r.Method == "OPTIONS" {
cors.logf("ServeHTTP: Preflight request")
cors.handlePreflight(w, r)
// Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example
// is authentication middleware ; OPTIONS requests won't carry authentication
// headers (see #1)
} else {
cors.logf("ServeHTTP: Actual request")
cors.handleActualRequest(w, r)
next(w, r)
}
}
// handlePreflight handles pre-flight CORS requests
func (cors *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
options := cors.options
headers := w.Header()
origin := r.Header.Get("Origin")
if r.Method != "OPTIONS" {
cors.logf(" Preflight aborted: %s!=OPTIONS", r.Method)
return
}
if origin == "" {
cors.logf(" Preflight aborted: empty origin")
return
}
if !cors.isOriginAllowed(origin) {
cors.logf(" Preflight aborted: origin '%s' not allowed", origin)
return
}
reqMethod := r.Header.Get("Access-Control-Request-Method")
if !cors.isMethodAllowed(reqMethod) {
cors.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
return
}
reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
if !cors.areHeadersAllowed(reqHeaders) {
cors.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
return
}
headers.Set("Access-Control-Allow-Origin", origin)
headers.Add("Vary", "Origin")
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
if len(reqHeaders) > 0 {
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
}
if options.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if options.MaxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(options.MaxAge))
}
cors.logf(" Preflight response headers: %v", headers)
}
// handleActualRequest handles simple cross-origin requests, actual request or redirects
func (cors *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
options := cors.options
headers := w.Header()
origin := r.Header.Get("Origin")
if r.Method == "OPTIONS" {
cors.logf(" Actual request no headers added: method == %s", r.Method)
return
}
if origin == "" {
cors.logf(" Actual request no headers added: missing origin")
return
}
if !cors.isOriginAllowed(origin) {
cors.logf(" Actual request no headers added: origin '%s' not allowed", origin)
return
}
// Note that spec does define a way to specifically disallow a simple method like GET or
// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
// We think it's a nice feature to be able to have control on those methods though.
if !cors.isMethodAllowed(r.Method) {
if cors.options.Debug {
cors.logf(" Actual request no headers added: method '%s' not allowed",
r.Method)
}
return
}
headers.Set("Access-Control-Allow-Origin", origin)
headers.Add("Vary", "Origin")
if len(options.ExposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(options.ExposedHeaders, ", "))
}
if options.AllowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
cors.logf(" Actual response added headers: %v", headers)
}
// convenience method. checks if debugging is turned on before printing
func (cors *Cors) logf(format string, a ...interface{}) {
if cors.options.Debug {
cors.options.log.Printf(format, a...)
}
}
// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
// on the endpoint
func (cors *Cors) isOriginAllowed(origin string) bool {
allowedOrigins := cors.options.AllowedOrigins
origin = strings.ToLower(origin)
for _, allowedOrigin := range allowedOrigins {
switch allowedOrigin {
case "*":
return true
case origin:
return true
}
}
return false
}
// isMethodAllowed checks if a given method can be used as part of a cross-domain request
// on the endpoing
func (cors *Cors) isMethodAllowed(method string) bool {
allowedMethods := cors.options.AllowedMethods
if len(allowedMethods) == 0 {
// If no method allowed, always return false, even for preflight request
return false
}
method = strings.ToUpper(method)
if method == "OPTIONS" {
// Always allow preflight requests
return true
}
for _, allowedMethod := range allowedMethods {
if allowedMethod == method {
return true
}
}
return false
}
// areHeadersAllowed checks if a given list of headers are allowed to used within
// a cross-domain request.
func (cors *Cors) areHeadersAllowed(requestedHeaders []string) bool {
if len(requestedHeaders) == 0 {
return true
}
for _, header := range requestedHeaders {
found := false
for _, allowedHeader := range cors.options.AllowedHeaders {
if allowedHeader == "*" || allowedHeader == header {
found = true
break
}
}
if !found {
return false
}
}
return true
}

View File

@ -1,288 +0,0 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
)
var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
})
func assertHeaders(t *testing.T, resHeaders http.Header, reqHeaders map[string]string) {
for name, value := range reqHeaders {
if resHeaders.Get(name) != value {
t.Errorf("Invalid header `%s', wanted `%s', got `%s'", name, value, resHeaders.Get(name))
}
}
}
func TestNoConfig(t *testing.T) {
s := New(Options{
// Intentionally left blank.
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestWildcardOrigin(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"*"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestAllowedOrigin(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestDisallowedOrigin(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://barbaz.com")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestAllowedMethod(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedMethods: []string{"PUT", "DELETE"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "PUT")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "PUT",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestDisallowedMethod(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedMethods: []string{"PUT", "DELETE"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "PATCH")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestAllowedHeader(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedHeaders: []string{"X-Header-1", "x-header-2"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "GET")
req.Header.Add("Access-Control-Request-Headers", "X-Header-2, X-HEADER-1")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestAllowedWildcardHeader(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedHeaders: []string{"*"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "GET")
req.Header.Add("Access-Control-Request-Headers", "X-Header-2, X-HEADER-1")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "X-Header-2, X-Header-1",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestDisallowedHeader(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowedHeaders: []string{"X-Header-1", "x-header-2"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "GET")
req.Header.Add("Access-Control-Request-Headers", "X-Header-3, X-Header-1")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestOriginHeader(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "GET")
req.Header.Add("Access-Control-Request-Headers", "origin")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "Origin",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}
func TestExposedHeader(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
ExposedHeaders: []string{"X-Header-1", "x-header-2"},
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "X-Header-1, X-Header-2",
})
}
func TestAllowedCredentials(t *testing.T) {
s := New(Options{
AllowedOrigins: []string{"http://foobar.com"},
AllowCredentials: true,
})
res := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
req.Header.Add("Origin", "http://foobar.com")
req.Header.Add("Access-Control-Request-Method", "GET")
s.Handler(testHandler).ServeHTTP(res, req)
assertHeaders(t, res.Header(), map[string]string{
"Access-Control-Allow-Origin": "http://foobar.com",
"Access-Control-Allow-Methods": "GET",
"Access-Control-Allow-Headers": "",
"Access-Control-Allow-Credentials": "true",
"Access-Control-Max-Age": "",
"Access-Control-Expose-Headers": "",
})
}

View File

@ -1,24 +0,0 @@
package main
import (
"net/http"
"github.com/justinas/alice"
"github.com/rs/cors"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
})
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
chain := alice.New(c.Handler).Then(mux)
http.ListenAndServe(":8080", chain)
}

View File

@ -1,18 +0,0 @@
package main
import (
"net/http"
"github.com/rs/cors"
)
func main() {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
// Use default options
handler := cors.Default().Handler(h)
http.ListenAndServe(":8080", handler)
}

View File

@ -1,22 +0,0 @@
package main
import (
"net/http"
"github.com/rs/cors"
"github.com/zenazn/goji"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
})
goji.Use(c.Handler)
goji.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
goji.Serve()
}

View File

@ -1,23 +0,0 @@
package main
import (
"github.com/go-martini/martini"
"github.com/martini-contrib/render"
"github.com/rs/cors"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
})
m := martini.Classic()
m.Use(render.Renderer())
m.Use(c.HandlerFunc)
m.Get("/", func(r render.Render) {
r.JSON(200, map[string]interface{}{"hello": "world"})
})
m.Run()
}

View File

@ -1,26 +0,0 @@
package main
import (
"net/http"
"github.com/codegangsta/negroni"
"github.com/rs/cors"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
})
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
n := negroni.Classic()
n.Use(c)
n.UseHandler(mux)
n.Run(":3000")
}

View File

@ -1,20 +0,0 @@
package main
import (
"net/http"
"github.com/rs/cors"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://foo.com"},
})
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
http.ListenAndServe(":8080", c.Handler(handler))
}

View File

@ -1,22 +0,0 @@
package main
import (
"net/http"
"github.com/rs/cors"
)
func main() {
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"},
AllowCredentials: true,
})
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte("{\"hello\": \"world\"}"))
})
http.ListenAndServe(":8080", c.Handler(h))
}

View File

@ -1,27 +0,0 @@
package cors
import (
"net/http"
"strings"
)
type converter func(string) string
// convert converts a list of string using the passed converter function
func convert(s []string, c converter) []string {
out := []string{}
for _, i := range s {
out = append(out, c(i))
}
return out
}
func parseHeaderList(headerList string) (headers []string) {
for _, header := range strings.Split(headerList, ",") {
header = http.CanonicalHeaderKey(strings.TrimSpace(header))
if header != "" {
headers = append(headers, header)
}
}
return headers
}

View File

@ -1,28 +0,0 @@
package cors
import (
"strings"
"testing"
)
func TestConvert(t *testing.T) {
s := convert([]string{"A", "b", "C"}, strings.ToLower)
e := []string{"a", "b", "c"}
if s[0] != e[0] || s[1] != e[1] || s[2] != e[2] {
t.Errorf("%v != %v", s, e)
}
}
func TestParseHeaderList(t *testing.T) {
h := parseHeaderList("header, second-header, THIRD-HEADER")
e := []string{"Header", "Second-Header", "Third-Header"}
if h[0] != e[0] || h[1] != e[1] || h[2] != e[2] {
t.Errorf("%v != %v", h, e)
}
}
func TestParseHeaderListEmpty(t *testing.T) {
if len(parseHeaderList("")) != 0 {
t.Error("should be empty sclice")
}
}

View File

@ -0,0 +1,113 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
var portMap = map[string]string{
"ws": "80",
"wss": "443",
}
func parseAuthority(location *url.URL) string {
if _, ok := portMap[location.Scheme]; ok {
if _, _, err := net.SplitHostPort(location.Host); err != nil {
return net.JoinHostPort(location.Host, portMap[location.Scheme])
}
}
return location.Host
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
switch config.Location.Scheme {
case "ws":
client, err = net.Dial("tcp", parseAuthority(config.Location))
case "wss":
client, err = tls.Dial("tcp", parseAuthority(config.Location), config.TlsConfig)
default:
err = ErrBadScheme
}
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
client.Close()
goto Error
}
return
Error:
return nil, &DialError{config, err}
}

View File

@ -0,0 +1,31 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"fmt"
"log"
"golang.org/x/net/websocket"
)
// This example demonstrates a trivial client.
func ExampleDial() {
origin := "http://localhost/"
url := "ws://localhost:12345/ws"
ws, err := websocket.Dial(url, "", origin)
if err != nil {
log.Fatal(err)
}
if _, err := ws.Write([]byte("hello, world!\n")); err != nil {
log.Fatal(err)
}
var msg = make([]byte, 512)
var n int
if n, err = ws.Read(msg); err != nil {
log.Fatal(err)
}
fmt.Printf("Received: %s.\n", msg[:n])
}

View File

@ -0,0 +1,26 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"io"
"net/http"
"golang.org/x/net/websocket"
)
// Echo the data received on the WebSocket.
func EchoServer(ws *websocket.Conn) {
io.Copy(ws, ws)
}
// This example demonstrates a trivial echo server.
func ExampleHandler() {
http.Handle("/echo", websocket.Handler(EchoServer))
err := http.ListenAndServe(":12345", nil)
if err != nil {
panic("ListenAndServe: " + err.Error())
}
}

View File

@ -0,0 +1,564 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
// This file implements a protocol of hybi draft.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
const (
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
closeStatusNormal = 1000
closeStatusGoingAway = 1001
closeStatusProtocolError = 1002
closeStatusUnsupportedData = 1003
closeStatusFrameTooLarge = 1004
closeStatusNoStatusRcvd = 1005
closeStatusAbnormalClosure = 1006
closeStatusBadMessageData = 1007
closeStatusPolicyViolation = 1008
closeStatusTooBigData = 1009
closeStatusExtensionMismatch = 1010
maxControlFramePayloadLength = 125
)
var (
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
ErrBadPongMessage = &ProtocolError{"bad pong message"}
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
type hybiFrameHeader struct {
Fin bool
Rsv [3]bool
OpCode byte
Length int64
MaskingKey []byte
data *bytes.Buffer
}
// A hybiFrameReader is a reader for hybi frame.
type hybiFrameReader struct {
reader io.Reader
header hybiFrameHeader
pos int64
length int
}
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
n, err = frame.reader.Read(msg)
if err != nil {
return 0, err
}
if frame.header.MaskingKey != nil {
for i := 0; i < n; i++ {
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
frame.pos++
}
}
return n, err
}
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
func (frame *hybiFrameReader) HeaderReader() io.Reader {
if frame.header.data == nil {
return nil
}
if frame.header.data.Len() == 0 {
return nil
}
return frame.header.data
}
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
type hybiFrameReaderFactory struct {
*bufio.Reader
}
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
// See Section 5.2 Base Framing protocol for detail.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
hybiFrame := new(hybiFrameReader)
frame = hybiFrame
var header []byte
var b byte
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
for i := 0; i < 3; i++ {
j := uint(6 - i)
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
}
hybiFrame.header.OpCode = header[0] & 0x0f
// Second byte. Mask/Payload len(7bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
mask := (b & 0x80) != 0
b &= 0x7f
lengthFields := 0
switch {
case b <= 125: // Payload length 7bits.
hybiFrame.header.Length = int64(b)
case b == 126: // Payload length 7+16bits
lengthFields = 2
case b == 127: // Payload length 7+64bits
lengthFields = 8
}
for i := 0; i < lengthFields; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
}
if mask {
// Masking key. 4 bytes.
for i := 0; i < 4; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
}
}
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
hybiFrame.header.data = bytes.NewBuffer(header)
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
return
}
// A HybiFrameWriter is a writer for hybi frame.
type hybiFrameWriter struct {
writer *bufio.Writer
header *hybiFrameHeader
}
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
var header []byte
var b byte
if frame.header.Fin {
b |= 0x80
}
for i := 0; i < 3; i++ {
if frame.header.Rsv[i] {
j := uint(6 - i)
b |= 1 << j
}
}
b |= frame.header.OpCode
header = append(header, b)
if frame.header.MaskingKey != nil {
b = 0x80
} else {
b = 0
}
lengthFields := 0
length := len(msg)
switch {
case length <= 125:
b |= byte(length)
case length < 65536:
b |= 126
lengthFields = 2
default:
b |= 127
lengthFields = 8
}
header = append(header, b)
for i := 0; i < lengthFields; i++ {
j := uint((lengthFields - i - 1) * 8)
b = byte((length >> j) & 0xff)
header = append(header, b)
}
if frame.header.MaskingKey != nil {
if len(frame.header.MaskingKey) != 4 {
return 0, ErrBadMaskingKey
}
header = append(header, frame.header.MaskingKey...)
frame.writer.Write(header)
data := make([]byte, length)
for i := range data {
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
}
frame.writer.Write(data)
err = frame.writer.Flush()
return length, err
}
frame.writer.Write(header)
frame.writer.Write(msg)
err = frame.writer.Flush()
return length, err
}
func (frame *hybiFrameWriter) Close() error { return nil }
type hybiFrameWriterFactory struct {
*bufio.Writer
needMaskingKey bool
}
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
if buf.needMaskingKey {
frameHeader.MaskingKey, err = generateMaskingKey()
if err != nil {
return nil, err
}
}
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
}
type hybiFrameHandler struct {
conn *Conn
payloadType byte
}
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
} else {
// The server MUST NOT mask all frames.
if frame.(*hybiFrameReader).header.MaskingKey != nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
case TextFrame, BinaryFrame:
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
case PingFrame:
pingMsg := make([]byte, maxControlFramePayloadLength)
n, err := io.ReadFull(frame, pingMsg)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
n, err = handler.WritePong(pingMsg[:n])
if err != nil {
return nil, err
}
return nil, nil
case PongFrame:
return nil, ErrNotImplemented
}
return frame, nil
}
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
if err != nil {
return err
}
msg := make([]byte, 2)
binary.BigEndian.PutUint16(msg, uint16(status))
_, err = w.Write(msg)
w.Close()
return err
}
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
if buf == nil {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
frameWriterFactory: hybiFrameWriterFactory{
buf.Writer, request == nil},
PayloadType: TextFrame,
defaultCloseStatus: closeStatusNormal}
ws.frameHandler = &hybiFrameHandler{conn: ws}
return ws
}
// generateMaskingKey generates a masking key for a frame.
func generateMaskingKey() (maskingKey []byte, err error) {
maskingKey = make([]byte, 4)
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
return
}
return
}
// generateNonce generates a nonce consisting of a randomly selected 16-byte
// value that has been base64-encoded.
func generateNonce() (nonce []byte) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
nonce = make([]byte, 24)
base64.StdEncoding.Encode(nonce, key)
return
}
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func getNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(websocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
bw.WriteString("Host: " + config.Location.Host + "\r\n")
bw.WriteString("Upgrade: websocket\r\n")
bw.WriteString("Connection: Upgrade\r\n")
nonce := generateNonce()
if config.handshakeData != nil {
nonce = []byte(config.handshakeData["key"])
}
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
if config.Version != ProtocolVersionHybi13 {
return ErrBadProtocolVersion
}
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
return err
}
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
if err != nil {
return err
}
if resp.StatusCode != 101 {
return ErrBadStatus
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
expectedAccept, err := getNonceAccept(nonce)
if err != nil {
return err
}
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
return ErrChallengeResponse
}
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
return ErrUnsupportedExtensions
}
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return ErrBadWebSocketProtocol
}
config.Protocol = []string{offeredProtocol}
}
return nil
}
// newHybiClientConn creates a client WebSocket connection after handshake.
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
return newHybiConn(config, buf, rwc, nil)
}
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
type hybiServerHandshaker struct {
*Config
accept []byte
}
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
c.Version = ProtocolVersionHybi13
if req.Method != "GET" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}
// HTTP version can be safely ignored.
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
return http.StatusBadRequest, ErrNotWebSocket
}
key := req.Header.Get("Sec-Websocket-Key")
if key == "" {
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
switch version {
case "13":
c.Version = ProtocolVersionHybi13
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
var scheme string
if req.TLS != nil {
scheme = "wss"
} else {
scheme = "ws"
}
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
if err != nil {
return http.StatusBadRequest, err
}
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocol != "" {
protocols := strings.Split(protocol, ",")
for i := 0; i < len(protocols); i++ {
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
}
}
c.accept, err = getNonceAccept([]byte(key))
if err != nil {
return http.StatusInternalServerError, err
}
return http.StatusSwitchingProtocols, nil
}
// Origin parses Origin header in "req".
// If origin is "null", returns (nil, nil).
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
if origin == "null" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
// You need choose a Protocol in Handshake func in Server.
return ErrBadWebSocketProtocol
}
}
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiServerConn(c.Config, buf, rwc, request)
}
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiConn(config, buf, rwc, request)
}

View File

@ -0,0 +1,590 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
// Test the getNonceAccept function with values in
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
func TestSecWebSocketAccept(t *testing.T) {
nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==")
expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
accept, err := getNonceAccept(nonce)
if err != nil {
t.Errorf("getNonceAccept: returned error %v", err)
return
}
if !bytes.Equal(expected, accept) {
t.Errorf("getNonceAccept: expected %q got %q", expected, accept)
}
}
func TestHybiClientHandshake(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiClientHandshakeWithHeader(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.Header = http.Header(make(map[string][]string))
config.Header.Add("User-Agent", "test")
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
"User-Agent": "test",
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiServerHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
expectedProtocols := []string{"chat", "superchat"}
if fmt.Sprintf("%v", config.Protocol) != fmt.Sprintf("%v", expectedProtocols) {
t.Errorf("protocol expected %q but got %q", expectedProtocols, config.Protocol)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = config.Protocol[:1]
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeNoSubProtocol(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
if len(config.Protocol) != 0 {
t.Errorf("len(config.Protocol) expected 0, but got %q", len(config.Protocol))
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeHybiBadVersion(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 9
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != ErrBadWebSocketVersion {
t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err)
}
if code != http.StatusBadRequest {
t.Errorf("status expected %q but got %q", http.StatusBadRequest, code)
}
}
func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) {
b := bytes.NewBuffer([]byte{})
frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false}
w, _ := frameWriterFactory.NewFrameWriter(TextFrame)
w.(*hybiFrameWriter).header = frameHeader
_, err := w.Write(testPayload)
w.Close()
if err != nil {
t.Errorf("Write error %q", err)
}
var expectedFrame []byte
expectedFrame = append(expectedFrame, testHeader...)
expectedFrame = append(expectedFrame, testMaskedPayload...)
if !bytes.Equal(expectedFrame, b.Bytes()) {
t.Errorf("frame expected %q got %q", expectedFrame, b.Bytes())
}
frameReaderFactory := &hybiFrameReaderFactory{bufio.NewReader(b)}
r, err := frameReaderFactory.NewFrameReader()
if err != nil {
t.Errorf("Read error %q", err)
}
if header := r.HeaderReader(); header == nil {
t.Errorf("no header")
} else {
actualHeader := make([]byte, r.Len())
n, err := header.Read(actualHeader)
if err != nil {
t.Errorf("Read header error %q", err)
} else {
if n < len(testHeader) {
t.Errorf("header too short %q got %q", testHeader, actualHeader[:n])
}
if !bytes.Equal(testHeader, actualHeader[:n]) {
t.Errorf("header expected %q got %q", testHeader, actualHeader[:n])
}
}
}
if trailer := r.TrailerReader(); trailer != nil {
t.Errorf("unexpected trailer %q", trailer)
}
frame := r.(*hybiFrameReader)
if frameHeader.Fin != frame.header.Fin ||
frameHeader.OpCode != frame.header.OpCode ||
len(testPayload) != int(frame.header.Length) {
t.Errorf("mismatch %v (%d) vs %v", frameHeader, len(testPayload), frame)
}
payload := make([]byte, len(testPayload))
_, err = r.Read(payload)
if err != nil {
t.Errorf("read %v", err)
}
if !bytes.Equal(testPayload, payload) {
t.Errorf("payload %q vs %q", testPayload, payload)
}
}
func TestHybiShortTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x81, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x81, 125}, payload, payload, frameHeader)
}
func TestHybiShortMaskedTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame,
MaskingKey: []byte{0xcc, 0x55, 0x80, 0x20}}
payload := []byte("hello")
maskedPayload := []byte{0xa4, 0x30, 0xec, 0x4c, 0xa3}
header := []byte{0x81, 0x85}
header = append(header, frameHeader.MaskingKey...)
testHybiFrame(t, header, payload, maskedPayload, frameHeader)
}
func TestHybiShortBinaryFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: BinaryFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x82, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x82, 125}, payload, payload, frameHeader)
}
func TestHybiControlFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame}
payload = []byte{0x03, 0xe8} // 1000
testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader)
}
func TestHybiLongFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := make([]byte, 126)
testHybiFrame(t, []byte{0x81, 126, 0x00, 126}, payload, payload, frameHeader)
payload = make([]byte, 65535)
testHybiFrame(t, []byte{0x81, 126, 0xff, 0xff}, payload, payload, frameHeader)
payload = make([]byte, 65536)
testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader)
}
func TestHybiClientRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[2:7], msg[:n]) {
t.Errorf("read 1st frame %v, got %v", wireData[2:7], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[16:21], msg[:n]) {
t.Errorf("read 2nd frame %v, got %v", wireData[16:21], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiShortRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
step := 0
pos := 0
expectedPos := []int{2, 5, 16, 19}
expectedLen := []int{3, 2, 3, 2}
for {
msg := make([]byte, 3)
n, err := conn.Read(msg)
if step >= len(expectedPos) {
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
return
}
pos = expectedPos[step]
endPos := pos + expectedLen[step]
if err != nil {
t.Errorf("read from %d, got error %q", pos, err)
return
}
if n != endPos-pos {
t.Errorf("read from %d, expect %d, got %d", pos, endPos-pos, n)
}
if !bytes.Equal(wireData[pos:endPos], msg[:n]) {
t.Errorf("read from %d, frame %v, got %v", pos, wireData[pos:endPos], msg[:n])
}
step++
}
}
func TestHybiServerRead(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
0x89, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello
0x81, 0x85, 0xed, 0x83, 0xb4, 0x24,
0x9a, 0xec, 0xc6, 0x48, 0x89, // world
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
expected := [][]byte{[]byte("hello"), []byte("world")}
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[0], msg[:n]) {
t.Errorf("read 1st frame %q, got %q", expected[0], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[1], msg[:n]) {
t.Errorf("read 2nd frame %q, got %q", expected[1], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiServerReadWithoutMasking(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
// server MUST close the connection upon receiving a non-masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
func TestHybiClientReadWithMasking(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
// client MUST close the connection upon receiving a masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
// Test the hybiServerHandshaker supports firefox implementation and
// checks Connection request header include (but it's not necessary
// equal to) "upgrade"
func TestHybiServerFirefoxHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: keep-alive, upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = []string{"chat"}
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}

View File

@ -0,0 +1,114 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
return
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser client, which doesn't send Origin header, you could use Server
//. that doesn't check origin in its Handshake.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}

View File

@ -0,0 +1,411 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements a client and server for the WebSocket protocol
// as specified in RFC 6455.
package websocket
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
ProtocolVersionHybi13 = 13
ProtocolVersionHybi = ProtocolVersionHybi13
SupportedProtocolVersion = "13"
ContinuationFrame = 0
TextFrame = 1
BinaryFrame = 2
CloseFrame = 8
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
)
// ProtocolError represents WebSocket protocol errors.
type ProtocolError struct {
ErrorString string
}
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
ErrBadScheme = &ProtocolError{"bad scheme"}
ErrBadStatus = &ProtocolError{"bad status"}
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
ErrBadFrame = &ProtocolError{"bad frame"}
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
ErrBadRequestMethod = &ProtocolError{"bad method"}
ErrNotSupported = &ProtocolError{"not supported"}
)
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string { return "websocket" }
// Config is a WebSocket configuration
type Config struct {
// A WebSocket server address.
Location *url.URL
// A Websocket client origin.
Origin *url.URL
// WebSocket subprotocols.
Protocol []string
// WebSocket protocol version.
Version int
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
handshakeData map[string]string
}
// serverHandshaker is an interface to handle WebSocket server side handshake.
type serverHandshaker interface {
// ReadHandshake reads handshake request message from client.
// Returns http response code and error if any.
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
// AcceptHandshake accepts the client handshake request and sends
// handshake response back to client.
AcceptHandshake(buf *bufio.Writer) (err error)
// NewServerConn creates a new WebSocket connection.
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
}
// frameReader is an interface to read a WebSocket frame.
type frameReader interface {
// Reader is to read payload of the frame.
io.Reader
// PayloadType returns payload type.
PayloadType() byte
// HeaderReader returns a reader to read header of the frame.
HeaderReader() io.Reader
// TrailerReader returns a reader to read trailer of the frame.
// If it returns nil, there is no trailer in the frame.
TrailerReader() io.Reader
// Len returns total length of the frame, including header and trailer.
Len() int
}
// frameReaderFactory is an interface to creates new frame reader.
type frameReaderFactory interface {
NewFrameReader() (r frameReader, err error)
}
// frameWriter is an interface to write a WebSocket frame.
type frameWriter interface {
// Writer is to write payload of the frame.
io.WriteCloser
}
// frameWriterFactory is an interface to create new frame writer.
type frameWriterFactory interface {
NewFrameWriter(payloadType byte) (w frameWriter, err error)
}
type frameHandler interface {
HandleFrame(frame frameReader) (r frameReader, err error)
WriteClose(status int) (err error)
}
// Conn represents a WebSocket connection.
type Conn struct {
config *Config
request *http.Request
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
rio sync.Mutex
frameReaderFactory
frameReader
wio sync.Mutex
frameWriterFactory
frameHandler
PayloadType byte
defaultCloseStatus int
}
// Read implements the io.Reader interface:
// it reads data of a frame from the WebSocket connection.
// if msg is not large enough for the frame data, it fills the msg and next Read
// will read the rest of the frame data.
// it reads Text frame or Binary frame.
func (ws *Conn) Read(msg []byte) (n int, err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
again:
if ws.frameReader == nil {
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return 0, err
}
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return 0, err
}
if ws.frameReader == nil {
goto again
}
}
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
}
ws.frameReader = nil
goto again
}
return n, err
}
// Write implements the io.Writer interface:
// it writes data as a frame to the WebSocket connection.
func (ws *Conn) Write(msg []byte) (n int, err error) {
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
if err != nil {
return n, err
}
return n, err
}
// Close implements the io.Closer interface.
func (ws *Conn) Close() error {
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
if err != nil {
return err
}
return ws.rwc.Close()
}
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
// LocalAddr returns the WebSocket Origin for the connection for client, or
// the WebSocket location for server.
func (ws *Conn) LocalAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Origin}
}
return &Addr{ws.config.Location}
}
// RemoteAddr returns the WebSocket location for the connection for client, or
// the Websocket Origin for server.
func (ws *Conn) RemoteAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Location}
}
return &Addr{ws.config.Origin}
}
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
// SetDeadline sets the connection's network read & write deadlines.
func (ws *Conn) SetDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetDeadline(t)
}
return errSetDeadline
}
// SetReadDeadline sets the connection's network read deadline.
func (ws *Conn) SetReadDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetReadDeadline(t)
}
return errSetDeadline
}
// SetWriteDeadline sets the connection's network write deadline.
func (ws *Conn) SetWriteDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetWriteDeadline(t)
}
return errSetDeadline
}
// Config returns the WebSocket config.
func (ws *Conn) Config() *Config { return ws.config }
// Request returns the http request upgraded to the WebSocket.
// It is nil for client side.
func (ws *Conn) Request() *http.Request { return ws.request }
// Codec represents a symmetric pair of functions that implement a codec.
type Codec struct {
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
}
// Send sends v marshaled by cd.Marshal as single frame to ws.
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
data, payloadType, err := cd.Marshal(v)
if err != nil {
return err
}
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
if err != nil {
return err
}
_, err = w.Write(data)
w.Close()
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
if err != nil {
return err
}
ws.frameReader = nil
}
again:
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return err
}
frame, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return err
}
if frame == nil {
goto again
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {
return err
}
return cd.Unmarshal(data, payloadType, v)
}
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
switch data := v.(type) {
case string:
return []byte(data), TextFrame, nil
case []byte:
return data, BinaryFrame, nil
}
return nil, UnknownFrame, ErrNotSupported
}
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
switch data := v.(type) {
case *string:
*data = string(msg)
return nil
case *[]byte:
*data = msg
return nil
}
return ErrNotSupported
}
/*
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
To send/receive text frame, use string type.
To send/receive binary frame, use []byte type.
Trivial usage:
import "websocket"
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
// send text frame
message = "hello"
websocket.Message.Send(ws, message)
// receive binary frame
var data []byte
websocket.Message.Receive(ws, &data)
// send binary frame
data = []byte{0, 1, 2}
websocket.Message.Send(ws, data)
*/
var Message = Codec{marshal, unmarshal}
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
msg, err = json.Marshal(v)
return msg, TextFrame, err
}
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
return json.Unmarshal(msg, v)
}
/*
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
Trivial usage:
import "websocket"
type T struct {
Msg string
Count int
}
// receive JSON type T
var data T
websocket.JSON.Receive(ws, &data)
// send JSON type T
websocket.JSON.Send(ws, data)
*/
var JSON = Codec{jsonMarshal, jsonUnmarshal}

View File

@ -0,0 +1,414 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
)
var serverAddr string
var once sync.Once
func echoServer(ws *Conn) { io.Copy(ws, ws) }
type Count struct {
S string
N int
}
func countServer(ws *Conn) {
for {
var count Count
err := JSON.Receive(ws, &count)
if err != nil {
return
}
count.N++
count.S = strings.Repeat(count.S, count.N)
err = JSON.Send(ws, count)
if err != nil {
return
}
}
}
func subProtocolHandshake(config *Config, req *http.Request) error {
for _, proto := range config.Protocol {
if proto == "chat" {
config.Protocol = []string{proto}
return nil
}
}
return ErrBadWebSocketProtocol
}
func subProtoServer(ws *Conn) {
for _, proto := range ws.Config().Protocol {
io.WriteString(ws, proto)
}
}
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
subproto := Server{
Handshake: subProtocolHandshake,
Handler: Handler(subProtoServer),
}
http.Handle("/subproto", subproto)
server := httptest.NewServer(nil)
serverAddr = server.Listener.Addr().String()
log.Print("Test WebSocket server listening on ", serverAddr)
}
func newConfig(t *testing.T, path string) *Config {
config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
return config
}
func TestEcho(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var actual_msg = make([]byte, 512)
n, err := conn.Read(actual_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
actual_msg = actual_msg[0:n]
if !bytes.Equal(msg, actual_msg) {
t.Errorf("Echo: expected %q got %q", msg, actual_msg)
}
conn.Close()
}
func TestAddr(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
ra := conn.RemoteAddr().String()
if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
t.Errorf("Bad remote addr: %v", ra)
}
la := conn.LocalAddr().String()
if !strings.HasPrefix(la, "http://") {
t.Errorf("Bad local addr: %v", la)
}
conn.Close()
}
func TestCount(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/count"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
var count Count
count.S = "hello"
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 1 {
t.Errorf("count: expected %d got %d", 1, count.N)
}
if count.S != "hello" {
t.Errorf("count: expected %q got %q", "hello", count.S)
}
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 2 {
t.Errorf("count: expected %d got %d", 2, count.N)
}
if count.S != "hellohello" {
t.Errorf("count: expected %q got %q", "hellohello", count.S)
}
conn.Close()
}
func TestWithQuery(t *testing.T) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/echo")
config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
if err != nil {
t.Fatal("location url", err)
}
ws, err := NewClient(config, client)
if err != nil {
t.Errorf("WebSocket handshake: %v", err)
return
}
ws.Close()
}
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/subproto")
config.Protocol = subproto
ws, err := NewClient(config, client)
if err != nil {
return "", err
}
msg := make([]byte, 16)
n, err := ws.Read(msg)
if err != nil {
return "", err
}
ws.Close()
return string(msg[:n]), nil
}
func TestWithProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithTwoProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"test", "chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithBadProtocol(t *testing.T) {
_, err := testWithProtocol(t, []string{"test"})
if err != ErrBadStatus {
t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
}
}
func TestHTTP(t *testing.T) {
once.Do(startServer)
// If the client did not send a handshake that matches the protocol
// specification, the server MUST return an HTTP response with an
// appropriate error code (such as 400 Bad Request)
resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
if err != nil {
t.Errorf("Get: error %#v", err)
return
}
if resp == nil {
t.Error("Get: resp is null")
return
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
}
}
func TestTrailingSpaces(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=955
// The last runs of this create keys with trailing spaces that should not be
// generated by the client.
once.Do(startServer)
config := newConfig(t, "/echo")
for i := 0; i < 30; i++ {
// body
ws, err := DialConfig(config)
if err != nil {
t.Errorf("Dial #%d failed: %v", i, err)
break
}
ws.Close()
}
}
func TestDialConfigBadVersion(t *testing.T) {
once.Do(startServer)
config := newConfig(t, "/echo")
config.Version = 1234
_, err := DialConfig(config)
if dialerr, ok := err.(*DialError); ok {
if dialerr.Err != ErrBadProtocolVersion {
t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
}
}
}
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var small_msg = make([]byte, 8)
n, err := conn.Read(small_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
}
var second_msg = make([]byte, len(msg))
n, err = conn.Read(second_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
second_msg = second_msg[0:n]
if !bytes.Equal(msg[len(small_msg):], second_msg) {
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
}
conn.Close()
}
var parseAuthorityTests = []struct {
in *url.URL
out string
}{
{
&url.URL{
Scheme: "ws",
Host: "www.google.com",
},
"www.google.com:80",
},
{
&url.URL{
Scheme: "wss",
Host: "www.google.com",
},
"www.google.com:443",
},
{
&url.URL{
Scheme: "ws",
Host: "www.google.com:80",
},
"www.google.com:80",
},
{
&url.URL{
Scheme: "wss",
Host: "www.google.com:443",
},
"www.google.com:443",
},
// some invalid ones for parseAuthority. parseAuthority doesn't
// concern itself with the scheme unless it actually knows about it
{
&url.URL{
Scheme: "http",
Host: "www.google.com",
},
"www.google.com",
},
{
&url.URL{
Scheme: "http",
Host: "www.google.com:80",
},
"www.google.com:80",
},
{
&url.URL{
Scheme: "asdf",
Host: "127.0.0.1",
},
"127.0.0.1",
},
{
&url.URL{
Scheme: "asdf",
Host: "www.google.com",
},
"www.google.com",
},
}
func TestParseAuthority(t *testing.T) {
for _, tt := range parseAuthorityTests {
out := parseAuthority(tt.in)
if out != tt.out {
t.Errorf("got %v; want %v", out, tt.out)
}
}
}