| 
									
										
										
										
											2019-07-22 12:17:27 +03:00
										 |  |  | // Copyright 2018 The go-ethereum Authors | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | // This file is part of the go-ethereum library. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is free software: you can redistribute it and/or modify | 
					
						
							|  |  |  | // it under the terms of the GNU Lesser General Public License as published by | 
					
						
							|  |  |  | // the Free Software Foundation, either version 3 of the License, or | 
					
						
							|  |  |  | // (at your option) any later version. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // The go-ethereum library is distributed in the hope that it will be useful, | 
					
						
							|  |  |  | // but WITHOUT ANY WARRANTY; without even the implied warranty of | 
					
						
							|  |  |  | // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | 
					
						
							|  |  |  | // GNU Lesser General Public License for more details. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // You should have received a copy of the GNU Lesser General Public License | 
					
						
							|  |  |  | // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | package rpc | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | import ( | 
					
						
							|  |  |  | 	"context" | 
					
						
							|  |  |  | 	"net" | 
					
						
							|  |  |  | 	"net/http" | 
					
						
							|  |  |  | 	"net/http/httptest" | 
					
						
							|  |  |  | 	"reflect" | 
					
						
							|  |  |  | 	"strings" | 
					
						
							|  |  |  | 	"testing" | 
					
						
							|  |  |  | 	"time" | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	"github.com/gorilla/websocket" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func TestWebsocketClientHeaders(t *testing.T) { | 
					
						
							|  |  |  | 	t.Parallel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	endpoint, header, err := wsClientHeaders("wss://testuser:test-PASS_01@example.com:1234", "https://example.com") | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2019-07-18 14:21:24 +02:00
										 |  |  | 		t.Fatalf("wsGetConfig failed: %s", err) | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	if endpoint != "wss://example.com:1234" { | 
					
						
							|  |  |  | 		t.Fatal("User should have been stripped from the URL") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if header.Get("authorization") != "Basic dGVzdHVzZXI6dGVzdC1QQVNTXzAx" { | 
					
						
							|  |  |  | 		t.Fatal("Basic auth header is incorrect") | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	if header.Get("origin") != "https://example.com" { | 
					
						
							|  |  |  | 		t.Fatal("Origin not set") | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | // This test checks that the server rejects connections from disallowed origins. | 
					
						
							|  |  |  | func TestWebsocketOriginCheck(t *testing.T) { | 
					
						
							|  |  |  | 	t.Parallel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		srv     = newTestServer() | 
					
						
							|  |  |  | 		httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"http://example.com"})) | 
					
						
							|  |  |  | 		wsURL   = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 	defer srv.Stop() | 
					
						
							|  |  |  | 	defer httpsrv.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	client, err := DialWebsocket(context.Background(), wsURL, "http://ekzample.com") | 
					
						
							|  |  |  | 	if err == nil { | 
					
						
							|  |  |  | 		client.Close() | 
					
						
							|  |  |  | 		t.Fatal("no error for wrong origin") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	wantErr := wsHandshakeError{websocket.ErrBadHandshake, "403 Forbidden"} | 
					
						
							|  |  |  | 	if !reflect.DeepEqual(err, wantErr) { | 
					
						
							|  |  |  | 		t.Fatalf("wrong error for wrong origin: %q", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Connections without origin header should work. | 
					
						
							|  |  |  | 	client, err = DialWebsocket(context.Background(), wsURL, "") | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	if err != nil { | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 		t.Fatal("error for empty origin") | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	client.Close() | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // This test checks whether calls exceeding the request size limit are rejected. | 
					
						
							|  |  |  | func TestWebsocketLargeCall(t *testing.T) { | 
					
						
							|  |  |  | 	t.Parallel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		srv     = newTestServer() | 
					
						
							|  |  |  | 		httpsrv = httptest.NewServer(srv.WebsocketHandler([]string{"*"})) | 
					
						
							|  |  |  | 		wsURL   = "ws:" + strings.TrimPrefix(httpsrv.URL, "http:") | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 	defer srv.Stop() | 
					
						
							|  |  |  | 	defer httpsrv.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	client, err := DialWebsocket(context.Background(), wsURL, "") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("can't dial: %v", err) | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	defer client.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// This call sends slightly less than the limit and should work. | 
					
						
							| 
									
										
										
										
											2019-11-20 09:06:21 +01:00
										 |  |  | 	var result echoResult | 
					
						
							| 
									
										
										
										
											2019-07-22 12:22:39 +02:00
										 |  |  | 	arg := strings.Repeat("x", maxRequestContentLength-200) | 
					
						
							|  |  |  | 	if err := client.Call(&result, "test_echo", arg, 1); err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("valid call didn't work: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if result.String != arg { | 
					
						
							|  |  |  | 		t.Fatal("wrong string echoed") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// This call sends twice the allowed size and shouldn't work. | 
					
						
							|  |  |  | 	arg = strings.Repeat("x", maxRequestContentLength*2) | 
					
						
							|  |  |  | 	err = client.Call(&result, "test_echo", arg) | 
					
						
							|  |  |  | 	if err == nil { | 
					
						
							|  |  |  | 		t.Fatal("no error for too large call") | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // This test checks that client handles WebSocket ping frames correctly. | 
					
						
							|  |  |  | func TestClientWebsocketPing(t *testing.T) { | 
					
						
							|  |  |  | 	t.Parallel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		sendPing    = make(chan struct{}) | 
					
						
							|  |  |  | 		server      = wsPingTestServer(t, sendPing) | 
					
						
							|  |  |  | 		ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 	defer cancel() | 
					
						
							|  |  |  | 	defer server.Shutdown(ctx) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	client, err := DialContext(ctx, "ws://"+server.Addr) | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("client dial error: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	resultChan := make(chan int) | 
					
						
							|  |  |  | 	sub, err := client.EthSubscribe(ctx, resultChan, "foo") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatalf("client subscribe error: %v", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Wait for the context's deadline to be reached before proceeding. | 
					
						
							|  |  |  | 	// This is important for reproducing https://github.com/ethereum/go-ethereum/issues/19798 | 
					
						
							|  |  |  | 	<-ctx.Done() | 
					
						
							|  |  |  | 	close(sendPing) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Wait for the subscription result. | 
					
						
							|  |  |  | 	timeout := time.NewTimer(5 * time.Second) | 
					
						
							|  |  |  | 	for { | 
					
						
							|  |  |  | 		select { | 
					
						
							|  |  |  | 		case err := <-sub.Err(): | 
					
						
							|  |  |  | 			t.Error("client subscription error:", err) | 
					
						
							|  |  |  | 		case result := <-resultChan: | 
					
						
							|  |  |  | 			t.Log("client got result:", result) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		case <-timeout.C: | 
					
						
							|  |  |  | 			t.Error("didn't get any result within the test timeout") | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // wsPingTestServer runs a WebSocket server which accepts a single subscription request. | 
					
						
							|  |  |  | // When a value arrives on sendPing, the server sends a ping frame, waits for a matching | 
					
						
							|  |  |  | // pong and finally delivers a single subscription result. | 
					
						
							|  |  |  | func wsPingTestServer(t *testing.T, sendPing <-chan struct{}) *http.Server { | 
					
						
							|  |  |  | 	var srv http.Server | 
					
						
							|  |  |  | 	shutdown := make(chan struct{}) | 
					
						
							|  |  |  | 	srv.RegisterOnShutdown(func() { | 
					
						
							|  |  |  | 		close(shutdown) | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	srv.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 
					
						
							|  |  |  | 		// Upgrade to WebSocket. | 
					
						
							|  |  |  | 		upgrader := websocket.Upgrader{ | 
					
						
							|  |  |  | 			CheckOrigin: func(r *http.Request) bool { return true }, | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		conn, err := upgrader.Upgrade(w, r, nil) | 
					
						
							|  |  |  | 		if err != nil { | 
					
						
							|  |  |  | 			t.Errorf("server WS upgrade error: %v", err) | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 		defer conn.Close() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 		// Handle the connection. | 
					
						
							|  |  |  | 		wsPingTestHandler(t, conn, shutdown, sendPing) | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Start the server. | 
					
						
							|  |  |  | 	listener, err := net.Listen("tcp", "127.0.0.1:0") | 
					
						
							|  |  |  | 	if err != nil { | 
					
						
							|  |  |  | 		t.Fatal("can't listen:", err) | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	srv.Addr = listener.Addr().String() | 
					
						
							|  |  |  | 	go srv.Serve(listener) | 
					
						
							|  |  |  | 	return &srv | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | func wsPingTestHandler(t *testing.T, conn *websocket.Conn, shutdown, sendPing <-chan struct{}) { | 
					
						
							|  |  |  | 	// Canned responses for the eth_subscribe call in TestClientWebsocketPing. | 
					
						
							|  |  |  | 	const ( | 
					
						
							|  |  |  | 		subResp   = `{"jsonrpc":"2.0","id":1,"result":"0x00"}` | 
					
						
							|  |  |  | 		subNotify = `{"jsonrpc":"2.0","method":"eth_subscription","params":{"subscription":"0x00","result":1}}` | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Handle subscribe request. | 
					
						
							|  |  |  | 	if _, _, err := conn.ReadMessage(); err != nil { | 
					
						
							|  |  |  | 		t.Errorf("server read error: %v", err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 	if err := conn.WriteMessage(websocket.TextMessage, []byte(subResp)); err != nil { | 
					
						
							|  |  |  | 		t.Errorf("server write error: %v", err) | 
					
						
							|  |  |  | 		return | 
					
						
							|  |  |  | 	} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Read from the connection to process control messages. | 
					
						
							|  |  |  | 	var pongCh = make(chan string) | 
					
						
							|  |  |  | 	conn.SetPongHandler(func(d string) error { | 
					
						
							|  |  |  | 		t.Logf("server got pong: %q", d) | 
					
						
							|  |  |  | 		pongCh <- d | 
					
						
							|  |  |  | 		return nil | 
					
						
							|  |  |  | 	}) | 
					
						
							|  |  |  | 	go func() { | 
					
						
							|  |  |  | 		for { | 
					
						
							|  |  |  | 			typ, msg, err := conn.ReadMessage() | 
					
						
							|  |  |  | 			if err != nil { | 
					
						
							|  |  |  | 				return | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			t.Logf("server got message (%d): %q", typ, msg) | 
					
						
							|  |  |  | 		} | 
					
						
							|  |  |  | 	}() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Write messages. | 
					
						
							|  |  |  | 	var ( | 
					
						
							|  |  |  | 		sendResponse <-chan time.Time | 
					
						
							|  |  |  | 		wantPong     string | 
					
						
							|  |  |  | 	) | 
					
						
							|  |  |  | 	for { | 
					
						
							|  |  |  | 		select { | 
					
						
							|  |  |  | 		case _, open := <-sendPing: | 
					
						
							|  |  |  | 			if !open { | 
					
						
							|  |  |  | 				sendPing = nil | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			t.Logf("server sending ping") | 
					
						
							|  |  |  | 			conn.WriteMessage(websocket.PingMessage, []byte("ping")) | 
					
						
							|  |  |  | 			wantPong = "ping" | 
					
						
							|  |  |  | 		case data := <-pongCh: | 
					
						
							|  |  |  | 			if wantPong == "" { | 
					
						
							|  |  |  | 				t.Errorf("unexpected pong") | 
					
						
							|  |  |  | 			} else if data != wantPong { | 
					
						
							|  |  |  | 				t.Errorf("got pong with wrong data %q", data) | 
					
						
							|  |  |  | 			} | 
					
						
							|  |  |  | 			wantPong = "" | 
					
						
							|  |  |  | 			sendResponse = time.NewTimer(200 * time.Millisecond).C | 
					
						
							|  |  |  | 		case <-sendResponse: | 
					
						
							|  |  |  | 			t.Logf("server sending response") | 
					
						
							|  |  |  | 			conn.WriteMessage(websocket.TextMessage, []byte(subNotify)) | 
					
						
							|  |  |  | 			sendResponse = nil | 
					
						
							|  |  |  | 		case <-shutdown: | 
					
						
							|  |  |  | 			conn.Close() | 
					
						
							|  |  |  | 			return | 
					
						
							|  |  |  | 		} | 
					
						
							| 
									
										
										
										
											2018-09-19 12:09:03 -04:00
										 |  |  | 	} | 
					
						
							|  |  |  | } |