From d03557acd2026803eb6fd101c085ae77ee15087b Mon Sep 17 00:00:00 2001 From: Camille Meulien Date: Sat, 2 May 2020 22:34:33 +0300 Subject: [PATCH 1/5] Adapt and refactore pkg/msg --- pkg/message/msg_test.go | 48 -------------------- pkg/message/parser.go | 51 ---------------------- pkg/msg/msg.go | 58 ++++++++++++++++++++++++ pkg/msg/msg_test.go | 18 ++++++++ pkg/msg/parser.go | 97 +++++++++++++++++++++++++++++++++++++++++ pkg/msg/parser_test.go | 95 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 268 insertions(+), 99 deletions(-) delete mode 100644 pkg/message/msg_test.go delete mode 100644 pkg/message/parser.go create mode 100644 pkg/msg/msg.go create mode 100644 pkg/msg/msg_test.go create mode 100644 pkg/msg/parser.go create mode 100644 pkg/msg/parser_test.go diff --git a/pkg/message/msg_test.go b/pkg/message/msg_test.go deleted file mode 100644 index 5249dd1..0000000 --- a/pkg/message/msg_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package message - -import ( - "errors" - "fmt" - "testing" -) - -func TestMsg_Response(t *testing.T) { - t.Run("no error", func(t *testing.T) { - res, err := PackOutgoingResponse(nil, "ok") - fmt.Println(string(res)) - - if err != nil { - t.Fatal("Should not return error") - } - - if string(res) != `{"success":"ok"}` { - t.Fatal("Response invalid") - } - }) - - t.Run("Some error", func(t *testing.T) { - res, err := PackOutgoingResponse(errors.New("Some Error"), "ok") - fmt.Println(string(res)) - - if err != nil { - t.Fatal("Should not return error") - } - - if string(res) != `{"error":"Some Error"}` { - t.Fatal("Response invalid") - } - }) -} - -func TestMsg_Event(t *testing.T) { - res, err := PackOutgoingEvent("someMethod", "Hello") - fmt.Println(string(res)) - - if err != nil { - t.Fatal("Should not return error") - } - - if string(res) != `{"someMethod":"Hello"}` { - t.Fatal("Event invalid") - } -} diff --git a/pkg/message/parser.go b/pkg/message/parser.go deleted file mode 100644 index b0ae436..0000000 --- a/pkg/message/parser.go +++ /dev/null @@ -1,51 +0,0 @@ -package message - -import ( - "encoding/json" - "errors" - "fmt" - "reflect" -) - -func ParseRequest(msg []byte) (Request, error) { - request, err := Parse(msg) - if err != nil { - return request, err - } - - return request, nil -} - -func Parse(msg []byte) (Request, error) { - var v map[string]interface{} - var parsed Request - - if err := json.Unmarshal(msg, &v); err != nil { - return parsed, fmt.Errorf("Could not parse message: %w", err) - } - - switch v["event"] { - case "subscribe": - parsed.Method = "subscribe" - switch reflect.TypeOf(v["streams"]).Kind() { - case reflect.Slice: - streams := reflect.ValueOf(v["streams"]) - for i := 0; i < streams.Len(); i++ { - parsed.Streams = append(parsed.Streams, streams.Index(i).Interface().(string)) - } - } - case "unsubscribe": - parsed.Method = "unsubscribe" - switch reflect.TypeOf(v["streams"]).Kind() { - case reflect.Slice: - streams := reflect.ValueOf(v["streams"]) - for i := 0; i < streams.Len(); i++ { - parsed.Streams = append(parsed.Streams, streams.Index(i).Interface().(string)) - } - } - default: - return parsed, errors.New("Could not parse Type: Invalid event") - } - - return parsed, nil -} diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go new file mode 100644 index 0000000..27a280a --- /dev/null +++ b/pkg/msg/msg.go @@ -0,0 +1,58 @@ +package msg + +import ( + "encoding/json" + + "github.com/rs/zerolog/log" +) + +// Request identifier +const Request = 0 + +// Response identifier +const Response = 1 + +// Event identifier +const Event = 2 + +// Msg represent websocket messages, it could be either a request, a response or an event +type Msg struct { + Type uint8 + ReqID uint64 + Method string + Args []interface{} +} + +// NewResponse build a response object +func NewResponse(req *Msg, method string, args []interface{}) *Msg { + return &Msg{ + Type: Response, + ReqID: req.ReqID, + Method: method, + Args: args, + } +} + +// Encode msg into json +func (m *Msg) Encode() []byte { + s, err := json.Marshal([]interface{}{ + m.Type, + m.ReqID, + m.Method, + m.Args, + }) + if err != nil { + log.Error().Msgf("Fail to encode Msg %v, %s", m, err.Error()) + return []byte{} + } + return s +} + +// Convss2is converts a string slice to interface slice more details: https://golang.org/doc/faq#convert_slice_of_interface) +func Convss2is(a []string) []interface{} { + s := make([]interface{}, len(a)) + for i, v := range a { + s[i] = v + } + return s +} diff --git a/pkg/msg/msg_test.go b/pkg/msg/msg_test.go new file mode 100644 index 0000000..97a2ec2 --- /dev/null +++ b/pkg/msg/msg_test.go @@ -0,0 +1,18 @@ +package msg + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEncoding(t *testing.T) { + msg := Msg{ + Type: Request, + ReqID: 42, + Method: "test", + Args: []interface{}{"hello", "there"}, + } + + assert.Equal(t, `[0,42,"test",["hello","there"]]`, string(msg.Encode())) +} diff --git a/pkg/msg/parser.go b/pkg/msg/parser.go new file mode 100644 index 0000000..a56d719 --- /dev/null +++ b/pkg/msg/parser.go @@ -0,0 +1,97 @@ +package msg + +import ( + "encoding/json" + "errors" + "fmt" + "math" + "reflect" +) + +func ParseUint64(t interface{}) (uint64, error) { + vf, ok := t.(float64) + + if !ok { + return 0, errors.New("expected uint64 got: " + reflect.TypeOf(t).String()) + } + + vu := uint64(vf) + if float64(vu) != vf { + return 0, errors.New("expected unsigned integer got: float") + } + return vu, nil +} + +func ParseUint8(t interface{}) (uint8, error) { + vf, ok := t.(float64) + + if !ok { + return 0, errors.New("expected uint8 got: " + reflect.TypeOf(t).String()) + } + + if math.Trunc(vf) != vf { + return 0, errors.New("expected unsigned integer got: float") + } + return uint8(vf), nil +} + +func ParseString(t interface{}) (string, error) { + s, ok := t.(string) + + if !ok { + return "", errors.New("expected string got: " + reflect.TypeOf(t).String()) + } + return s, nil +} + +func ParseSlice(t interface{}) ([]interface{}, error) { + s, ok := t.([]interface{}) + + if !ok { + return nil, errors.New("expected array got: " + reflect.TypeOf(t).String()) + } + return s, nil +} + +func Parse(msg []byte) (*Msg, error) { + req := Msg{} + + var v []interface{} + if err := json.Unmarshal(msg, &v); err != nil { + return nil, fmt.Errorf("Could not parse message: %w", err) + } + + if len(v) != 4 { + return nil, errors.New("message must contains 4 elements") + } + + t, err := ParseUint8(v[0]) + if err != nil { + return nil, fmt.Errorf("failed to parse type: %w", err) + } + if t != Request && t != Response && t != Event { + return nil, errors.New("message type must be 0, 1 or 2") + } + + reqID, err := ParseUint64(v[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse request ID: %w", err) + } + + method, err := ParseString(v[2]) + if err != nil { + return nil, fmt.Errorf("failed to parse method: %w", err) + } + + args, err := ParseSlice(v[3]) + if err != nil { + return nil, fmt.Errorf("failed to parse arguments: %w", err) + } + + req.Type = t + req.ReqID = reqID + req.Method = method + req.Args = args + + return &req, nil +} diff --git a/pkg/msg/parser_test.go b/pkg/msg/parser_test.go new file mode 100644 index 0000000..56ab0a3 --- /dev/null +++ b/pkg/msg/parser_test.go @@ -0,0 +1,95 @@ +package msg + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParserSuccess(t *testing.T) { + msg, err := Parse([]byte(`[0,42,"ping",[]]`)) + assert.NoError(t, err) + assert.Equal(t, + &Msg{ + Type: Request, + ReqID: 42, + Method: "ping", + Args: []interface{}{}, + }, msg) + + msg, err = Parse([]byte(`[1,42,"pong",[]]`)) + assert.NoError(t, err) + assert.Equal(t, + &Msg{ + Type: Response, + ReqID: 42, + Method: "pong", + Args: []interface{}{}, + }, msg) + + msg, err = Parse([]byte(`[2,42,"temperature",[28.7]]`)) + assert.NoError(t, err) + assert.Equal(t, + &Msg{ + Type: Event, + ReqID: 42, + Method: "temperature", + Args: []interface{}{28.7}, + }, msg) +} + +func TestParserErrorsMessageLength(t *testing.T) { + msg, err := Parse([]byte(`[0,42,"ping"]`)) + assert.EqualError(t, err, "message must contains 4 elements") + assert.Nil(t, msg) +} + +func TestParserErrorsBadJSON(t *testing.T) { + msg, err := Parse([]byte(`[0,42,"ping",[]`)) + assert.EqualError(t, err, "Could not parse message: unexpected end of JSON input") + assert.Nil(t, msg) +} + +func TestParserErrorsType(t *testing.T) { + msg, err := Parse([]byte(`[3,42,"ping",[]]`)) + assert.EqualError(t, err, "message type must be 0, 1 or 2") + assert.Nil(t, msg) + + msg, err = Parse([]byte(`[0.1,42,"pong",[]]`)) + assert.EqualError(t, err, "failed to parse type: expected unsigned integer got: float") + assert.Nil(t, msg) + + msg, err = Parse([]byte(`["0",42,"pong",[]]`)) + assert.EqualError(t, err, "failed to parse type: expected uint8 got: string") + assert.Nil(t, msg) +} + +func TestParserErrorsRequestID(t *testing.T) { + msg, err := Parse([]byte(`[0,"42","ping",[]]`)) + assert.EqualError(t, err, "failed to parse request ID: expected uint64 got: string") + assert.Nil(t, msg) + + msg, err = Parse([]byte(`[0,42.1,"ping",[]]`)) + assert.EqualError(t, err, "failed to parse request ID: expected unsigned integer got: float") + assert.Nil(t, msg) +} + +func TestParserErrorsMethod(t *testing.T) { + msg, err := Parse([]byte(`[0,42,51,[]]`)) + assert.EqualError(t, err, "failed to parse method: expected string got: float64") + assert.Nil(t, msg) + + msg, err = Parse([]byte(`[0,42,true,[]]`)) + assert.EqualError(t, err, "failed to parse method: expected string got: bool") + assert.Nil(t, msg) +} + +func TestParserErrorsArgs(t *testing.T) { + msg, err := Parse([]byte(`[0,42,"ping",true]`)) + assert.EqualError(t, err, "failed to parse arguments: expected array got: bool") + assert.Nil(t, msg) + + msg, err = Parse([]byte(`[0,42,"ping","hello"]`)) + assert.EqualError(t, err, "failed to parse arguments: expected array got: string") + assert.Nil(t, msg) +} From a4ea75f32123e6128e705a6efe1248725fcbd2da Mon Sep 17 00:00:00 2001 From: Camille Meulien Date: Sun, 3 May 2020 22:55:45 +0300 Subject: [PATCH 2/5] Adapt the hub for the new protocol --- README.md | 48 ++++++- pkg/msg/parser.go | 20 +++ pkg/routing/client.go | 34 ++--- pkg/routing/client_test.go | 8 -- pkg/routing/hub.go | 254 +++++++++++++++++++++---------------- pkg/routing/hub_test.go | 44 ++++--- pkg/routing/topic.go | 10 -- 7 files changed, 247 insertions(+), 171 deletions(-) diff --git a/README.md b/README.md index 4aa52ea..6eee5b4 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,8 @@ # RANGO -Rango is a general purpose websocket server which dispatch public and private messages. +Rango is a general purpose websocket server which dispatches public and private messages. It's using AMQP (RabbitMQ) as source of messages. -Rango is made as a drop-in replacement of ranger built in ruby. - ## Build ```bash @@ -29,16 +27,52 @@ wscat --connect localhost:8080/public wscat --connect localhost:8080/private --header "Authorization: Bearer $(go run ./tools/jwt)" ``` -## Messages +## RPC Methods + +Every request and responses are formated in a array the following way: + +```json +[0, 42, "method", ["any", "arguments", 51]] +``` + +- The first argument is 0 for requests and 1 for responses. + +- The second argument is the request ID, the client can set the value he wants and the server will include it in the response. This helps to keep track of which response stands for which request. We use often 42 in our examples, you need to change this value by your own according to your implementation. + +- Third is the method name + +- Last is a list of arguments for the method + +### Subscribe to public streams + +Request: -### Subscribe to a stream list +``` +[0,42,"subscribe",["public",["eurusd.trades","eurusd.ob-inc"]]] +[0,43,"subscribe",["private",["trades","orders"]]] +``` + +Response: ``` -{"event":"subscribe","streams":["eurusd.trades","eurusd.ob-inc"]} +[1,42,"subscribed",["public",["eurusd.trades","eurusd.ob-inc"]]] +[1,43,"subscribed",["private",["trades","orders"]]] ``` ### Unsubscribe to one or several streams +Request: + +``` +[0,42,"unsubscribe",["public",["eurusd.trades","eurusd.ob-inc"]]] +``` + +Response: + ``` -{"event":"unsubscribe","streams":["eurusd.trades"]} +[1,42,"unsubscribe",["public",["eurusd.trades","eurusd.ob-inc"]]] ``` + +## RPC Responses + +### Authentication notification diff --git a/pkg/msg/parser.go b/pkg/msg/parser.go index a56d719..e8c0a37 100644 --- a/pkg/msg/parser.go +++ b/pkg/msg/parser.go @@ -53,6 +53,26 @@ func ParseSlice(t interface{}) ([]interface{}, error) { return s, nil } +func ParseSliceOfStrings(t interface{}) ([]string, error) { + s, err := ParseSlice(t) + if err != nil { + return nil, err + } + + a := make([]string, len(s)) + + for i, istr := range s { + str, ok := istr.(string) + a[i] = str + + if !ok { + return nil, errors.New("expected array of string, got unexpected " + reflect.TypeOf(istr).String()) + } + } + + return a, nil +} + func Parse(msg []byte) (*Msg, error) { req := Msg{} diff --git a/pkg/routing/client.go b/pkg/routing/client.go index a391ecd..1eab490 100644 --- a/pkg/routing/client.go +++ b/pkg/routing/client.go @@ -3,12 +3,11 @@ package routing import ( "bytes" "net/http" - "strings" "time" "github.com/gorilla/websocket" - msg "github.com/openware/rango/pkg/message" "github.com/openware/rango/pkg/metrics" + "github.com/openware/rango/pkg/msg" "github.com/rs/zerolog/log" ) @@ -154,23 +153,6 @@ func (c *Client) UnsubscribePrivate(s string) { c.privSub = l } -func parseStreamsFromURI(uri string) []string { - streams := make([]string, 0) - path := strings.Split(uri, "?") - if len(path) != 2 { - return streams - } - for _, up := range strings.Split(path[1], "&") { - p := strings.Split(up, "=") - if len(p) != 2 || p[0] != "stream" { - continue - } - streams = append(streams, strings.Split(p[1], ",")...) - - } - return streams -} - // read pumps messages from the websocket connection to the hub. // // The application runs read in a per-connection goroutine. The application @@ -211,12 +193,14 @@ func (c *Client) read() { continue } - req, err := msg.ParseRequest(message) + req, err := msg.Parse(message) if err != nil { - c.send <- []byte(responseMust(err, nil)) + log.Error().Msgf("fail to parse message: %s", err.Error()) + c.send <- msg.NewResponse(&msg.Msg{ReqID: 0}, "error", []interface{}{err.Error()}).Encode() continue } + log.Debug().Msgf("Pushing request to hub: %v", req) c.hub.Requests <- Request{c, req} } } @@ -249,6 +233,14 @@ func (c *Client) write() { return } w.Write(message) + + // Add queued chat messages to the current websocket message. + n := len(c.send) + for i := 0; i < n; i++ { + w.Write(newline) + w.Write(<-c.send) + } + if err := w.Close(); err != nil { return } diff --git a/pkg/routing/client_test.go b/pkg/routing/client_test.go index 458c52a..4b49a82 100644 --- a/pkg/routing/client_test.go +++ b/pkg/routing/client_test.go @@ -59,11 +59,3 @@ func TestClient(t *testing.T) { assert.Equal(t, []string{}, client.pubSub) assert.Equal(t, []string{}, client.privSub) } - -func TestParseStreamsFromURI(t *testing.T) { - assert.Equal(t, []string{}, parseStreamsFromURI("/?")) - assert.Equal(t, []string{}, parseStreamsFromURI("")) - assert.Equal(t, []string{"aaa", "bbb"}, parseStreamsFromURI("/?stream=aaa&stream=bbb")) - assert.Equal(t, []string{"aaa", "bbb"}, parseStreamsFromURI("/?stream=aaa,bbb")) - assert.Equal(t, []string{"aaa", "bbb"}, parseStreamsFromURI("/public/?stream=aaa,bbb")) -} diff --git a/pkg/routing/hub.go b/pkg/routing/hub.go index cec660d..dc513ce 100644 --- a/pkg/routing/hub.go +++ b/pkg/routing/hub.go @@ -6,18 +6,13 @@ import ( "fmt" "strings" - msg "github.com/openware/rango/pkg/message" "github.com/openware/rango/pkg/metrics" + "github.com/openware/rango/pkg/msg" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/streadway/amqp" ) -type Request struct { - client IClient - msg.Request -} - // Hub maintains the set of active clients and broadcasts messages to the // clients. type Hub struct { @@ -37,6 +32,13 @@ type Hub struct { IncrementalObjects map[string]*IncrementalObject } +// Request is a container for client message and pointer to the client +type Request struct { + client IClient + *msg.Msg +} + +// Event contains an event received through AMQP type Event struct { Scope string // global, public, private Stream string // channel routing key @@ -45,11 +47,13 @@ type Event struct { Body interface{} // event json body } +// IncrementalObject stores an incremental object built from a snapshot and increments type IncrementalObject struct { Snapshot string Increments []string } +// NewHub creates a Hub func NewHub() *Hub { return &Hub{ Requests: make(chan Request), @@ -90,7 +94,8 @@ func (h *Hub) ListenWebsocketEvents() { for { select { case req := <-h.Requests: - h.handleRequest(&req) + resp := h.handleRequest(&req) + req.client.Send(string(resp.Encode())) case client := <-h.Unregister: log.Info().Msgf("Unregistering client %s", client.GetUID()) @@ -273,132 +278,169 @@ func (h *Hub) unsubscribeAll(client IClient) { } } -func responseMust(e error, r interface{}) string { - res, err := msg.PackOutgoingResponse(e, r) - if err != nil { - log.Panic().Msg("responseMust failed:" + err.Error()) - panic(err.Error()) - } +func (h *Hub) handleRequest(req *Request) (resp *msg.Msg) { + var err error - return string(res) -} - -func isPrivateStream(s string) bool { - return strings.Count(s, ".") == 0 -} - -func (h *Hub) handleRequest(req *Request) { switch req.Method { case "subscribe": - h.handleSubscribe(req) + err, resp = h.handleSubscribe(req) case "unsubscribe": - h.handleUnsubscribe(req) + err, resp = h.handleUnsubscribe(req) default: - req.client.Send(responseMust(errors.New("unsupported method"), nil)) + return msg.NewResponse(req.Msg, "error", []interface{}{"Unknown method " + req.Method}) + } + + if err != nil { + return msg.NewResponse(req.Msg, "error", []interface{}{err.Error()}) } + return resp } -func (h *Hub) handleSubscribe(req *Request) { - for _, t := range req.Streams { - if isPrivateStream(t) { - uid := req.client.GetUID() - if uid == "" { - log.Error().Msgf("Anonymous user tried to subscribe to private stream %s", t) - continue - } +func (h *Hub) handleSubscribePublic(client IClient, streams []string) { + for _, t := range streams { + topic, ok := h.PublicTopics[t] + if !ok { + topic = NewTopic(h) + h.PublicTopics[t] = topic + } - uTopics, ok := h.PrivateTopics[uid] - if !ok { - uTopics = make(map[string]*Topic, 3) - h.PrivateTopics[uid] = uTopics + topic.subscribe(client) + client.SubscribePublic(t) + if isIncrementObject(t) { + o, ok := h.IncrementalObjects[t] + if ok && o.Snapshot != "" { + client.Send(o.Snapshot) + for _, inc := range o.Increments { + client.Send(inc) + } } + } + } +} - topic, ok := uTopics[t] - if !ok { - topic = NewTopic(h) - uTopics[t] = topic - } +func (h *Hub) handleSubscribePrivate(client IClient, uid string, streams []string) { + for _, t := range streams { + uTopics, ok := h.PrivateTopics[uid] + if !ok { + uTopics = make(map[string]*Topic, 3) + h.PrivateTopics[uid] = uTopics + } - if topic.subscribe(req.client) { - metrics.RecordHubSubscription("private", t) - req.client.SubscribePrivate(t) - } - } else { - topic, ok := h.PublicTopics[t] - if !ok { - topic = NewTopic(h) - h.PublicTopics[t] = topic - } + topic, ok := uTopics[t] + if !ok { + topic = NewTopic(h) + uTopics[t] = topic + } - if topic.subscribe(req.client) { - metrics.RecordHubSubscription("public", t) - req.client.SubscribePublic(t) - } + topic.subscribe(client) + client.SubscribePrivate(t) + } +} - if isIncrementObject(t) { - o, ok := h.IncrementalObjects[t] - if ok && o.Snapshot != "" { - req.client.Send(o.Snapshot) - for _, inc := range o.Increments { - req.client.Send(inc) - } - } - } +func (h *Hub) handleSubscribe(req *Request) (error, *msg.Msg) { + args := req.Msg.Args + if len(args) != 2 { + return errors.New("method expects exactly 2 arguments"), nil + } + + scope, err := msg.ParseString(args[0]) + if err != nil { + return errors.New("first argument must be a string"), nil + } + + streams, err := msg.ParseSliceOfStrings(args[1]) + if err != nil { + return errors.New("second argument must be a list of strings"), nil + } + + switch scope { + case "public": + h.handleSubscribePublic(req.client, streams) + + case "private": + uid := req.client.GetUID() + if uid != "" { + h.handleSubscribePrivate(req.client, uid, streams) + } else { + return errors.New("unauthorized"), nil } + + default: + return errors.New("Unexpected scope " + scope), nil } - req.client.Send(responseMust(nil, map[string]interface{}{ - "message": "subscribed", - "streams": req.client.GetSubscriptions(), - })) + return nil, msg.NewResponse(req.Msg, "subscribed", msg.Convss2is(req.client.GetSubscriptions())) } -func (h *Hub) handleUnsubscribe(req *Request) { - for _, t := range req.Streams { - if isPrivateStream(t) { - uid := req.client.GetUID() - if uid == "" { - continue - } - uTopics, ok := h.PrivateTopics[uid] - if !ok { - continue - } +func (h *Hub) handleUnsubscribe(req *Request) (error, *msg.Msg) { + args := req.Msg.Args + if len(args) != 2 { + return errors.New("method expects exactly 2 arguments"), nil + } - topic, ok := uTopics[t] - if ok { - if topic.unsubscribe(req.client) { - metrics.RecordHubUnsubscription("private", t) - req.client.UnsubscribePrivate(t) - } + scope, ok := args[0].(string) + if !ok { + return errors.New("first argument must be a string"), nil + } - if topic.len() == 0 { - delete(uTopics, t) - } - } + streams, ok := args[1].([]string) + if !ok { + return errors.New("second argument must be a list of strings"), nil + } + switch scope { + case "public": + h.handleUnsubscribePublic(req.client, streams) - uTopics, ok = h.PrivateTopics[uid] - if ok && len(uTopics) == 0 { - delete(h.PrivateTopics, uid) + case "private": + uid := req.client.GetUID() + if uid != "" { + h.handleUnsubscribePrivate(req.client, uid, streams) + } else { + return errors.New("unauthorized"), nil + } + + default: + return errors.New("Unexpected scope " + scope), nil + } + + return nil, msg.NewResponse(req.Msg, "unsubscribed", msg.Convss2is(req.client.GetSubscriptions())) +} + +func (h *Hub) handleUnsubscribePublic(client IClient, streams []string) { + for _, t := range streams { + topic, ok := h.PublicTopics[t] + if ok { + if topic.unsubscribe(client) { + client.UnsubscribePublic(t) + metrics.RecordHubUnsubscription("public", t) } + if topic.len() == 0 { + delete(h.PublicTopics, t) + } + } + } +} - } else { - topic, ok := h.PublicTopics[t] - if ok { - if topic.unsubscribe(req.client) { - metrics.RecordHubUnsubscription("public", t) - req.client.UnsubscribePublic(t) - } +func (h *Hub) handleUnsubscribePrivate(client IClient, uid string, streams []string) { + uTopics, ok := h.PrivateTopics[uid] + if !ok { + return + } - if topic.len() == 0 { - delete(h.PublicTopics, t) - } + for _, t := range streams { + topic := uTopics[t] + if ok { + if topic.unsubscribe(client) { + client.UnsubscribePrivate(t) + metrics.RecordHubUnsubscription("private", t) + } + if topic.len() == 0 { + delete(uTopics, t) } } } - req.client.Send(responseMust(nil, map[string]interface{}{ - "message": "unsubscribed", - "streams": req.client.GetSubscriptions(), - })) + if len(uTopics) == 0 { + delete(h.PrivateTopics, uid) + } } diff --git a/pkg/routing/hub_test.go b/pkg/routing/hub_test.go index c376fb9..42ed9d7 100644 --- a/pkg/routing/hub_test.go +++ b/pkg/routing/hub_test.go @@ -3,7 +3,7 @@ package routing import ( "testing" - "github.com/openware/rango/pkg/message" + "github.com/openware/rango/pkg/msg" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -46,22 +46,28 @@ func (c *MockedClient) UnsubscribePrivate(s string) { c.Called(s) } -func setup(c *MockedClient, streams []string) *Hub { +func setup(c *MockedClient, streams []interface{}) *Hub { h := NewHub() h.handleSubscribe(&Request{ client: c, - Request: message.Request{ - Streams: streams, + Msg: &msg.Msg{ + Type: msg.Request, + ReqID: 41, + Method: "subscribe", + Args: []interface{}{"public", streams}, }, }) return h } -func teardown(h *Hub, c *MockedClient, streams []string) { +func teardown(h *Hub, c *MockedClient, streams []interface{}) { h.handleUnsubscribe(&Request{ client: c, - Request: message.Request{ - Streams: streams, + Msg: &msg.Msg{ + Type: msg.Request, + ReqID: 42, + Method: "unsubscribe", + Args: []interface{}{"public", streams}, }, }) } @@ -70,14 +76,14 @@ func TestAnonymous(t *testing.T) { t.Run("subscribe to a public single stream", func(t *testing.T) { c := &MockedClient{} - streams := []string{ + streams := []interface{}{ "eurusd.trades", } c.On("GetUID").Return("") c.On("GetSubscriptions").Return(streams).Once() c.On("SubscribePublic", streams[0]).Return().Once() - c.On("Send", `{"success":{"message":"subscribed","streams":["`+streams[0]+`"]}}`).Return() + c.On("Send", `[1,41,"subscribed",["eurusd.trades"]]`).Return() h := setup(c, streams) assert.Equal(t, 1, len(h.PublicTopics)) @@ -85,7 +91,7 @@ func TestAnonymous(t *testing.T) { c.On("UnsubscribePublic", streams[0]).Return() c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() + c.On("Send", `[1,42,"unsubscribed",[]]`).Return() teardown(h, c, streams) assert.Equal(t, 0, len(h.PublicTopics)) @@ -94,7 +100,7 @@ func TestAnonymous(t *testing.T) { t.Run("subscribe to multiple public streams", func(t *testing.T) { c := &MockedClient{} - streams := []string{ + streams := []interface{}{ "eurusd.trades", "eurusd.updates", } @@ -105,7 +111,7 @@ func TestAnonymous(t *testing.T) { c.On("SubscribePublic", "eurusd.updates").Return() c.On("Send", `{"success":{"message":"subscribed","streams":["eurusd.trades","eurusd.updates"]}}`).Return() - h := setup(c, []string{ + h := setup(c, []interface{}{ "eurusd.trades", "eurusd.updates", }) @@ -132,7 +138,7 @@ func TestAnonymous(t *testing.T) { c.On("SubscribePrivate", "trades").Return() c.On("Send", `{"success":{"message":"subscribed","streams":[]}}`).Return() - h := setup(&c, []string{ + h := setup(&c, []interface{}{ "trades", }) @@ -149,7 +155,7 @@ func TestAuthenticated(t *testing.T) { c.On("SubscribePrivate", "trades").Return() c.On("Send", `{"success":{"message":"subscribed","streams":["trades"]}}`).Return() - h := setup(c, []string{ + h := setup(c, []interface{}{ "trades", }) assert.Equal(t, 0, len(h.PublicTopics)) @@ -159,7 +165,7 @@ func TestAuthenticated(t *testing.T) { c.On("GetSubscriptions").Return([]string{}).Once() c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() - teardown(h, c, []string{"trades"}) + teardown(h, c, []interface{}{"trades"}) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) @@ -173,7 +179,7 @@ func TestAuthenticated(t *testing.T) { c.On("SubscribePrivate", "orders").Return() c.On("Send", `{"success":{"message":"subscribed","streams":["trades","orders"]}}`).Return() - h := setup(c, []string{"trades", "orders"}) + h := setup(c, []interface{}{"trades", "orders"}) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 1, len(h.PrivateTopics)) @@ -186,7 +192,7 @@ func TestAuthenticated(t *testing.T) { c.On("GetSubscriptions").Return([]string{}).Once() c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() - teardown(h, c, []string{"trades", "orders"}) + teardown(h, c, []interface{}{"trades", "orders"}) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) @@ -202,7 +208,7 @@ func TestAuthenticated(t *testing.T) { c.On("SubscribePublic", "eurusd.updates").Return() c.On("Send", `{"success":{"message":"subscribed","streams":["trades","orders","eurusd.updates"]}}`).Return() - h := setup(c, []string{"trades", "orders", "eurusd.updates"}) + h := setup(c, []interface{}{"trades", "orders", "eurusd.updates"}) assert.Equal(t, 1, len(h.PublicTopics)) assert.Equal(t, 1, len(h.PrivateTopics)) @@ -216,7 +222,7 @@ func TestAuthenticated(t *testing.T) { c.On("GetSubscriptions").Return([]string{}).Once() c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() - teardown(h, c, []string{"trades", "orders", "eurusd.updates"}) + teardown(h, c, []interface{}{"trades", "orders", "eurusd.updates"}) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) diff --git a/pkg/routing/topic.go b/pkg/routing/topic.go index 24d5558..5ea5a6e 100644 --- a/pkg/routing/topic.go +++ b/pkg/routing/topic.go @@ -3,7 +3,6 @@ package routing import ( "encoding/json" - msg "github.com/openware/rango/pkg/message" "github.com/rs/zerolog/log" ) @@ -19,15 +18,6 @@ func NewTopic(h *Hub) *Topic { } } -func eventMust(method string, data interface{}) []byte { - ev, err := msg.PackOutgoingEvent(method, data) - if err != nil { - log.Panic().Msg(err.Error()) - } - - return ev -} - func contains(list []string, el string) bool { for _, l := range list { if l == el { From e777adc20003852184a5e8f7f3773f969dd8377d Mon Sep 17 00:00:00 2001 From: Camille Meulien Date: Mon, 4 May 2020 19:26:13 +0300 Subject: [PATCH 3/5] Fix tests --- pkg/msg/msg.go | 9 ++ pkg/msg/parser.go | 2 - pkg/routing/client.go | 27 +++-- pkg/routing/client_test.go | 55 +++++----- pkg/routing/hub.go | 46 ++++---- pkg/routing/hub_test.go | 209 +++++++++++++++++++++++++++---------- 6 files changed, 226 insertions(+), 122 deletions(-) diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index 27a280a..1f43601 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -56,3 +56,12 @@ func Convss2is(a []string) []interface{} { } return s } + +func Contains(haystack []interface{}, niddle interface{}) bool { + for _, el := range haystack { + if el == niddle { + return true + } + } + return false +} diff --git a/pkg/msg/parser.go b/pkg/msg/parser.go index e8c0a37..df3c424 100644 --- a/pkg/msg/parser.go +++ b/pkg/msg/parser.go @@ -60,7 +60,6 @@ func ParseSliceOfStrings(t interface{}) ([]string, error) { } a := make([]string, len(s)) - for i, istr := range s { str, ok := istr.(string) a[i] = str @@ -69,7 +68,6 @@ func ParseSliceOfStrings(t interface{}) ([]string, error) { return nil, errors.New("expected array of string, got unexpected " + reflect.TypeOf(istr).String()) } } - return a, nil } diff --git a/pkg/routing/client.go b/pkg/routing/client.go index 1eab490..b4578a5 100644 --- a/pkg/routing/client.go +++ b/pkg/routing/client.go @@ -40,7 +40,8 @@ type IClient interface { Send(string) Close() GetUID() string - GetSubscriptions() []string + GetPublicSubscriptions() []interface{} + GetPrivateSubscriptions() []interface{} SubscribePublic(string) SubscribePrivate(string) UnsubscribePublic(string) @@ -54,8 +55,8 @@ type Client struct { // User ID if authorized UID string - pubSub []string - privSub []string + pubSub []interface{} + privSub []interface{} // The websocket connection. conn *websocket.Conn @@ -76,8 +77,8 @@ func NewClient(hub *Hub, w http.ResponseWriter, r *http.Request) { conn: conn, send: make(chan []byte, 256), UID: r.Header.Get("JwtUID"), - pubSub: []string{}, - privSub: []string{}, + pubSub: []interface{}{}, + privSub: []interface{}{}, } if client.UID == "" { @@ -113,24 +114,28 @@ func (c *Client) GetUID() string { return c.UID } -func (c *Client) GetSubscriptions() []string { - return append(c.pubSub, c.privSub...) +func (c *Client) GetPublicSubscriptions() []interface{} { + return c.pubSub +} + +func (c *Client) GetPrivateSubscriptions() []interface{} { + return c.privSub } func (c *Client) SubscribePublic(s string) { - if !contains(c.pubSub, s) { + if !msg.Contains(c.pubSub, s) { c.pubSub = append(c.pubSub, s) } } func (c *Client) SubscribePrivate(s string) { - if !contains(c.privSub, s) { + if !msg.Contains(c.privSub, s) { c.privSub = append(c.privSub, s) } } func (c *Client) UnsubscribePublic(s string) { - l := make([]string, len(c.pubSub)-1) + l := make([]interface{}, len(c.pubSub)-1) i := 0 for _, el := range c.pubSub { if s != el { @@ -142,7 +147,7 @@ func (c *Client) UnsubscribePublic(s string) { } func (c *Client) UnsubscribePrivate(s string) { - l := make([]string, len(c.privSub)-1) + l := make([]interface{}, len(c.privSub)-1) i := 0 for _, el := range c.privSub { if s != el { diff --git a/pkg/routing/client_test.go b/pkg/routing/client_test.go index 4b49a82..a23078c 100644 --- a/pkg/routing/client_test.go +++ b/pkg/routing/client_test.go @@ -12,50 +12,51 @@ func TestClient(t *testing.T) { hub: hub, send: make(chan []byte, 256), UID: "UIDABC001", - pubSub: []string{}, - privSub: []string{}, + pubSub: []interface{}{}, + privSub: []interface{}{}, } assert.Equal(t, "UIDABC001", client.GetUID()) - assert.Equal(t, []string{}, client.GetSubscriptions()) + assert.Equal(t, []interface{}{}, client.GetPublicSubscriptions()) + assert.Equal(t, []interface{}{}, client.GetPrivateSubscriptions()) client.SubscribePublic("a.x") - assert.Equal(t, []string{"a.x"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{}, client.privSub) + assert.Equal(t, []interface{}{"a.x"}, client.GetPublicSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{}, client.privSub) client.SubscribePublic("a.y") - assert.Equal(t, []string{"a.x", "a.y"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x", "a.y"}, client.pubSub) - assert.Equal(t, []string{}, client.privSub) + assert.Equal(t, []interface{}{"a.x", "a.y"}, client.GetPublicSubscriptions()) + assert.Equal(t, []interface{}{"a.x", "a.y"}, client.pubSub) + assert.Equal(t, []interface{}{}, client.privSub) client.UnsubscribePublic("a.y") - assert.Equal(t, []string{"a.x"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{}, client.privSub) + assert.Equal(t, []interface{}{"a.x"}, client.GetPublicSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{}, client.privSub) client.SubscribePrivate("b") - assert.Equal(t, []string{"a.x", "b"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{"b"}, client.privSub) + assert.Equal(t, []interface{}{"b"}, client.GetPrivateSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{"b"}, client.privSub) client.SubscribePrivate("c") - assert.Equal(t, []string{"a.x", "b", "c"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{"b", "c"}, client.privSub) + assert.Equal(t, []interface{}{"b", "c"}, client.GetPrivateSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{"b", "c"}, client.privSub) client.UnsubscribePrivate("b") - assert.Equal(t, []string{"a.x", "c"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{"c"}, client.privSub) + assert.Equal(t, []interface{}{"c"}, client.GetPrivateSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{"c"}, client.privSub) client.UnsubscribePrivate("c") - assert.Equal(t, []string{"a.x"}, client.GetSubscriptions()) - assert.Equal(t, []string{"a.x"}, client.pubSub) - assert.Equal(t, []string{}, client.privSub) + assert.Equal(t, []interface{}{}, client.GetPrivateSubscriptions()) + assert.Equal(t, []interface{}{"a.x"}, client.pubSub) + assert.Equal(t, []interface{}{}, client.privSub) client.UnsubscribePublic("a.x") - assert.Equal(t, []string{}, client.GetSubscriptions()) - assert.Equal(t, []string{}, client.pubSub) - assert.Equal(t, []string{}, client.privSub) + assert.Equal(t, []interface{}{}, client.GetPublicSubscriptions()) + assert.Equal(t, []interface{}{}, client.pubSub) + assert.Equal(t, []interface{}{}, client.privSub) } diff --git a/pkg/routing/hub.go b/pkg/routing/hub.go index dc513ce..4db031c 100644 --- a/pkg/routing/hub.go +++ b/pkg/routing/hub.go @@ -283,9 +283,9 @@ func (h *Hub) handleRequest(req *Request) (resp *msg.Msg) { switch req.Method { case "subscribe": - err, resp = h.handleSubscribe(req) + resp, err = h.handleSubscribe(req) case "unsubscribe": - err, resp = h.handleUnsubscribe(req) + resp, err = h.handleUnsubscribe(req) default: return msg.NewResponse(req.Msg, "error", []interface{}{"Unknown method " + req.Method}) } @@ -337,73 +337,71 @@ func (h *Hub) handleSubscribePrivate(client IClient, uid string, streams []strin } } -func (h *Hub) handleSubscribe(req *Request) (error, *msg.Msg) { +func (h *Hub) handleSubscribe(req *Request) (*msg.Msg, error) { args := req.Msg.Args if len(args) != 2 { - return errors.New("method expects exactly 2 arguments"), nil + return nil, errors.New("method expects exactly 2 arguments") } scope, err := msg.ParseString(args[0]) if err != nil { - return errors.New("first argument must be a string"), nil + return nil, errors.New("first argument must be a string") } streams, err := msg.ParseSliceOfStrings(args[1]) if err != nil { - return errors.New("second argument must be a list of strings"), nil + log.Error().Msgf("in subscribe failed to parse argument: %s", err.Error()) + return nil, errors.New("second argument must be a list of strings") } switch scope { case "public": h.handleSubscribePublic(req.client, streams) + return msg.NewResponse(req.Msg, "subscribed", []interface{}{"public", req.client.GetPublicSubscriptions()}), nil case "private": uid := req.client.GetUID() if uid != "" { h.handleSubscribePrivate(req.client, uid, streams) - } else { - return errors.New("unauthorized"), nil + return msg.NewResponse(req.Msg, "subscribed", []interface{}{"private", req.client.GetPrivateSubscriptions()}), nil } + return nil, errors.New("unauthorized") - default: - return errors.New("Unexpected scope " + scope), nil } - - return nil, msg.NewResponse(req.Msg, "subscribed", msg.Convss2is(req.client.GetSubscriptions())) + return nil, errors.New("Unexpected scope " + scope) } -func (h *Hub) handleUnsubscribe(req *Request) (error, *msg.Msg) { +func (h *Hub) handleUnsubscribe(req *Request) (*msg.Msg, error) { args := req.Msg.Args if len(args) != 2 { - return errors.New("method expects exactly 2 arguments"), nil + return nil, errors.New("method expects exactly 2 arguments") } scope, ok := args[0].(string) if !ok { - return errors.New("first argument must be a string"), nil + return nil, errors.New("first argument must be a string") } - streams, ok := args[1].([]string) - if !ok { - return errors.New("second argument must be a list of strings"), nil + streams, err := msg.ParseSliceOfStrings(args[1]) + if err != nil { + log.Error().Msgf("in subscribe failed to parse argument: %s", err.Error()) + return nil, errors.New("second argument must be a list of strings") } switch scope { case "public": h.handleUnsubscribePublic(req.client, streams) + return msg.NewResponse(req.Msg, "unsubscribed", []interface{}{"public", req.client.GetPublicSubscriptions()}), nil case "private": uid := req.client.GetUID() if uid != "" { h.handleUnsubscribePrivate(req.client, uid, streams) - } else { - return errors.New("unauthorized"), nil + return msg.NewResponse(req.Msg, "unsubscribed", []interface{}{"private", req.client.GetPrivateSubscriptions()}), nil } + return nil, errors.New("unauthorized") - default: - return errors.New("Unexpected scope " + scope), nil } - - return nil, msg.NewResponse(req.Msg, "unsubscribed", msg.Convss2is(req.client.GetSubscriptions())) + return nil, errors.New("Unexpected scope " + scope) } func (h *Hub) handleUnsubscribePublic(client IClient, streams []string) { diff --git a/pkg/routing/hub_test.go b/pkg/routing/hub_test.go index 42ed9d7..b124a68 100644 --- a/pkg/routing/hub_test.go +++ b/pkg/routing/hub_test.go @@ -25,9 +25,14 @@ func (c *MockedClient) GetUID() string { return args.String(0) } -func (c *MockedClient) GetSubscriptions() []string { +func (c *MockedClient) GetPublicSubscriptions() []interface{} { args := c.Called() - return args.Get(0).([]string) + return args.Get(0).([]interface{}) +} + +func (c *MockedClient) GetPrivateSubscriptions() []interface{} { + args := c.Called() + return args.Get(0).([]interface{}) } func (c *MockedClient) SubscribePublic(s string) { @@ -46,28 +51,26 @@ func (c *MockedClient) UnsubscribePrivate(s string) { c.Called(s) } -func setup(c *MockedClient, streams []interface{}) *Hub { - h := NewHub() - h.handleSubscribe(&Request{ +func subscribe(h *Hub, c *MockedClient, reqID uint64, args []interface{}) (*msg.Msg, error) { + return h.handleSubscribe(&Request{ client: c, Msg: &msg.Msg{ Type: msg.Request, - ReqID: 41, + ReqID: reqID, Method: "subscribe", - Args: []interface{}{"public", streams}, + Args: args, }, }) - return h } -func teardown(h *Hub, c *MockedClient, streams []interface{}) { - h.handleUnsubscribe(&Request{ +func unsubscribe(h *Hub, c *MockedClient, reqID uint64, args []interface{}) (*msg.Msg, error) { + return h.handleUnsubscribe(&Request{ client: c, Msg: &msg.Msg{ Type: msg.Request, - ReqID: 42, + ReqID: reqID, Method: "unsubscribe", - Args: []interface{}{"public", streams}, + Args: args, }, }) } @@ -75,73 +78,99 @@ func teardown(h *Hub, c *MockedClient, streams []interface{}) { func TestAnonymous(t *testing.T) { t.Run("subscribe to a public single stream", func(t *testing.T) { c := &MockedClient{} + h := NewHub() - streams := []interface{}{ - "eurusd.trades", - } + streams := []interface{}{"eurusd.trades"} c.On("GetUID").Return("") - c.On("GetSubscriptions").Return(streams).Once() + c.On("GetPublicSubscriptions").Return(streams).Once() c.On("SubscribePublic", streams[0]).Return().Once() - c.On("Send", `[1,41,"subscribed",["eurusd.trades"]]`).Return() - h := setup(c, streams) + r, err := subscribe(h, c, 41, []interface{}{"public", streams}) + assert.NoError(t, err) + + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 41, + Method: "subscribed", + Args: []interface{}{"public", streams}, + }, r) + assert.Equal(t, 1, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) c.On("UnsubscribePublic", streams[0]).Return() - c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `[1,42,"unsubscribed",[]]`).Return() + c.On("GetPublicSubscriptions").Return([]interface{}{}).Once() + + r, err = unsubscribe(h, c, 42, []interface{}{"public", streams}) + assert.NoError(t, err) - teardown(h, c, streams) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 42, + Method: "unsubscribed", + Args: []interface{}{"public", []interface{}{}}, + }, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) t.Run("subscribe to multiple public streams", func(t *testing.T) { c := &MockedClient{} + h := NewHub() streams := []interface{}{ "eurusd.trades", "eurusd.updates", } c.On("GetUID").Return("") - c.On("GetSubscriptions").Return(streams).Once() + c.On("GetPublicSubscriptions").Return(streams).Once() c.On("SubscribePublic", "eurusd.trades").Return() c.On("SubscribePublic", "eurusd.updates").Return() - c.On("Send", `{"success":{"message":"subscribed","streams":["eurusd.trades","eurusd.updates"]}}`).Return() - h := setup(c, []interface{}{ - "eurusd.trades", - "eurusd.updates", - }) + r, err := subscribe(h, c, 41, []interface{}{"public", streams}) + assert.NoError(t, err) + + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 41, + Method: "subscribed", + Args: []interface{}{"public", streams}, + }, r) assert.Equal(t, 2, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) c.On("UnsubscribePublic", streams[0]).Return().Once() c.On("UnsubscribePublic", streams[1]).Return().Once() - c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() + c.On("GetPublicSubscriptions").Return([]interface{}{}).Once() + + r, err = unsubscribe(h, c, 42, []interface{}{"public", streams}) + assert.NoError(t, err) + + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 42, + Method: "unsubscribed", + Args: []interface{}{"public", []interface{}{}}, + }, r) - teardown(h, c, streams) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) t.Run("subscribe to a private single stream", func(t *testing.T) { - c := MockedClient{} + c := &MockedClient{} + h := NewHub() c.On("GetUID").Return("") - c.On("GetSubscriptions").Return([]string{}) + c.On("GetPrivateSubscriptions").Return([]interface{}{}) c.On("SubscribePrivate", "trades").Return() - c.On("Send", `{"success":{"message":"subscribed","streams":[]}}`).Return() - - h := setup(&c, []interface{}{ - "trades", - }) + r, err := subscribe(h, c, 41, []interface{}{"private", []interface{}{"trades"}}) + assert.EqualError(t, err, "unauthorized") + assert.Nil(t, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) @@ -149,37 +178,57 @@ func TestAnonymous(t *testing.T) { func TestAuthenticated(t *testing.T) { t.Run("subscribe to a private single stream", func(t *testing.T) { c := &MockedClient{} + h := NewHub() c.On("GetUID").Return("UIDABC00001") - c.On("GetSubscriptions").Return([]string{"trades"}).Once() + c.On("GetPrivateSubscriptions").Return([]interface{}{"trades"}).Once() c.On("SubscribePrivate", "trades").Return() - c.On("Send", `{"success":{"message":"subscribed","streams":["trades"]}}`).Return() - h := setup(c, []interface{}{ - "trades", - }) + r, err := subscribe(h, c, 41, []interface{}{"private", []interface{}{"trades"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 41, + Method: "subscribed", + Args: []interface{}{"private", []interface{}{"trades"}}, + }, r) + assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 1, len(h.PrivateTopics)) c.On("UnsubscribePrivate", "trades").Return().Once() - c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() + c.On("GetPrivateSubscriptions").Return([]interface{}{}).Once() - teardown(h, c, []interface{}{"trades"}) + r, err = unsubscribe(h, c, 42, []interface{}{"private", []interface{}{"trades"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 42, + Method: "unsubscribed", + Args: []interface{}{"private", []interface{}{}}, + }, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) t.Run("subscribe to multiple private streams", func(t *testing.T) { c := &MockedClient{} + h := NewHub() - c.On("GetSubscriptions").Return([]string{"trades", "orders"}).Once() + c.On("GetPrivateSubscriptions").Return([]interface{}{"trades", "orders"}).Once() c.On("GetUID").Return("UIDABC00001") c.On("SubscribePrivate", "trades").Return() c.On("SubscribePrivate", "orders").Return() - c.On("Send", `{"success":{"message":"subscribed","streams":["trades","orders"]}}`).Return() + c.On("Send", `[1,41,"subscribed",["trades","orders"]]`).Return() - h := setup(c, []interface{}{"trades", "orders"}) + r, err := subscribe(h, c, 41, []interface{}{"private", []interface{}{"trades", "orders"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 41, + Method: "subscribed", + Args: []interface{}{"private", []interface{}{"trades", "orders"}}, + }, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 1, len(h.PrivateTopics)) @@ -189,10 +238,16 @@ func TestAuthenticated(t *testing.T) { c.On("UnsubscribePrivate", "trades").Return().Once() c.On("UnsubscribePrivate", "orders").Return().Once() - c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() + c.On("GetPrivateSubscriptions").Return([]interface{}{}).Once() - teardown(h, c, []interface{}{"trades", "orders"}) + r, err = unsubscribe(h, c, 42, []interface{}{"private", []interface{}{"trades", "orders"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 42, + Method: "unsubscribed", + Args: []interface{}{"private", []interface{}{}}, + }, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) @@ -200,15 +255,35 @@ func TestAuthenticated(t *testing.T) { t.Run("subscribe to multiple private and public streams", func(t *testing.T) { c := &MockedClient{} + h := NewHub() - c.On("GetSubscriptions").Return([]string{"trades", "orders", "eurusd.updates"}).Once() + c.On("GetPublicSubscriptions").Return([]interface{}{"eurusd.updates"}).Once() + c.On("GetPrivateSubscriptions").Return([]interface{}{"trades", "orders"}).Once() c.On("GetUID").Return("UIDABC00001") c.On("SubscribePrivate", "trades").Return() c.On("SubscribePrivate", "orders").Return() c.On("SubscribePublic", "eurusd.updates").Return() - c.On("Send", `{"success":{"message":"subscribed","streams":["trades","orders","eurusd.updates"]}}`).Return() + c.On("Send", `[1,41,"subscribed",["public",["eurusd.updates"]]]`).Return().Once() + c.On("Send", `[1,42,"subscribed",["private",["trades","orders"]]`).Return().Once() + + r, err := subscribe(h, c, 41, []interface{}{"public", []interface{}{"eurusd.updates"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 41, + Method: "subscribed", + Args: []interface{}{"public", []interface{}{"eurusd.updates"}}, + }, r) + + r, err = subscribe(h, c, 42, []interface{}{"private", []interface{}{"trades", "orders"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 42, + Method: "subscribed", + Args: []interface{}{"private", []interface{}{"trades", "orders"}}, + }, r) - h := setup(c, []interface{}{"trades", "orders", "eurusd.updates"}) assert.Equal(t, 1, len(h.PublicTopics)) assert.Equal(t, 1, len(h.PrivateTopics)) @@ -219,10 +294,28 @@ func TestAuthenticated(t *testing.T) { c.On("UnsubscribePrivate", "trades").Return().Once() c.On("UnsubscribePrivate", "orders").Return().Once() c.On("UnsubscribePublic", "eurusd.updates").Return().Once() - c.On("GetSubscriptions").Return([]string{}).Once() - c.On("Send", `{"success":{"message":"unsubscribed","streams":[]}}`).Return() - - teardown(h, c, []interface{}{"trades", "orders", "eurusd.updates"}) + c.On("GetPublicSubscriptions").Return([]interface{}{}).Once() + c.On("GetPrivateSubscriptions").Return([]interface{}{}).Once() + c.On("Send", `[1,42,"unsubscribed",["public",[]]]`).Return().Once() + c.On("Send", `[1,42,"unsubscribed",["private",[]]]`).Return().Once() + + r, err = unsubscribe(h, c, 43, []interface{}{"public", []interface{}{"eurusd.updates"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 43, + Method: "unsubscribed", + Args: []interface{}{"public", []interface{}{}}, + }, r) + + r, err = unsubscribe(h, c, 44, []interface{}{"private", []interface{}{"trades", "orders"}}) + assert.NoError(t, err) + assert.Equal(t, &msg.Msg{ + Type: msg.Response, + ReqID: 44, + Method: "unsubscribed", + Args: []interface{}{"private", []interface{}{}}, + }, r) assert.Equal(t, 0, len(h.PublicTopics)) assert.Equal(t, 0, len(h.PrivateTopics)) }) From 2fe6048e16617004b442ada1bab12513d9c2cf57 Mon Sep 17 00:00:00 2001 From: Camille Meulien Date: Tue, 5 May 2020 14:24:34 +0300 Subject: [PATCH 4/5] Update message type codes according to spec changes --- pkg/msg/msg.go | 14 ++++++++------ pkg/msg/msg_test.go | 2 +- pkg/msg/parser.go | 2 +- pkg/msg/parser_test.go | 30 +++++++++++++++--------------- 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index 1f43601..2bdbacc 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -6,14 +6,16 @@ import ( "github.com/rs/zerolog/log" ) -// Request identifier -const Request = 0 +const ( + // Request type code + Request = 1 -// Response identifier -const Response = 1 + // Response type code + Response = 2 -// Event identifier -const Event = 2 + // Event type code + Event = 3 +) // Msg represent websocket messages, it could be either a request, a response or an event type Msg struct { diff --git a/pkg/msg/msg_test.go b/pkg/msg/msg_test.go index 97a2ec2..1c3a62f 100644 --- a/pkg/msg/msg_test.go +++ b/pkg/msg/msg_test.go @@ -14,5 +14,5 @@ func TestEncoding(t *testing.T) { Args: []interface{}{"hello", "there"}, } - assert.Equal(t, `[0,42,"test",["hello","there"]]`, string(msg.Encode())) + assert.Equal(t, `[1,42,"test",["hello","there"]]`, string(msg.Encode())) } diff --git a/pkg/msg/parser.go b/pkg/msg/parser.go index df3c424..5ef14f3 100644 --- a/pkg/msg/parser.go +++ b/pkg/msg/parser.go @@ -88,7 +88,7 @@ func Parse(msg []byte) (*Msg, error) { return nil, fmt.Errorf("failed to parse type: %w", err) } if t != Request && t != Response && t != Event { - return nil, errors.New("message type must be 0, 1 or 2") + return nil, errors.New("message type must be 1, 2 or 3") } reqID, err := ParseUint64(v[1]) diff --git a/pkg/msg/parser_test.go b/pkg/msg/parser_test.go index 56ab0a3..fb0bae6 100644 --- a/pkg/msg/parser_test.go +++ b/pkg/msg/parser_test.go @@ -7,7 +7,7 @@ import ( ) func TestParserSuccess(t *testing.T) { - msg, err := Parse([]byte(`[0,42,"ping",[]]`)) + msg, err := Parse([]byte(`[1,42,"ping",[]]`)) assert.NoError(t, err) assert.Equal(t, &Msg{ @@ -17,7 +17,7 @@ func TestParserSuccess(t *testing.T) { Args: []interface{}{}, }, msg) - msg, err = Parse([]byte(`[1,42,"pong",[]]`)) + msg, err = Parse([]byte(`[2,42,"pong",[]]`)) assert.NoError(t, err) assert.Equal(t, &Msg{ @@ -27,7 +27,7 @@ func TestParserSuccess(t *testing.T) { Args: []interface{}{}, }, msg) - msg, err = Parse([]byte(`[2,42,"temperature",[28.7]]`)) + msg, err = Parse([]byte(`[3,42,"temperature",[28.7]]`)) assert.NoError(t, err) assert.Equal(t, &Msg{ @@ -39,57 +39,57 @@ func TestParserSuccess(t *testing.T) { } func TestParserErrorsMessageLength(t *testing.T) { - msg, err := Parse([]byte(`[0,42,"ping"]`)) + msg, err := Parse([]byte(`[1,42,"ping"]`)) assert.EqualError(t, err, "message must contains 4 elements") assert.Nil(t, msg) } func TestParserErrorsBadJSON(t *testing.T) { - msg, err := Parse([]byte(`[0,42,"ping",[]`)) + msg, err := Parse([]byte(`[1,42,"ping",[]`)) assert.EqualError(t, err, "Could not parse message: unexpected end of JSON input") assert.Nil(t, msg) } func TestParserErrorsType(t *testing.T) { - msg, err := Parse([]byte(`[3,42,"ping",[]]`)) - assert.EqualError(t, err, "message type must be 0, 1 or 2") + msg, err := Parse([]byte(`[4,42,"ping",[]]`)) + assert.EqualError(t, err, "message type must be 1, 2 or 3") assert.Nil(t, msg) - msg, err = Parse([]byte(`[0.1,42,"pong",[]]`)) + msg, err = Parse([]byte(`[1.1,42,"pong",[]]`)) assert.EqualError(t, err, "failed to parse type: expected unsigned integer got: float") assert.Nil(t, msg) - msg, err = Parse([]byte(`["0",42,"pong",[]]`)) + msg, err = Parse([]byte(`["1",42,"pong",[]]`)) assert.EqualError(t, err, "failed to parse type: expected uint8 got: string") assert.Nil(t, msg) } func TestParserErrorsRequestID(t *testing.T) { - msg, err := Parse([]byte(`[0,"42","ping",[]]`)) + msg, err := Parse([]byte(`[1,"42","ping",[]]`)) assert.EqualError(t, err, "failed to parse request ID: expected uint64 got: string") assert.Nil(t, msg) - msg, err = Parse([]byte(`[0,42.1,"ping",[]]`)) + msg, err = Parse([]byte(`[1,42.1,"ping",[]]`)) assert.EqualError(t, err, "failed to parse request ID: expected unsigned integer got: float") assert.Nil(t, msg) } func TestParserErrorsMethod(t *testing.T) { - msg, err := Parse([]byte(`[0,42,51,[]]`)) + msg, err := Parse([]byte(`[1,42,51,[]]`)) assert.EqualError(t, err, "failed to parse method: expected string got: float64") assert.Nil(t, msg) - msg, err = Parse([]byte(`[0,42,true,[]]`)) + msg, err = Parse([]byte(`[1,42,true,[]]`)) assert.EqualError(t, err, "failed to parse method: expected string got: bool") assert.Nil(t, msg) } func TestParserErrorsArgs(t *testing.T) { - msg, err := Parse([]byte(`[0,42,"ping",true]`)) + msg, err := Parse([]byte(`[1,42,"ping",true]`)) assert.EqualError(t, err, "failed to parse arguments: expected array got: bool") assert.Nil(t, msg) - msg, err = Parse([]byte(`[0,42,"ping","hello"]`)) + msg, err = Parse([]byte(`[1,42,"ping","hello"]`)) assert.EqualError(t, err, "failed to parse arguments: expected array got: string") assert.Nil(t, msg) } From f309330fce71082b11d391a88e5ec4380e597b86 Mon Sep 17 00:00:00 2001 From: denisfd Date: Wed, 6 May 2020 15:46:12 +0300 Subject: [PATCH 5/5] Finilize rebase --- pkg/message/msg.go | 26 -------------------------- pkg/msg/msg.go | 3 --- pkg/routing/client.go | 7 ------- pkg/routing/hub.go | 13 +++++++++---- 4 files changed, 9 insertions(+), 40 deletions(-) delete mode 100644 pkg/message/msg.go diff --git a/pkg/message/msg.go b/pkg/message/msg.go deleted file mode 100644 index dd4dfae..0000000 --- a/pkg/message/msg.go +++ /dev/null @@ -1,26 +0,0 @@ -package message - -import ( - "encoding/json" -) - -type Request struct { - Method string - Streams []string -} - -func PackOutgoingResponse(err error, message interface{}) ([]byte, error) { - res := make(map[string]interface{}, 1) - if err != nil { - res["error"] = err.Error() - } else { - res["success"] = message - } - return json.Marshal(res) -} - -func PackOutgoingEvent(channel string, data interface{}) ([]byte, error) { - resp := make(map[string]interface{}, 1) - resp[channel] = data - return json.Marshal(resp) -} diff --git a/pkg/msg/msg.go b/pkg/msg/msg.go index 2bdbacc..369e0b9 100644 --- a/pkg/msg/msg.go +++ b/pkg/msg/msg.go @@ -2,8 +2,6 @@ package msg import ( "encoding/json" - - "github.com/rs/zerolog/log" ) const ( @@ -44,7 +42,6 @@ func (m *Msg) Encode() []byte { m.Args, }) if err != nil { - log.Error().Msgf("Fail to encode Msg %v, %s", m, err.Error()) return []byte{} } return s diff --git a/pkg/routing/client.go b/pkg/routing/client.go index b4578a5..d79f655 100644 --- a/pkg/routing/client.go +++ b/pkg/routing/client.go @@ -87,13 +87,6 @@ func NewClient(hub *Hub, w http.ResponseWriter, r *http.Request) { log.Info().Msgf("New authenticated connection: %s", client.UID) } - hub.handleSubscribe(&Request{ - client: client, - Request: msg.Request{ - Streams: parseStreamsFromURI(r.RequestURI), - }, - }) - metrics.RecordHubClientNew() // Allow collection of memory referenced by the caller by doing all work in diff --git a/pkg/routing/hub.go b/pkg/routing/hub.go index 4db031c..a086a1b 100644 --- a/pkg/routing/hub.go +++ b/pkg/routing/hub.go @@ -304,8 +304,11 @@ func (h *Hub) handleSubscribePublic(client IClient, streams []string) { h.PublicTopics[t] = topic } - topic.subscribe(client) - client.SubscribePublic(t) + if topic.subscribe(client) { + client.SubscribePublic(t) + metrics.RecordHubSubscription("public", t) + } + if isIncrementObject(t) { o, ok := h.IncrementalObjects[t] if ok && o.Snapshot != "" { @@ -332,8 +335,10 @@ func (h *Hub) handleSubscribePrivate(client IClient, uid string, streams []strin uTopics[t] = topic } - topic.subscribe(client) - client.SubscribePrivate(t) + if topic.subscribe(client) { + client.SubscribePrivate(t) + metrics.RecordHubSubscription("private", t) + } } }