rpc: add SetHeader method to Client (#21392)
Resolves #20163 Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
		@@ -85,7 +85,7 @@ type Client struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	// writeConn is used for writing to the connection on the caller's goroutine. It should
 | 
						// writeConn is used for writing to the connection on the caller's goroutine. It should
 | 
				
			||||||
	// only be accessed outside of dispatch, with the write lock held. The write lock is
 | 
						// only be accessed outside of dispatch, with the write lock held. The write lock is
 | 
				
			||||||
	// taken by sending on requestOp and released by sending on sendDone.
 | 
						// taken by sending on reqInit and released by sending on reqSent.
 | 
				
			||||||
	writeConn jsonWriter
 | 
						writeConn jsonWriter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// for dispatch
 | 
						// for dispatch
 | 
				
			||||||
@@ -260,6 +260,19 @@ func (c *Client) Close() {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetHeader adds a custom HTTP header to the client's requests.
 | 
				
			||||||
 | 
					// This method only works for clients using HTTP, it doesn't have
 | 
				
			||||||
 | 
					// any effect for clients using another transport.
 | 
				
			||||||
 | 
					func (c *Client) SetHeader(key, value string) {
 | 
				
			||||||
 | 
						if !c.isHTTP {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						conn := c.writeConn.(*httpConn)
 | 
				
			||||||
 | 
						conn.mu.Lock()
 | 
				
			||||||
 | 
						conn.headers.Set(key, value)
 | 
				
			||||||
 | 
						conn.mu.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Call performs a JSON-RPC call with the given arguments and unmarshals into
 | 
					// Call performs a JSON-RPC call with the given arguments and unmarshals into
 | 
				
			||||||
// result if no error occurred.
 | 
					// result if no error occurred.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -26,6 +26,7 @@ import (
 | 
				
			|||||||
	"os"
 | 
						"os"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
@@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) {
 | 
				
			|||||||
	doTest(23000, true)
 | 
						doTest(23000, true)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestClientSetHeader(t *testing.T) {
 | 
				
			||||||
 | 
						var gotHeader bool
 | 
				
			||||||
 | 
						srv := newTestServer()
 | 
				
			||||||
 | 
						httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							if r.Header.Get("test") == "ok" {
 | 
				
			||||||
 | 
								gotHeader = true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							srv.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
						defer httpsrv.Close()
 | 
				
			||||||
 | 
						defer srv.Stop()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client, err := Dial(httpsrv.URL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer client.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						client.SetHeader("test", "ok")
 | 
				
			||||||
 | 
						if _, err := client.SupportedModules(); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !gotHeader {
 | 
				
			||||||
 | 
							t.Fatal("client did not set custom header")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Check that Content-Type can be replaced.
 | 
				
			||||||
 | 
						client.SetHeader("content-type", "application/x-garbage")
 | 
				
			||||||
 | 
						_, err = client.SupportedModules()
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal("no error for invalid content-type header")
 | 
				
			||||||
 | 
						} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
 | 
				
			||||||
 | 
							t.Fatalf("error is not related to content-type: %q", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestClientHTTP(t *testing.T) {
 | 
					func TestClientHTTP(t *testing.T) {
 | 
				
			||||||
	server := newTestServer()
 | 
						server := newTestServer()
 | 
				
			||||||
	defer server.Stop()
 | 
						defer server.Stop()
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										37
									
								
								rpc/http.go
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								rpc/http.go
									
									
									
									
									
								
							@@ -26,6 +26,7 @@ import (
 | 
				
			|||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"mime"
 | 
						"mime"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type httpConn struct {
 | 
					type httpConn struct {
 | 
				
			||||||
	client    *http.Client
 | 
						client    *http.Client
 | 
				
			||||||
	req       *http.Request
 | 
						url       string
 | 
				
			||||||
	closeOnce sync.Once
 | 
						closeOnce sync.Once
 | 
				
			||||||
	closeCh   chan interface{}
 | 
						closeCh   chan interface{}
 | 
				
			||||||
 | 
						mu        sync.Mutex // protects headers
 | 
				
			||||||
 | 
						headers   http.Header
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// httpConn is treated specially by Client.
 | 
					// httpConn is treated specially by Client.
 | 
				
			||||||
@@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hc *httpConn) remoteAddr() string {
 | 
					func (hc *httpConn) remoteAddr() string {
 | 
				
			||||||
	return hc.req.URL.String()
 | 
						return hc.url
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
 | 
					func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
 | 
				
			||||||
@@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
 | 
				
			|||||||
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
 | 
					// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
 | 
				
			||||||
// using the provided HTTP Client.
 | 
					// using the provided HTTP Client.
 | 
				
			||||||
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
 | 
					func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
 | 
				
			||||||
	req, err := http.NewRequest(http.MethodPost, endpoint, nil)
 | 
						// Sanity check URL so we don't end up with a client that will fail every request.
 | 
				
			||||||
 | 
						_, err := url.Parse(endpoint)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req.Header.Set("Content-Type", contentType)
 | 
					 | 
				
			||||||
	req.Header.Set("Accept", contentType)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	initctx := context.Background()
 | 
						initctx := context.Background()
 | 
				
			||||||
 | 
						headers := make(http.Header, 2)
 | 
				
			||||||
 | 
						headers.Set("accept", contentType)
 | 
				
			||||||
 | 
						headers.Set("content-type", contentType)
 | 
				
			||||||
	return newClient(initctx, func(context.Context) (ServerCodec, error) {
 | 
						return newClient(initctx, func(context.Context) (ServerCodec, error) {
 | 
				
			||||||
		return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
 | 
							hc := &httpConn{
 | 
				
			||||||
 | 
								client:  client,
 | 
				
			||||||
 | 
								headers: headers,
 | 
				
			||||||
 | 
								url:     endpoint,
 | 
				
			||||||
 | 
								closeCh: make(chan interface{}),
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return hc, nil
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
 | 
				
			|||||||
		if respBody != nil {
 | 
							if respBody != nil {
 | 
				
			||||||
			buf := new(bytes.Buffer)
 | 
								buf := new(bytes.Buffer)
 | 
				
			||||||
			if _, err2 := buf.ReadFrom(respBody); err2 == nil {
 | 
								if _, err2 := buf.ReadFrom(respBody); err2 == nil {
 | 
				
			||||||
				return fmt.Errorf("%v %v", err, buf.String())
 | 
									return fmt.Errorf("%v: %v", err, buf.String())
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
@@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	req := hc.req.WithContext(ctx)
 | 
						req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
 | 
				
			||||||
	req.Body = ioutil.NopCloser(bytes.NewReader(body))
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	req.ContentLength = int64(len(body))
 | 
						req.ContentLength = int64(len(body))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// set headers
 | 
				
			||||||
 | 
						hc.mu.Lock()
 | 
				
			||||||
 | 
						req.Header = hc.headers.Clone()
 | 
				
			||||||
 | 
						hc.mu.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// do request
 | 
				
			||||||
	resp, err := hc.client.Do(req)
 | 
						resp, err := hc.client.Do(req)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user