rpc: implement full bi-directional communication (#18471)
New APIs added: client.RegisterName(namespace, service) // makes service available to server client.Notify(ctx, method, args...) // sends a notification ClientFromContext(ctx) // to get a client in handler method This is essentially a rewrite of the server-side code. JSON-RPC processing code is now the same on both server and client side. Many minor issues were fixed in the process and there is a new test suite for JSON-RPC spec compliance (and non-compliance in some cases). List of behavior changes: - Method handlers are now called with a per-request context instead of a per-connection context. The context is canceled right after the method returns. - Subscription error channels are always closed when the connection ends. There is no need to also wait on the Notifier's Closed channel to detect whether the subscription has ended. - Client now omits "params" instead of sending "params": null when there are no arguments to a call. The previous behavior was not compliant with the spec. The server still accepts "params": null. - Floating point numbers are allowed as "id". The spec doesn't allow them, but we handle request "id" as json.RawMessage and guarantee that the same number will be sent back. - Logging is improved significantly. There is now a message at DEBUG level for each RPC call served.
This commit is contained in:
@ -22,6 +22,7 @@ import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -56,24 +57,39 @@ var websocketJSONCodec = websocket.Codec{
|
||||
//
|
||||
// allowedOrigins should be a comma-separated list of allowed origin URLs.
|
||||
// To allow connections with any origin, pass "*".
|
||||
func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
||||
func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
|
||||
return websocket.Server{
|
||||
Handshake: wsHandshakeValidator(allowedOrigins),
|
||||
Handler: func(conn *websocket.Conn) {
|
||||
// Create a custom encode/decode pair to enforce payload size and number encoding
|
||||
conn.MaxPayloadBytes = maxRequestContentLength
|
||||
|
||||
encoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Send(conn, v)
|
||||
}
|
||||
decoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Receive(conn, v)
|
||||
}
|
||||
srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions)
|
||||
codec := newWebsocketCodec(conn)
|
||||
s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
|
||||
// Create a custom encode/decode pair to enforce payload size and number encoding
|
||||
conn.MaxPayloadBytes = maxRequestContentLength
|
||||
encoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Send(conn, v)
|
||||
}
|
||||
decoder := func(v interface{}) error {
|
||||
return websocketJSONCodec.Receive(conn, v)
|
||||
}
|
||||
rpcconn := Conn(conn)
|
||||
if conn.IsServerConn() {
|
||||
// Override remote address with the actual socket address because
|
||||
// package websocket crashes if there is no request origin.
|
||||
addr := conn.Request().RemoteAddr
|
||||
if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil {
|
||||
// Add origin if present.
|
||||
addr += "(" + wsaddr.URL.String() + ")"
|
||||
}
|
||||
rpcconn = connWithRemoteAddr{conn, addr}
|
||||
}
|
||||
return NewCodec(rpcconn, encoder, decoder)
|
||||
}
|
||||
|
||||
// NewWSServer creates a new websocket RPC server around an API provider.
|
||||
//
|
||||
// Deprecated: use Server.WebsocketHandler
|
||||
@ -105,15 +121,16 @@ func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice()))
|
||||
log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
|
||||
|
||||
f := func(cfg *websocket.Config, req *http.Request) error {
|
||||
// Verify origin against whitelist.
|
||||
origin := strings.ToLower(req.Header.Get("Origin"))
|
||||
if allowAllOrigins || origins.Contains(origin) {
|
||||
return nil
|
||||
}
|
||||
log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
|
||||
return fmt.Errorf("origin %s not allowed", origin)
|
||||
log.Warn("Rejected WebSocket connection", "origin", origin)
|
||||
return errors.New("origin not allowed")
|
||||
}
|
||||
|
||||
return f
|
||||
@ -155,8 +172,12 @@ func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newClient(ctx, func(ctx context.Context) (net.Conn, error) {
|
||||
return wsDialContext(ctx, config)
|
||||
return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
|
||||
conn, err := wsDialContext(ctx, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newWebsocketCodec(conn), nil
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user