| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | // Copyright 2015 The go-ethereum Authors | 
					
						
							|  |  |  | // This file is part of the go-ethereum library. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Lesser General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | 
					
						
							|  |  |  | // GNU Lesser General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Lesser General Public License | 
					
						
							|  |  |  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package rpc | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import ( | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | 	"context" | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	"encoding/base64" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"fmt" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	"net/url" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	"os" | 
					
						
							| 
									
										
										
										
											2016-02-05 15:08:48 +02:00
										 |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	"sync" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	mapset "github.com/deckarep/golang-set" | 
					
						
							| 
									
										
										
										
											2017-02-22 14:10:07 +02:00
										 |  |  | 	"github.com/ethereum/go-ethereum/log" | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	"github.com/gorilla/websocket" | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | const ( | 
					
						
							|  |  |  | 	wsReadBuffer  = 1024 | 
					
						
							|  |  |  | 	wsWriteBuffer = 1024 | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | var wsBufferPool = new(sync.Pool) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // NewWSServer creates a new websocket RPC server around an API provider. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // Deprecated: use Server.WebsocketHandler | 
					
						
							|  |  |  | func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { | 
					
						
							|  |  |  | 	return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} | 
					
						
							| 
									
										
										
										
											2018-03-13 13:23:44 +02:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // allowedOrigins should be a comma-separated list of allowed origin URLs. | 
					
						
							|  |  |  | // To allow connections with any origin, pass "*". | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	var upgrader = websocket.Upgrader{ | 
					
						
							|  |  |  | 		ReadBufferSize:  wsReadBuffer, | 
					
						
							|  |  |  | 		WriteBufferSize: wsWriteBuffer, | 
					
						
							|  |  |  | 		WriteBufferPool: wsBufferPool, | 
					
						
							|  |  |  | 		CheckOrigin:     wsHandshakeValidator(allowedOrigins), | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 
					
						
							|  |  |  | 		conn, err := upgrader.Upgrade(w, r, nil) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			log.Debug("WebSocket upgrade failed", "err", err) | 
					
						
							|  |  |  | 			return | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 		codec := newWebsocketCodec(conn) | 
					
						
							|  |  |  | 		s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) | 
					
						
							|  |  |  | 	}) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wsHandshakeValidator returns a handler that verifies the origin during the | 
					
						
							|  |  |  | // websocket upgrade process. When a '*' is specified as an allowed origins all | 
					
						
							|  |  |  | // connections are accepted. | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	origins := mapset.NewSet() | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	allowAllOrigins := false | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	for _, origin := range allowedOrigins { | 
					
						
							|  |  |  | 		if origin == "*" { | 
					
						
							|  |  |  | 			allowAllOrigins = true | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		if origin != "" { | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 			origins.Add(strings.ToLower(origin)) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 	// allow localhost if no allowedOrigins are specified. | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 	if len(origins.ToSlice()) == 0 { | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		origins.Add("http://localhost") | 
					
						
							|  |  |  | 		if hostname, err := os.Hostname(); err == nil { | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 			origins.Add("http://" + strings.ToLower(hostname)) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	f := func(req *http.Request) bool { | 
					
						
							| 
									
										
										
										
											2019-02-19 11:49:43 +01:00
										 |  |  | 		// Skip origin verification if no Origin header is present. The origin check | 
					
						
							|  |  |  | 		// is supposed to protect against browser based attacks. Browsers always set | 
					
						
							|  |  |  | 		// Origin. Non-browser software can put anything in origin and checking it doesn't | 
					
						
							|  |  |  | 		// provide additional security. | 
					
						
							|  |  |  | 		if _, ok := req.Header["Origin"]; !ok { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 			return true | 
					
						
							| 
									
										
										
										
											2019-02-19 11:49:43 +01:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 		// Verify origin against whitelist. | 
					
						
							| 
									
										
										
										
											2016-05-10 18:01:58 +02:00
										 |  |  | 		origin := strings.ToLower(req.Header.Get("Origin")) | 
					
						
							| 
									
										
										
										
											2018-07-16 00:54:19 -07:00
										 |  |  | 		if allowAllOrigins || origins.Contains(origin) { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 			return true | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 		} | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 		log.Warn("Rejected WebSocket connection", "origin", origin) | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 		return false | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	return f | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | type wsHandshakeError struct { | 
					
						
							|  |  |  | 	err    error | 
					
						
							|  |  |  | 	status string | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | func (e wsHandshakeError) Error() string { | 
					
						
							|  |  |  | 	s := e.err.Error() | 
					
						
							|  |  |  | 	if e.status != "" { | 
					
						
							|  |  |  | 		s += " (HTTP status " + e.status + ")" | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	return s | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server | 
					
						
							|  |  |  | // that is listening on the given endpoint. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The context is used for the initial connection establishment. It does not | 
					
						
							|  |  |  | // affect subsequent interactions with the client. | 
					
						
							|  |  |  | func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	endpoint, header, err := wsClientHeaders(endpoint, origin) | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		return nil, err | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	dialer := websocket.Dialer{ | 
					
						
							|  |  |  | 		ReadBufferSize:  wsReadBuffer, | 
					
						
							|  |  |  | 		WriteBufferSize: wsWriteBuffer, | 
					
						
							|  |  |  | 		WriteBufferPool: wsBufferPool, | 
					
						
							|  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 		conn, resp, err := dialer.DialContext(ctx, endpoint, header) | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 			hErr := wsHandshakeError{err: err} | 
					
						
							|  |  |  | 			if resp != nil { | 
					
						
							|  |  |  | 				hErr.status = resp.Status | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			return nil, hErr | 
					
						
							| 
									
										
										
										
											2019-02-04 13:47:34 +01:00
										 |  |  | 		} | 
					
						
							|  |  |  | 		return newWebsocketCodec(conn), nil | 
					
						
							| 
									
										
										
										
											2016-07-12 17:47:15 +02:00
										 |  |  | 	}) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { | 
					
						
							|  |  |  | 	endpointURL, err := url.Parse(endpoint) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 		return endpoint, nil, err | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	header := make(http.Header) | 
					
						
							|  |  |  | 	if origin != "" { | 
					
						
							|  |  |  | 		header.Add("origin", origin) | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	if endpointURL.User != nil { | 
					
						
							|  |  |  | 		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) | 
					
						
							|  |  |  | 		header.Add("authorization", "Basic "+b64auth) | 
					
						
							|  |  |  | 		endpointURL.User = nil | 
					
						
							| 
									
										
										
										
											2015-12-16 10:58:01 +01:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	return endpointURL.String(), header, nil | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | func newWebsocketCodec(conn *websocket.Conn) ServerCodec { | 
					
						
							|  |  |  | 	conn.SetReadLimit(maxRequestContentLength) | 
					
						
							|  |  |  | 	return newCodec(conn, conn.WriteJSON, conn.ReadJSON) | 
					
						
							| 
									
										
										
										
											2017-03-22 18:20:33 +01:00
										 |  |  | } |