Commit 316b155a authored by Joram Wilander's avatar Joram Wilander Committed by Christopher Speller

PLT-3562 Switch websocket over to post-connect authentication (#4327)

* Switch websocket over to post-connect authentication

* Add ability to specify token in websocket js driver, add unit tests

* Temporarily disable client websocket tests until issues are resolved

* Minor refactoring and fix status test

* Add isAuthenticated method to WebConn and minor status updates
parent ef363fd8
......@@ -22,6 +22,11 @@ func TestStatuses(t *testing.T) {
defer WebSocketClient.Close()
WebSocketClient.Listen()
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK {
t.Fatal("should have responded OK to authentication challenge")
}
team := model.Team{DisplayName: "Name", Name: "z-z-" + model.NewId() + "a", Email: "test@nowhere.com", Type: model.TEAM_OPEN}
rteam, _ := Client.CreateTeam(&team)
......@@ -75,7 +80,7 @@ func TestStatuses(t *testing.T) {
}
if status, ok := resp.Data[th.BasicUser2.Id]; !ok {
t.Log(len(resp.Data))
t.Log(resp.Data)
t.Fatal("should have had user status")
} else if status != model.STATUS_ONLINE {
t.Log(status)
......
......@@ -1794,6 +1794,11 @@ func TestUserTyping(t *testing.T) {
defer WebSocketClient.Close()
WebSocketClient.Listen()
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK {
t.Fatal("should have responded OK to authentication challenge")
}
WebSocketClient.UserTyping("", "")
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Error.Id != "api.websocket_handler.invalid_param.app_error" {
......
......@@ -15,9 +15,10 @@ import (
)
const (
WRITE_WAIT = 30 * time.Second
PONG_WAIT = 100 * time.Second
PING_PERIOD = (PONG_WAIT * 6) / 10
WRITE_WAIT = 30 * time.Second
PONG_WAIT = 100 * time.Second
PING_PERIOD = (PONG_WAIT * 6) / 10
AUTH_TIMEOUT = 5 * time.Second
)
type WebConn struct {
......@@ -32,7 +33,9 @@ type WebConn struct {
}
func NewWebConn(c *Context, ws *websocket.Conn) *WebConn {
go SetStatusOnline(c.Session.UserId, c.Session.Id, false)
if len(c.Session.UserId) > 0 {
go SetStatusOnline(c.Session.UserId, c.Session.Id, false)
}
return &WebConn{
Send: make(chan model.WebSocketMessage, 256),
......@@ -53,7 +56,9 @@ func (c *WebConn) readPump() {
c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT))
c.WebSocket.SetPongHandler(func(string) error {
c.WebSocket.SetReadDeadline(time.Now().Add(PONG_WAIT))
go SetStatusAwayIfNeeded(c.UserId, false)
if c.isAuthenticated() {
go SetStatusAwayIfNeeded(c.UserId, false)
}
return nil
})
......@@ -64,7 +69,7 @@ func (c *WebConn) readPump() {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
l4g.Debug(fmt.Sprintf("websocket.read: client side closed socket userId=%v", c.UserId))
} else {
l4g.Debug(fmt.Sprintf("websocket.read: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error()))
l4g.Debug(fmt.Sprintf("websocket.read: closing websocket for userId=%v error=%v", c.UserId, err.Error()))
}
return
......@@ -76,9 +81,11 @@ func (c *WebConn) readPump() {
func (c *WebConn) writePump() {
ticker := time.NewTicker(PING_PERIOD)
authTicker := time.NewTicker(AUTH_TIMEOUT)
defer func() {
ticker.Stop()
authTicker.Stop()
c.WebSocket.Close()
}()
......@@ -97,7 +104,7 @@ func (c *WebConn) writePump() {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
l4g.Debug(fmt.Sprintf("websocket.send: client side closed socket userId=%v", c.UserId))
} else {
l4g.Debug(fmt.Sprintf("websocket.send: cannot send, closing websocket for userId=%v, error=%v", c.UserId, err.Error()))
l4g.Debug(fmt.Sprintf("websocket.send: closing websocket for userId=%v, error=%v", c.UserId, err.Error()))
}
return
......@@ -110,11 +117,18 @@ func (c *WebConn) writePump() {
if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived) {
l4g.Debug(fmt.Sprintf("websocket.ticker: client side closed socket userId=%v", c.UserId))
} else {
l4g.Debug(fmt.Sprintf("websocket.ticker: cannot read, closing websocket for userId=%v error=%v", c.UserId, err.Error()))
l4g.Debug(fmt.Sprintf("websocket.ticker: closing websocket for userId=%v error=%v", c.UserId, err.Error()))
}
return
}
case <-authTicker.C:
if c.SessionToken == "" {
l4g.Debug(fmt.Sprintf("websocket.authTicker: did not authenticate ip=%v", c.WebSocket.RemoteAddr()))
return
}
authTicker.Stop()
}
}
}
......@@ -122,10 +136,18 @@ func (c *WebConn) writePump() {
func (webCon *WebConn) InvalidateCache() {
webCon.AllChannelMembers = nil
webCon.LastAllChannelMembersTime = 0
}
func (webCon *WebConn) isAuthenticated() bool {
return webCon.SessionToken != ""
}
func (webCon *WebConn) ShouldSendEvent(msg *model.WebSocketEvent) bool {
// IMPORTANT: Do not send event if WebConn does not have a session
if !webCon.isAuthenticated() {
return false
}
// If the event is destined to a specific user
if len(msg.Broadcast.UserId) > 0 && webCon.UserId != msg.Broadcast.UserId {
return false
......
......@@ -156,6 +156,10 @@ func (h *Hub) Start() {
close(webCon.Send)
}
if len(userId) == 0 {
continue
}
found := false
for webCon := range h.connections {
if userId == webCon.UserId {
......
......@@ -17,7 +17,7 @@ const (
func InitWebSocket() {
l4g.Debug(utils.T("api.web_socket.init.debug"))
BaseRoutes.Users.Handle("/websocket", ApiUserRequiredTrustRequester(connect)).Methods("GET")
BaseRoutes.Users.Handle("/websocket", ApiAppHandlerTrustRequester(connect)).Methods("GET")
HubStart()
}
......
......@@ -37,6 +37,37 @@ func (wr *WebSocketRouter) ServeWebSocket(conn *WebConn, r *model.WebSocketReque
return
}
if r.Action == model.WEBSOCKET_AUTHENTICATION_CHALLENGE {
token, ok := r.Data["token"].(string)
if !ok {
conn.WebSocket.Close()
return
}
session := GetSession(token)
if session == nil || session.IsExpired() {
conn.WebSocket.Close()
} else {
go SetStatusOnline(session.UserId, session.Id, false)
conn.SessionToken = session.Token
conn.UserId = session.UserId
resp := model.NewWebSocketResponse(model.STATUS_OK, r.Seq, nil)
resp.DoPreComputeJson()
conn.Send <- resp
}
return
}
if conn.SessionToken == "" {
err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.not_authenticated.app_error", nil, "")
wr.ReturnWebSocketError(conn, r, err)
return
}
var handler *webSocketHandler
if h, ok := wr.handlers[r.Action]; !ok {
err := model.NewLocAppError("ServeWebSocket", "api.web_socket_router.bad_action.app_error", nil, "")
......
......@@ -4,12 +4,116 @@
package api
import (
"encoding/json"
"net/http"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/mattermost/platform/model"
)
func TestWebSocketAuthentication(t *testing.T) {
th := Setup().InitBasic()
WebSocketClient, err := th.CreateWebSocketClient()
if err != nil {
t.Fatal(err)
}
WebSocketClient.Listen()
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK {
t.Fatal("should have responded OK to authentication challenge")
}
WebSocketClient.SendMessage("ping", nil)
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Data["text"].(string) != "pong" {
t.Fatal("wrong response")
}
WebSocketClient.Close()
authToken := WebSocketClient.AuthToken
WebSocketClient.AuthToken = "junk"
if err := WebSocketClient.Connect(); err != nil {
t.Fatal(err)
}
WebSocketClient.Listen()
if resp := <-WebSocketClient.ResponseChannel; resp != nil {
t.Fatal("should have closed")
}
WebSocketClient.Close()
if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", nil); err != nil {
t.Fatal("should have connected")
} else {
req := &model.WebSocketRequest{}
req.Seq = 1
req.Action = "ping"
conn.WriteJSON(req)
closedAutomatically := false
hitNotAuthedError := false
go func() {
time.Sleep(10 * time.Second)
conn.Close()
if !closedAutomatically {
t.Fatal("should have closed automatically in 5 seconds")
}
}()
for {
if _, rawMsg, err := conn.ReadMessage(); err != nil {
closedAutomatically = true
conn.Close()
break
} else {
var response model.WebSocketResponse
if err := json.Unmarshal(rawMsg, &response); err != nil && !response.IsValid() {
t.Fatal("should not have failed")
} else {
if response.Error == nil || response.Error.Id != "api.web_socket_router.not_authenticated.app_error" {
t.Log(response.Error.Id)
t.Fatal("wrong error")
continue
}
hitNotAuthedError = true
}
}
}
if !hitNotAuthedError {
t.Fatal("should have received a not authenticated response")
}
}
header := http.Header{}
header.Set(model.HEADER_AUTH, "BEARER "+authToken)
if conn, _, err := websocket.DefaultDialer.Dial(WebSocketClient.ApiUrl+"/users/websocket", header); err != nil {
t.Fatal("should have connected")
} else {
if _, rawMsg, err := conn.ReadMessage(); err != nil {
t.Fatal("should not have closed automatically")
} else {
var event model.WebSocketEvent
if err := json.Unmarshal(rawMsg, &event); err != nil && !event.IsValid() {
t.Fatal("should not have failed")
} else if event.Event != model.WEBSOCKET_EVENT_HELLO {
t.Log(event.ToJson())
t.Fatal("should have helloed")
}
}
conn.Close()
}
}
func TestWebSocket(t *testing.T) {
th := Setup().InitBasic()
WebSocketClient, err := th.CreateWebSocketClient()
......@@ -29,6 +133,9 @@ func TestWebSocket(t *testing.T) {
WebSocketClient.Listen()
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK {
t.Fatal("should have responded OK to authentication challenge")
}
WebSocketClient.SendMessage("ping", nil)
time.Sleep(300 * time.Millisecond)
......@@ -78,6 +185,11 @@ func TestWebSocketEvent(t *testing.T) {
WebSocketClient.Listen()
time.Sleep(300 * time.Millisecond)
if resp := <-WebSocketClient.ResponseChannel; resp.Status != model.STATUS_OK {
t.Fatal("should have responded OK to authentication challenge")
}
omitUser := make(map[string]bool, 1)
omitUser["somerandomid"] = true
evt1 := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_TYPING, "", th.BasicChannel.Id, "", omitUser)
......
......@@ -6,7 +6,6 @@ package model
import (
"encoding/json"
"github.com/gorilla/websocket"
"net/http"
)
type WebSocketClient struct {
......@@ -23,14 +22,12 @@ type WebSocketClient struct {
// NewWebSocketClient constructs a new WebSocket client with convienence
// methods for talking to the server.
func NewWebSocketClient(url, authToken string) (*WebSocketClient, *AppError) {
header := http.Header{}
header.Set(HEADER_AUTH, "BEARER "+authToken)
conn, _, err := websocket.DefaultDialer.Dial(url+API_URL_SUFFIX+"/users/websocket", header)
conn, _, err := websocket.DefaultDialer.Dial(url+API_URL_SUFFIX+"/users/websocket", nil)
if err != nil {
return nil, NewLocAppError("NewWebSocketClient", "model.websocket_client.connect_fail.app_error", nil, err.Error())
}
return &WebSocketClient{
client := &WebSocketClient{
url,
url + API_URL_SUFFIX,
conn,
......@@ -39,19 +36,25 @@ func NewWebSocketClient(url, authToken string) (*WebSocketClient, *AppError) {
make(chan *WebSocketEvent, 100),
make(chan *WebSocketResponse, 100),
nil,
}, nil
}
client.SendMessage(WEBSOCKET_AUTHENTICATION_CHALLENGE, map[string]interface{}{"token": authToken})
return client, nil
}
func (wsc *WebSocketClient) Connect() *AppError {
header := http.Header{}
header.Set(HEADER_AUTH, "BEARER "+wsc.AuthToken)
var err error
wsc.Conn, _, err = websocket.DefaultDialer.Dial(wsc.ApiUrl+"/users/websocket", header)
wsc.Conn, _, err = websocket.DefaultDialer.Dial(wsc.ApiUrl+"/users/websocket", nil)
if err != nil {
return NewLocAppError("NewWebSocketClient", "model.websocket_client.connect_fail.app_error", nil, err.Error())
}
wsc.EventChannel = make(chan *WebSocketEvent, 100)
wsc.ResponseChannel = make(chan *WebSocketResponse, 100)
wsc.SendMessage(WEBSOCKET_AUTHENTICATION_CHALLENGE, map[string]interface{}{"token": wsc.AuthToken})
return nil
}
......@@ -89,6 +92,7 @@ func (wsc *WebSocketClient) Listen() {
wsc.ResponseChannel <- &response
continue
}
}
}()
}
......
......@@ -26,6 +26,7 @@ const (
WEBSOCKET_EVENT_STATUS_CHANGE = "status_change"
WEBSOCKET_EVENT_HELLO = "hello"
WEBSOCKET_EVENT_WEBRTC = "webrtc"
WEBSOCKET_AUTHENTICATION_CHALLENGE = "authentication_challenge"
)
type WebSocketMessage interface {
......
......@@ -18,7 +18,7 @@ export default class WebSocketClient {
this.closeCallback = null;
}
initialize(connectionUrl) {
initialize(connectionUrl, token) {
if (this.conn) {
return;
}
......@@ -30,6 +30,10 @@ export default class WebSocketClient {
this.conn = new WebSocket(connectionUrl);
this.conn.onopen = () => {
if (token) {
this.sendMessage('authentication_challenge', {token});
}
if (this.connectFailCount > 0) {
console.log('websocket re-established connection'); //eslint-disable-line no-console
if (this.reconnectCallback) {
......@@ -68,7 +72,7 @@ export default class WebSocketClient {
setTimeout(
() => {
this.initialize(connectionUrl);
this.initialize(connectionUrl, token);
},
retryTime
);
......@@ -152,12 +156,12 @@ export default class WebSocketClient {
}
}
userTyping(channelId, parentId) {
userTyping(channelId, parentId, callback) {
const data = {};
data.channel_id = channelId;
data.parent_id = parentId;
this.sendMessage('user_typing', data);
this.sendMessage('user_typing', data, callback);
}
getStatuses(callback) {
......
// Copyright (c) 2016 Mattermost, Inc. All Rights Reserved.
// See License.txt for license information.
/*
var assert = require('assert');
import TestHelper from './test_helper.jsx';
describe('Client.WebSocket', function() {
this.timeout(10000);
it('WebSocket.getStatusesByIds', function(done) {
TestHelper.initBasic(() => {
TestHelper.basicWebSocketClient().getStatusesByIds(
[TestHelper.basicUser().id],
function(resp) {
TestHelper.basicWebSocketClient().close();
assert.equal(resp.data[TestHelper.basicUser().id], 'online');
done();
}
);
}, true);
});
it('WebSocket.getStatuses', function(done) {
TestHelper.initBasic(() => {
TestHelper.basicWebSocketClient().getStatuses(
function(resp) {
TestHelper.basicWebSocketClient().close();
assert.equal(resp.data != null, true);
done();
}
);
}, true);
});
it('WebSocket.userTyping', function(done) {
TestHelper.initBasic(() => {
TestHelper.basicWebSocketClient().userTyping(
TestHelper.basicChannel().id,
'',
function(resp) {
TestHelper.basicWebSocketClient().close();
assert.equal(resp.status, 'OK');
done();
}
);
}, true);
});
});*/
......@@ -2,13 +2,20 @@
// See License.txt for license information.
import Client from 'client/client.jsx';
import WebSocketClient from 'client/websocket_client.jsx';
import jqd from 'jquery-deferred';
var HEADER_TOKEN = 'token';
class TestHelperClass {
basicClient = () => {
return this.basicc;
}
basicWebSocketClient = () => {
return this.basicwsc;
}
basicTeam = () => {
return this.basict;
}
......@@ -53,6 +60,12 @@ class TestHelperClass {
return c;
}
createWebSocketClient(token) {
var ws = new WebSocketClient();
ws.initialize('http://localhost:8065/api/v3/users/websocket', token);
return ws;
}
fakeEmail = () => {
return 'success' + this.generateId() + '@simulator.amazonses.com';
}
......@@ -90,7 +103,7 @@ class TestHelperClass {
return post;
}
initBasic = (callback) => {
initBasic = (callback, connectWS) => {
this.basicc = this.createClient();
var d1 = jqd.Deferred();
......@@ -122,7 +135,10 @@ class TestHelperClass {
rteamSignup.user.email,
password,
null,
function() {
function(data, res) {
if (connectWS) {
outer.basicwsc = outer.createWebSocketClient(res.header[HEADER_TOKEN]);
}
outer.basicClient().useHeaderToken();
var channel = outer.fakeChannel();
channel.team_id = outer.basicTeam().id;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment