node: allow websocket and HTTP on the same port (#20810)

This change makes it possible to run geth with JSON-RPC over HTTP and
WebSocket on the same TCP port. The default port for WebSocket
is still 8546. 

    geth --rpc --rpcport 8545 --ws --wsport 8545

This also removes a lot of deprecated API surface from package rpc.
The rpc package is now purely about serving JSON-RPC and no longer
provides a way to start an HTTP server.
This commit is contained in:
rene
2020-04-08 13:33:12 +02:00
committed by GitHub
parent 5065cdefff
commit 07d909ff32
13 changed files with 443 additions and 268 deletions

View File

@ -186,7 +186,7 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
}
}
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts); err != nil {
if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil {
return false, err
}
return true, nil

99
node/endpoints.go Normal file
View File

@ -0,0 +1,99 @@
// Copyright 2018 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"net"
"net/http"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rpc"
)
// StartHTTPEndpoint starts the HTTP RPC endpoint.
func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http.Handler) (net.Listener, error) {
// start the HTTP listener
var (
listener net.Listener
err error
)
if listener, err = net.Listen("tcp", endpoint); err != nil {
return nil, err
}
// make sure timeout values are meaningful
CheckTimeouts(&timeouts)
// Bundle and start the HTTP server
httpSrv := &http.Server{
Handler: handler,
ReadTimeout: timeouts.ReadTimeout,
WriteTimeout: timeouts.WriteTimeout,
IdleTimeout: timeouts.IdleTimeout,
}
go httpSrv.Serve(listener)
return listener, err
}
// startWSEndpoint starts a websocket endpoint.
func startWSEndpoint(endpoint string, handler http.Handler) (net.Listener, error) {
// start the HTTP listener
var (
listener net.Listener
err error
)
if listener, err = net.Listen("tcp", endpoint); err != nil {
return nil, err
}
wsSrv := &http.Server{Handler: handler}
go wsSrv.Serve(listener)
return listener, err
}
// checkModuleAvailability checks that all names given in modules are actually
// available API services. It assumes that the MetadataApi module ("rpc") is always available;
// the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints.
func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available []string) {
availableSet := make(map[string]struct{})
for _, api := range apis {
if _, ok := availableSet[api.Namespace]; !ok {
availableSet[api.Namespace] = struct{}{}
available = append(available, api.Namespace)
}
}
for _, name := range modules {
if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi {
bad = append(bad, name)
}
}
return bad, available
}
// CheckTimeouts ensures that timeout values are meaningful
func CheckTimeouts(timeouts *rpc.HTTPTimeouts) {
if timeouts.ReadTimeout < time.Second {
log.Warn("Sanitizing invalid HTTP read timeout", "provided", timeouts.ReadTimeout, "updated", rpc.DefaultHTTPTimeouts.ReadTimeout)
timeouts.ReadTimeout = rpc.DefaultHTTPTimeouts.ReadTimeout
}
if timeouts.WriteTimeout < time.Second {
log.Warn("Sanitizing invalid HTTP write timeout", "provided", timeouts.WriteTimeout, "updated", rpc.DefaultHTTPTimeouts.WriteTimeout)
timeouts.WriteTimeout = rpc.DefaultHTTPTimeouts.WriteTimeout
}
if timeouts.IdleTimeout < time.Second {
log.Warn("Sanitizing invalid HTTP idle timeout", "provided", timeouts.IdleTimeout, "updated", rpc.DefaultHTTPTimeouts.IdleTimeout)
timeouts.IdleTimeout = rpc.DefaultHTTPTimeouts.IdleTimeout
}
}

View File

@ -291,17 +291,21 @@ func (n *Node) startRPC(services map[reflect.Type]Service) error {
n.stopInProc()
return err
}
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts); err != nil {
if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil {
n.stopIPC()
n.stopInProc()
return err
}
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
n.stopHTTP()
n.stopIPC()
n.stopInProc()
return err
// if endpoints are not the same, start separate servers
if n.httpEndpoint != n.wsEndpoint {
if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil {
n.stopHTTP()
n.stopIPC()
n.stopInProc()
return err
}
}
// All API endpoints started successfully
n.rpcAPIs = apis
return nil
@ -359,22 +363,36 @@ func (n *Node) stopIPC() {
}
// startHTTP initializes and starts the HTTP RPC endpoint.
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error {
func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error {
// Short circuit if the HTTP endpoint isn't being exposed
if endpoint == "" {
return nil
}
listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts)
// register apis and create handler stack
srv := rpc.NewServer()
err := RegisterApisFromWhitelist(apis, modules, srv, false)
if err != nil {
return err
}
handler := NewHTTPHandlerStack(srv, cors, vhosts)
// wrap handler in websocket handler only if websocket port is the same as http rpc
if n.httpEndpoint == n.wsEndpoint {
handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins))
}
listener, err := StartHTTPEndpoint(endpoint, timeouts, handler)
if err != nil {
return err
}
n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", listener.Addr()),
"cors", strings.Join(cors, ","),
"vhosts", strings.Join(vhosts, ","))
if n.httpEndpoint == n.wsEndpoint {
n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", listener.Addr()))
}
// All listeners booted successfully
n.httpEndpoint = endpoint
n.httpListener = listener
n.httpHandler = handler
n.httpHandler = srv
return nil
}
@ -399,7 +417,14 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
if endpoint == "" {
return nil
}
listener, handler, err := rpc.StartWSEndpoint(endpoint, apis, modules, wsOrigins, exposeAll)
srv := rpc.NewServer()
handler := srv.WebsocketHandler(wsOrigins)
err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll)
if err != nil {
return err
}
listener, err := startWSEndpoint(endpoint, handler)
if err != nil {
return err
}
@ -407,7 +432,7 @@ func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrig
// All listeners booted successfully
n.wsEndpoint = endpoint
n.wsListener = listener
n.wsHandler = handler
n.wsHandler = srv
return nil
}
@ -664,3 +689,25 @@ func (n *Node) apis() []rpc.API {
},
}
}
// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules,
// and then registers all of the APIs exposed by the services.
func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error {
if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 {
log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available)
}
// Generate the whitelist based on the allowed modules
whitelist := make(map[string]bool)
for _, module := range modules {
whitelist[module] = true
}
// Register all the APIs exposed by the services
for _, api := range apis {
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := srv.RegisterName(api.Namespace, api.Service); err != nil {
return err
}
}
}
return nil
}

View File

@ -19,6 +19,7 @@ package node
import (
"errors"
"io/ioutil"
"net/http"
"os"
"reflect"
"testing"
@ -27,6 +28,8 @@ import (
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/assert"
)
var (
@ -597,3 +600,58 @@ func TestAPIGather(t *testing.T) {
}
}
}
func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) {
node := startHTTP(t)
defer node.stopHTTP()
wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
if err != nil {
t.Error("could not issue new http request ", err)
}
wsReq.Header.Set("Connection", "upgrade")
wsReq.Header.Set("Upgrade", "websocket")
wsReq.Header.Set("Sec-WebSocket-Version", "13")
wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
resp := doHTTPRequest(t, wsReq)
assert.Equal(t, "websocket", resp.Header.Get("Upgrade"))
}
func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) {
node := startHTTP(t)
defer node.stopHTTP()
httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil)
if err != nil {
t.Error("could not issue new http request ", err)
}
httpReq.Header.Set("Accept-Encoding", "gzip")
resp := doHTTPRequest(t, httpReq)
assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding"))
}
func startHTTP(t *testing.T) *Node {
conf := &Config{HTTPPort: 7453, WSPort: 7453}
node, err := New(conf)
if err != nil {
t.Error("could not create a new node ", err)
}
err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{})
if err != nil {
t.Error("could not start http service on node ", err)
}
return node
}
func doHTTPRequest(t *testing.T, req *http.Request) *http.Response {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Error("could not issue a GET request to the given endpoint", err)
}
return resp
}

159
node/rpcstack.go Normal file
View File

@ -0,0 +1,159 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package node
import (
"compress/gzip"
"io"
"io/ioutil"
"net"
"net/http"
"strings"
"sync"
"github.com/ethereum/go-ethereum/log"
"github.com/rs/cors"
)
// NewHTTPHandlerStack returns wrapped http-related handlers
func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler {
// Wrap the CORS-handler within a host-handler
handler := newCorsHandler(srv, cors)
handler = newVHostHandler(vhosts, handler)
return newGzipHandler(handler)
}
func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
// disable CORS support if user has not specified a custom CORS configuration
if len(allowedOrigins) == 0 {
return srv
}
c := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet},
MaxAge: 600,
AllowedHeaders: []string{"*"},
})
return c.Handler(srv)
}
// virtualHostHandler is a handler which validates the Host-header of incoming requests.
// Using virtual hosts can help prevent DNS rebinding attacks, where a 'random' domain name points to
// the service ip address (but without CORS headers). By verifying the targeted virtual host, we can
// ensure that it's a destination that the node operator has defined.
type virtualHostHandler struct {
vhosts map[string]struct{}
next http.Handler
}
func newVHostHandler(vhosts []string, next http.Handler) http.Handler {
vhostMap := make(map[string]struct{})
for _, allowedHost := range vhosts {
vhostMap[strings.ToLower(allowedHost)] = struct{}{}
}
return &virtualHostHandler{vhostMap, next}
}
// ServeHTTP serves JSON-RPC requests over HTTP, implements http.Handler
func (h *virtualHostHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// if r.Host is not set, we can continue serving since a browser would set the Host header
if r.Host == "" {
h.next.ServeHTTP(w, r)
return
}
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// Either invalid (too many colons) or no port specified
host = r.Host
}
if ipAddr := net.ParseIP(host); ipAddr != nil {
// It's an IP address, we can serve that
h.next.ServeHTTP(w, r)
return
}
// Not an IP address, but a hostname. Need to validate
if _, exist := h.vhosts["*"]; exist {
h.next.ServeHTTP(w, r)
return
}
if _, exist := h.vhosts[host]; exist {
h.next.ServeHTTP(w, r)
return
}
http.Error(w, "invalid host specified", http.StatusForbidden)
}
var gzPool = sync.Pool{
New: func() interface{} {
w := gzip.NewWriter(ioutil.Discard)
return w
},
}
type gzipResponseWriter struct {
io.Writer
http.ResponseWriter
}
func (w *gzipResponseWriter) WriteHeader(status int) {
w.Header().Del("Content-Length")
w.ResponseWriter.WriteHeader(status)
}
func (w *gzipResponseWriter) Write(b []byte) (int, error) {
return w.Writer.Write(b)
}
func newGzipHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
w.Header().Set("Content-Encoding", "gzip")
gz := gzPool.Get().(*gzip.Writer)
defer gzPool.Put(gz)
gz.Reset(w)
defer gz.Close()
next.ServeHTTP(&gzipResponseWriter{ResponseWriter: w, Writer: gz}, r)
})
}
// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade
// request to the websocket protocol. If not, serves the the request with the http handler.
func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if isWebsocket(r) {
ws.ServeHTTP(w, r)
log.Debug("serving websocket request")
return
}
h.ServeHTTP(w, r)
})
}
// isWebsocket checks the header of an http request for a websocket upgrade request.
func isWebsocket(r *http.Request) bool {
return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" &&
strings.ToLower(r.Header.Get("Connection")) == "upgrade"
}

38
node/rpcstack_test.go Normal file
View File

@ -0,0 +1,38 @@
package node
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/ethereum/go-ethereum/rpc"
"github.com/stretchr/testify/assert"
)
func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) {
srv := rpc.NewServer()
handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{}))
ts := httptest.NewServer(handler)
defer ts.Close()
responses := make(chan *http.Response)
go func(responses chan *http.Response) {
client := &http.Client{}
req, _ := http.NewRequest(http.MethodGet, ts.URL, nil)
req.Header.Set("Connection", "upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==")
resp, err := client.Do(req)
if err != nil {
t.Error("could not issue a GET request to the test http server", err)
}
responses <- resp
}(responses)
response := <-responses
assert.Equal(t, "websocket", response.Header.Get("Upgrade"))
}