rpc: add SetHeader method to Client (#21392)

Resolves #20163

Co-authored-by: Felix Lange <fjl@twurst.com>
This commit is contained in:
rene
2020-08-03 14:08:42 +02:00
committed by GitHub
parent 9c2ac6fbd5
commit 290d6bd903
3 changed files with 79 additions and 10 deletions

View File

@ -26,6 +26,7 @@ import (
"io/ioutil"
"mime"
"net/http"
"net/url"
"sync"
"time"
)
@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
type httpConn struct {
client *http.Client
req *http.Request
url string
closeOnce sync.Once
closeCh chan interface{}
mu sync.Mutex // protects headers
headers http.Header
}
// httpConn is treated specially by Client.
@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
}
func (hc *httpConn) remoteAddr() string {
return hc.req.URL.String()
return hc.url
}
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
// using the provided HTTP Client.
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
req, err := http.NewRequest(http.MethodPost, endpoint, nil)
// Sanity check URL so we don't end up with a client that will fail every request.
_, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", contentType)
req.Header.Set("Accept", contentType)
initctx := context.Background()
headers := make(http.Header, 2)
headers.Set("accept", contentType)
headers.Set("content-type", contentType)
return newClient(initctx, func(context.Context) (ServerCodec, error) {
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
hc := &httpConn{
client: client,
headers: headers,
url: endpoint,
closeCh: make(chan interface{}),
}
return hc, nil
})
}
@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
if respBody != nil {
buf := new(bytes.Buffer)
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
return fmt.Errorf("%v %v", err, buf.String())
return fmt.Errorf("%v: %v", err, buf.String())
}
}
return err
@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
if err != nil {
return nil, err
}
req := hc.req.WithContext(ctx)
req.Body = ioutil.NopCloser(bytes.NewReader(body))
req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
if err != nil {
return nil, err
}
req.ContentLength = int64(len(body))
// set headers
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
// do request
resp, err := hc.client.Do(req)
if err != nil {
return nil, err