rpc: implement full bi-directional communication (#18471)

New APIs added:

    client.RegisterName(namespace, service) // makes service available to server
    client.Notify(ctx, method, args...)     // sends a notification
    ClientFromContext(ctx)                  // to get a client in handler method

This is essentially a rewrite of the server-side code. JSON-RPC
processing code is now the same on both server and client side. Many
minor issues were fixed in the process and there is a new test suite for
JSON-RPC spec compliance (and non-compliance in some cases).

List of behavior changes:

- Method handlers are now called with a per-request context instead of a
  per-connection context. The context is canceled right after the method
  returns.
- Subscription error channels are always closed when the connection
  ends. There is no need to also wait on the Notifier's Closed channel
  to detect whether the subscription has ended.
- Client now omits "params" instead of sending "params": null when there
  are no arguments to a call. The previous behavior was not compliant
  with the spec. The server still accepts "params": null.
- Floating point numbers are allowed as "id". The spec doesn't allow
  them, but we handle request "id" as json.RawMessage and guarantee that
  the same number will be sent back.
- Logging is improved significantly. There is now a message at DEBUG
  level for each RPC call served.
This commit is contained in:
Felix Lange
2019-02-04 13:47:34 +01:00
committed by GitHub
parent ec3432bccb
commit 245f3146c2
36 changed files with 2211 additions and 2169 deletions

View File

@ -22,6 +22,7 @@ import (
"crypto/tls"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
@ -56,24 +57,39 @@ var websocketJSONCodec = websocket.Codec{
//
// allowedOrigins should be a comma-separated list of allowed origin URLs.
// To allow connections with any origin, pass "*".
func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
return websocket.Server{
Handshake: wsHandshakeValidator(allowedOrigins),
Handler: func(conn *websocket.Conn) {
// Create a custom encode/decode pair to enforce payload size and number encoding
conn.MaxPayloadBytes = maxRequestContentLength
encoder := func(v interface{}) error {
return websocketJSONCodec.Send(conn, v)
}
decoder := func(v interface{}) error {
return websocketJSONCodec.Receive(conn, v)
}
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
codec := newWebsocketCodec(conn)
s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
},
}
}
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
// Create a custom encode/decode pair to enforce payload size and number encoding
conn.MaxPayloadBytes = maxRequestContentLength
encoder := func(v interface{}) error {
return websocketJSONCodec.Send(conn, v)
}
decoder := func(v interface{}) error {
return websocketJSONCodec.Receive(conn, v)
}
rpcconn := Conn(conn)
if conn.IsServerConn() {
// Override remote address with the actual socket address because
// package websocket crashes if there is no request origin.
addr := conn.Request().RemoteAddr
if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil {
// Add origin if present.
addr += "(" + wsaddr.URL.String() + ")"
}
rpcconn = connWithRemoteAddr{conn, addr}
}
return NewCodec(rpcconn, encoder, decoder)
}
// NewWSServer creates a new websocket RPC server around an API provider.
//
// Deprecated: use Server.WebsocketHandler
@ -105,15 +121,16 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http
}
}
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice()))
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
f := func(cfg *websocket.Config, req *http.Request) error {
// Verify origin against whitelist.
origin := strings.ToLower(req.Header.Get("Origin"))
if allowAllOrigins || origins.Contains(origin) {
return nil
}
log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
return fmt.Errorf("origin %s not allowed", origin)
log.Warn("Rejected WebSocket connection", "origin", origin)
return errors.New("origin not allowed")
}
return f
@ -155,8 +172,12 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
return nil, err
}
return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
return wsDialContext(ctx, config)
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
conn, err := wsDialContext(ctx, config)
if err != nil {
return nil, err
}
return newWebsocketCodec(conn), nil
})
}