| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | // Copyright 2015 The go-ethereum Authors | 
					
						
							|  |  |  | // This file is part of the go-ethereum library. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Lesser General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | 
					
						
							|  |  |  | // GNU Lesser General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Lesser General Public License | 
					
						
							|  |  |  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package rpc | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2018-03-13 13:23:44 +02:00
										 |  |  | 	"bytes" | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	"crypto/tls" | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	"encoding/base64" | 
					
						
							| 
									
										
										
										
											2018-03-13 13:23:44 +02:00
										 |  |  | 	"encoding/json" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"fmt" | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	"net" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"net/http" | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	"net/url" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"os" | 
					
						
							| 
									
										
										
										
											2016-02-05 15:08:48 +02:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | 	"time" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	mapset "github.com/deckarep/golang-set" | 
					
						
							| 
									
										
										
										
											2017-02-22 14:10:07 +02:00
										 |  |  | 	"github.com/ethereum/go-ethereum/log" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"golang.org/x/net/websocket" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-03-13 13:23:44 +02:00
										 |  |  | // websocketJSONCodec is a custom JSON codec with payload size enforcement and | 
					
						
							|  |  |  | // special number parsing. | 
					
						
							|  |  |  | var websocketJSONCodec = websocket.Codec{ | 
					
						
							|  |  |  | 	// Marshal is the stock JSON marshaller used by the websocket library too. | 
					
						
							|  |  |  | 	Marshal: func(v interface{}) ([]byte, byte, error) { | 
					
						
							|  |  |  | 		msg, err := json.Marshal(v) | 
					
						
							|  |  |  | 		return msg, websocket.TextFrame, err | 
					
						
							|  |  |  | 	}, | 
					
						
							|  |  |  | 	// Unmarshal is a specialized unmarshaller to properly convert numbers. | 
					
						
							|  |  |  | 	Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { | 
					
						
							|  |  |  | 		dec := json.NewDecoder(bytes.NewReader(msg)) | 
					
						
							|  |  |  | 		dec.UseNumber() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		return dec.Decode(v) | 
					
						
							|  |  |  | 	}, | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // allowedOrigins should be a comma-separated list of allowed origin URLs. | 
					
						
							|  |  |  | // To allow connections with any origin, pass "*". | 
					
						
							| 
									
										
										
										
											2017-04-12 23:04:14 +02:00
										 |  |  | func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	return websocket.Server{ | 
					
						
							| 
									
										
										
										
											2017-04-12 23:04:14 +02:00
										 |  |  | 		Handshake: wsHandshakeValidator(allowedOrigins), | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 		Handler: func(conn *websocket.Conn) { | 
					
						
							| 
									
										
										
										
											2018-03-13 13:23:44 +02:00
										 |  |  | 			// Create a custom encode/decode pair to enforce payload size and number encoding | 
					
						
							|  |  |  | 			conn.MaxPayloadBytes = maxRequestContentLength | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 			encoder := func(v interface{}) error { | 
					
						
							|  |  |  | 				return websocketJSONCodec.Send(conn, v) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			decoder := func(v interface{}) error { | 
					
						
							|  |  |  | 				return websocketJSONCodec.Receive(conn, v) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 		}, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | // NewWSServer creates a new websocket RPC server around an API provider. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // Deprecated: use Server.WebsocketHandler | 
					
						
							| 
									
										
										
										
											2017-04-12 23:04:14 +02:00
										 |  |  | func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wsHandshakeValidator returns a handler that verifies the origin during the | 
					
						
							|  |  |  | // websocket upgrade process. When a '*' is specified as an allowed origins all | 
					
						
							|  |  |  | // connections are accepted. | 
					
						
							|  |  |  | func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	origins := mapset.NewSet() | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	allowAllOrigins := false | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, origin := range allowedOrigins { | 
					
						
							|  |  |  | 		if origin == "*" { | 
					
						
							|  |  |  | 			allowAllOrigins = true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if origin != "" { | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 			origins.Add(strings.ToLower(origin)) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 	// allow localhost if no allowedOrigins are specified. | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	if len(origins.ToSlice()) == 0 { | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		origins.Add("http://localhost") | 
					
						
							|  |  |  | 		if hostname, err := os.Hostname(); err == nil { | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 			origins.Add("http://" + strings.ToLower(hostname)) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	f := func(cfg *websocket.Config, req *http.Request) error { | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 		origin := strings.ToLower(req.Header.Get("Origin")) | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 		if allowAllOrigins || origins.Contains(origin) { | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 			return nil | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2017-11-10 01:22:06 -08:00
										 |  |  | 		log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		return fmt.Errorf("origin %s not allowed", origin) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return f | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	if origin == "" { | 
					
						
							|  |  |  | 		var err error | 
					
						
							|  |  |  | 		if origin, err = os.Hostname(); err != nil { | 
					
						
							|  |  |  | 			return nil, err | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if strings.HasPrefix(endpoint, "wss") { | 
					
						
							|  |  |  | 			origin = "https://" + strings.ToLower(origin) | 
					
						
							|  |  |  | 		} else { | 
					
						
							|  |  |  | 			origin = "http://" + strings.ToLower(origin) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	config, err := websocket.NewConfig(endpoint, origin) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	if config.Location.User != nil { | 
					
						
							|  |  |  | 		b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) | 
					
						
							|  |  |  | 		config.Header.Add("Authorization", "Basic "+b64auth) | 
					
						
							|  |  |  | 		config.Location.User = nil | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return config, nil | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server | 
					
						
							|  |  |  | // that is listening on the given endpoint. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The context is used for the initial connection establishment. It does not | 
					
						
							|  |  |  | // affect subsequent interactions with the client. | 
					
						
							|  |  |  | func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { | 
					
						
							|  |  |  | 	config, err := wsGetConfig(endpoint, origin) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	return newClient(ctx, func(ctx context.Context) (net.Conn, error) { | 
					
						
							|  |  |  | 		return wsDialContext(ctx, config) | 
					
						
							|  |  |  | 	}) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { | 
					
						
							|  |  |  | 	var conn net.Conn | 
					
						
							|  |  |  | 	var err error | 
					
						
							|  |  |  | 	switch config.Location.Scheme { | 
					
						
							|  |  |  | 	case "ws": | 
					
						
							|  |  |  | 		conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) | 
					
						
							|  |  |  | 	case "wss": | 
					
						
							|  |  |  | 		dialer := contextDialer(ctx) | 
					
						
							|  |  |  | 		conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) | 
					
						
							|  |  |  | 	default: | 
					
						
							|  |  |  | 		err = websocket.ErrBadScheme | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	ws, err := websocket.NewClient(config, conn) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		conn.Close() | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	return ws, err | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | var wsPortMap = map[string]string{"ws": "80", "wss": "443"} | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | func wsDialAddress(location *url.URL) string { | 
					
						
							|  |  |  | 	if _, ok := wsPortMap[location.Scheme]; ok { | 
					
						
							|  |  |  | 		if _, _, err := net.SplitHostPort(location.Host); err != nil { | 
					
						
							|  |  |  | 			return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	return location.Host | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | 
 | 
					
						
							|  |  |  | func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { | 
					
						
							|  |  |  | 	d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} | 
					
						
							|  |  |  | 	return d.DialContext(ctx, network, addr) | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func contextDialer(ctx context.Context) *net.Dialer { | 
					
						
							|  |  |  | 	dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} | 
					
						
							|  |  |  | 	if deadline, ok := ctx.Deadline(); ok { | 
					
						
							|  |  |  | 		dialer.Deadline = deadline | 
					
						
							|  |  |  | 	} else { | 
					
						
							|  |  |  | 		dialer.Deadline = time.Now().Add(defaultDialTimeout) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return dialer | 
					
						
							|  |  |  | } |