Commit 15d64fb2 authored by Jesús Espino's avatar Jesús Espino Committed by GitHub

MM-7188: Cleaning push notification on every read, not only on channel switch (#9348)

* MM-7188: Cleaning push notification on every read, not only on channel switch

* Removed unnecesary goroutine

* Fixing tests

* Applying suggestion from PR
parent 37e00ef9
......@@ -1073,7 +1073,7 @@ func addChannelMember(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
cm, err := c.App.AddChannelMember(member.UserId, channel, c.Session.UserId, postRootId)
cm, err := c.App.AddChannelMember(member.UserId, channel, c.Session.UserId, postRootId, !c.Session.IsMobileApp())
if err != nil {
c.Err = err
return
......
......@@ -58,7 +58,7 @@ func createPost(c *Context, w http.ResponseWriter, r *http.Request) {
post.CreateAt = 0
}
rp, err := c.App.CreatePostAsUser(c.App.PostWithProxyRemovedFromImageURLs(post))
rp, err := c.App.CreatePostAsUser(c.App.PostWithProxyRemovedFromImageURLs(post), !c.Session.IsMobileApp())
if err != nil {
c.Err = err
return
......
......@@ -50,6 +50,8 @@ type App struct {
Hubs []*Hub
HubsStopCheckingForDeadlock chan bool
PushNotificationsHub PushNotificationsHub
Jobs *jobs.JobServer
AccountMigration einterfaces.AccountMigrationInterface
......@@ -128,6 +130,9 @@ func New(options ...Option) (outApp *App, outErr error) {
app.HTTPService = MakeHTTPService(app)
app.CreatePushNotificationsHub()
app.StartPushNotificationsHubWorkers()
defer func() {
if outErr != nil {
app.Shutdown()
......@@ -276,6 +281,7 @@ func (a *App) Shutdown() {
a.StopServer()
a.HubStop()
a.StopPushNotificationsHubWorkers()
a.ShutDownPlugins()
a.WaitForGoroutines()
......
......@@ -817,7 +817,7 @@ func (a *App) AddUserToChannel(user *model.User, channel *model.Channel) (*model
return newMember, nil
}
func (a *App) AddChannelMember(userId string, channel *model.Channel, userRequestorId string, postRootId string) (*model.ChannelMember, *model.AppError) {
func (a *App) AddChannelMember(userId string, channel *model.Channel, userRequestorId string, postRootId string, clearPushNotifications bool) (*model.ChannelMember, *model.AppError) {
if result := <-a.Srv.Store.Channel().GetMember(channel.Id, userId); result.Err != nil {
if result.Err.Id != store.MISSING_CHANNEL_MEMBER_ERROR {
return nil, result.Err
......@@ -864,7 +864,7 @@ func (a *App) AddChannelMember(userId string, channel *model.Channel, userReques
}
if userRequestor != nil {
a.UpdateChannelLastViewedAt([]string{channel.Id}, userRequestor.Id)
a.MarkChannelsAsViewed([]string{channel.Id}, userRequestor.Id, clearPushNotifications)
}
return cm, nil
......@@ -1559,6 +1559,55 @@ func (a *App) SearchChannelsUserNotIn(teamId string, userId string, term string)
return result.Data.(*model.ChannelList), nil
}
func (a *App) MarkChannelsAsViewed(channelIds []string, userId string, clearPushNotifications bool) (map[string]int64, *model.AppError) {
// I start looking for channels with notifications before I mark it as read, to clear the push notifications if needed
channelsToClearPushNotifications := []string{}
if *a.Config().EmailSettings.SendPushNotifications && clearPushNotifications {
for _, channelId := range channelIds {
if model.IsValidId(channelId) {
member := (<-a.Srv.Store.Channel().GetMember(channelId, userId)).Data.(*model.ChannelMember)
notify := member.NotifyProps[model.PUSH_NOTIFY_PROP]
if notify == model.CHANNEL_NOTIFY_DEFAULT {
user, _ := a.GetUser(userId)
notify = user.NotifyProps[model.PUSH_NOTIFY_PROP]
}
if notify == model.USER_NOTIFY_ALL {
if result := <-a.Srv.Store.User().GetAnyUnreadPostCountForChannel(userId, channelId); result.Err == nil {
if result.Data.(int64) > 0 {
channelsToClearPushNotifications = append(channelsToClearPushNotifications, channelId)
}
}
} else if notify == model.USER_NOTIFY_MENTION {
if result := <-a.Srv.Store.User().GetUnreadCountForChannel(userId, channelId); result.Err == nil {
if result.Data.(int64) > 0 {
channelsToClearPushNotifications = append(channelsToClearPushNotifications, channelId)
}
}
}
}
}
}
result := <-a.Srv.Store.Channel().UpdateLastViewedAt(channelIds, userId)
if result.Err != nil {
return nil, result.Err
}
times := result.Data.(map[string]int64)
if *a.Config().ServiceSettings.EnableChannelViewedMessages {
for _, channelId := range channelIds {
if model.IsValidId(channelId) {
message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_CHANNEL_VIEWED, "", "", userId, nil)
message.Add("channel_id", channelId)
a.Publish(message)
}
}
}
for _, channelId := range channelsToClearPushNotifications {
a.ClearPushNotification(userId, channelId)
}
return times, nil
}
func (a *App) ViewChannel(view *model.ChannelView, userId string, clearPushNotifications bool) (map[string]int64, *model.AppError) {
if err := a.SetActiveChannel(userId, view.ChannelId); err != nil {
return nil, err
......@@ -1570,45 +1619,15 @@ func (a *App) ViewChannel(view *model.ChannelView, userId string, clearPushNotif
channelIds = append(channelIds, view.ChannelId)
}
var pchan store.StoreChannel
if len(view.PrevChannelId) > 0 {
channelIds = append(channelIds, view.PrevChannelId)
if *a.Config().EmailSettings.SendPushNotifications && clearPushNotifications && len(view.ChannelId) > 0 {
pchan = a.Srv.Store.User().GetUnreadCountForChannel(userId, view.ChannelId)
}
}
if len(channelIds) == 0 {
return map[string]int64{}, nil
}
uchan := a.Srv.Store.Channel().UpdateLastViewedAt(channelIds, userId)
if pchan != nil {
result := <-pchan
if result.Err != nil {
return nil, result.Err
}
if result.Data.(int64) > 0 {
a.ClearPushNotification(userId, view.ChannelId)
}
}
var times map[string]int64
result := <-uchan
if result.Err != nil {
return nil, result.Err
}
times = result.Data.(map[string]int64)
if *a.Config().ServiceSettings.EnableChannelViewedMessages && model.IsValidId(view.ChannelId) {
message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_CHANNEL_VIEWED, "", "", userId, nil)
message.Add("channel_id", view.ChannelId)
a.Publish(message)
}
return times, nil
return a.MarkChannelsAsViewed(channelIds, userId, clearPushNotifications)
}
func (a *App) PermanentDeleteChannel(channel *model.Channel) *model.AppError {
......
......@@ -389,7 +389,7 @@ func TestAddChannelMemberNoUserRequestor(t *testing.T) {
channel := th.createChannel(th.BasicTeam, model.CHANNEL_OPEN)
userRequestorId := ""
postRootId := ""
if _, err := th.App.AddChannelMember(user.Id, channel, userRequestorId, postRootId); err != nil {
if _, err := th.App.AddChannelMember(user.Id, channel, userRequestorId, postRootId, false); err != nil {
t.Fatal("Failed to add user to channel. Error: " + err.Message)
}
......
......@@ -135,7 +135,7 @@ func (me *InviteProvider) DoCommand(a *App, args *model.CommandArgs, message str
}
}
if _, err := a.AddChannelMember(userProfile.Id, channelToJoin, args.Session.UserId, ""); err != nil {
if _, err := a.AddChannelMember(userProfile.Id, channelToJoin, args.Session.UserId, "", !args.Session.IsMobileApp()); err != nil {
return &model.CommandResponse{
Text: args.T("api.command_invite.fail.app_error"),
ResponseType: model.COMMAND_RESPONSE_TYPE_EPHEMERAL,
......
......@@ -5,9 +5,9 @@ package app
import (
"fmt"
"hash/fnv"
"net/http"
"strings"
"time"
"github.com/mattermost/mattermost-server/mlog"
"github.com/mattermost/mattermost-server/model"
......@@ -15,21 +15,42 @@ import (
"github.com/nicksnyder/go-i18n/i18n"
)
func (a *App) sendPushNotification(notification *postNotification, user *model.User, explicitMention, channelWideMention bool, replyToThreadType string) *model.AppError {
channel := notification.channel
post := notification.post
type NotificationType string
cfg := a.Config()
const NOTIFICATION_TYPE_CLEAR NotificationType = "clear"
const NOTIFICATION_TYPE_MESSAGE NotificationType = "message"
var nameFormat string
if result := <-a.Srv.Store.Preference().Get(user.Id, model.PREFERENCE_CATEGORY_DISPLAY_SETTINGS, model.PREFERENCE_NAME_NAME_FORMAT); result.Err != nil {
nameFormat = *a.Config().TeamSettings.TeammateNameDisplay
} else {
nameFormat = result.Data.(model.Preference).Value
}
const PUSH_NOTIFICATION_HUB_WORKERS = 1000
const PUSH_NOTIFICATIONS_HUB_BUFFER_PER_WORKER = 50
channelName := notification.GetChannelName(nameFormat, user.Id)
senderName := notification.GetSenderName(nameFormat, cfg.ServiceSettings.EnablePostUsernameOverride)
type PushNotificationsHub struct {
Channels []chan PushNotification
}
type PushNotification struct {
notificationType NotificationType
userId string
channelId string
post *model.Post
user *model.User
channel *model.Channel
senderName string
channelName string
explicitMention bool
channelWideMention bool
replyToThreadType string
}
func (hub *PushNotificationsHub) GetGoChannelFromUserId(userId string) chan PushNotification {
h := fnv.New32a()
h.Write([]byte(userId))
chanIdx := h.Sum32() % PUSH_NOTIFICATION_HUB_WORKERS
return hub.Channels[chanIdx]
}
func (a *App) sendPushNotificationSync(post *model.Post, user *model.User, channel *model.Channel, channelName string, senderName string,
explicitMention, channelWideMention bool, replyToThreadType string) *model.AppError {
cfg := a.Config()
sessions, err := a.getMobileAppSessions(user.Id)
if err != nil {
......@@ -86,11 +107,7 @@ func (a *App) sendPushNotification(notification *postNotification, user *model.U
mlog.Debug(fmt.Sprintf("Sending push notification to device %v for user %v with msg of '%v'", tmpMessage.DeviceId, user.Id, msg.Message), mlog.String("user_id", user.Id))
a.Go(func(session *model.Session) func() {
return func() {
a.sendToPushProxy(tmpMessage, session)
}
}(session))
a.sendToPushProxy(tmpMessage, session)
if a.Metrics != nil {
a.Metrics.IncrementPostSentPush()
......@@ -100,6 +117,35 @@ func (a *App) sendPushNotification(notification *postNotification, user *model.U
return nil
}
func (a *App) sendPushNotification(notification *postNotification, user *model.User, explicitMention, channelWideMention bool, replyToThreadType string) {
cfg := a.Config()
channel := notification.channel
post := notification.post
var nameFormat string
if result := <-a.Srv.Store.Preference().Get(user.Id, model.PREFERENCE_CATEGORY_DISPLAY_SETTINGS, model.PREFERENCE_NAME_NAME_FORMAT); result.Err != nil {
nameFormat = *a.Config().TeamSettings.TeammateNameDisplay
} else {
nameFormat = result.Data.(model.Preference).Value
}
channelName := notification.GetChannelName(nameFormat, user.Id)
senderName := notification.GetSenderName(nameFormat, cfg.ServiceSettings.EnablePostUsernameOverride)
c := a.PushNotificationsHub.GetGoChannelFromUserId(user.Id)
c <- PushNotification{
notificationType: NOTIFICATION_TYPE_MESSAGE,
post: post,
user: user,
channel: channel,
senderName: senderName,
channelName: channelName,
explicitMention: explicitMention,
channelWideMention: channelWideMention,
replyToThreadType: replyToThreadType,
}
}
func (a *App) getPushNotificationMessage(postMessage string, explicitMention, channelWideMention, hasFiles bool,
senderName, channelName, channelType, replyToThreadType string, userLocale i18n.TranslateFunc) string {
message := ""
......@@ -140,41 +186,85 @@ func (a *App) getPushNotificationMessage(postMessage string, explicitMention, ch
return message
}
func (a *App) ClearPushNotificationSync(userId string, channelId string) {
sessions, err := a.getMobileAppSessions(userId)
if err != nil {
mlog.Error(err.Error())
return
}
msg := model.PushNotification{}
msg.Type = model.PUSH_TYPE_CLEAR
msg.ChannelId = channelId
msg.ContentAvailable = 0
if badge := <-a.Srv.Store.User().GetUnreadCount(userId); badge.Err != nil {
msg.Badge = 0
mlog.Error(fmt.Sprint("We could not get the unread message count for the user", userId, badge.Err), mlog.String("user_id", userId))
} else {
msg.Badge = int(badge.Data.(int64))
}
mlog.Debug(fmt.Sprintf("Clearing push notification to %v with channel_id %v", msg.DeviceId, msg.ChannelId))
for _, session := range sessions {
tmpMessage := *model.PushNotificationFromJson(strings.NewReader(msg.ToJson()))
tmpMessage.SetDeviceIdAndPlatform(session.DeviceId)
a.sendToPushProxy(tmpMessage, session)
}
}
func (a *App) ClearPushNotification(userId string, channelId string) {
a.Go(func() {
// Sleep is to allow the read replicas a chance to fully sync
// the unread count for sending an accurate count.
// Delaying a little doesn't hurt anything and is cheaper than
// attempting to read from master.
time.Sleep(time.Second * 5)
sessions, err := a.getMobileAppSessions(userId)
if err != nil {
mlog.Error(err.Error())
return
}
channel := a.PushNotificationsHub.GetGoChannelFromUserId(userId)
channel <- PushNotification{
notificationType: NOTIFICATION_TYPE_CLEAR,
userId: userId,
channelId: channelId,
}
}
msg := model.PushNotification{}
msg.Type = model.PUSH_TYPE_CLEAR
msg.ChannelId = channelId
msg.ContentAvailable = 0
if badge := <-a.Srv.Store.User().GetUnreadCount(userId); badge.Err != nil {
msg.Badge = 0
mlog.Error(fmt.Sprint("We could not get the unread message count for the user", userId, badge.Err), mlog.String("user_id", userId))
} else {
msg.Badge = int(badge.Data.(int64))
func (a *App) CreatePushNotificationsHub() {
hub := PushNotificationsHub{
Channels: []chan PushNotification{},
}
for x := 0; x < PUSH_NOTIFICATION_HUB_WORKERS; x++ {
hub.Channels = append(hub.Channels, make(chan PushNotification, PUSH_NOTIFICATIONS_HUB_BUFFER_PER_WORKER))
}
a.PushNotificationsHub = hub
}
func (a *App) pushNotificationWorker(notifications chan PushNotification) {
for notification := range notifications {
switch notification.notificationType {
case NOTIFICATION_TYPE_CLEAR:
a.ClearPushNotificationSync(notification.userId, notification.channelId)
case NOTIFICATION_TYPE_MESSAGE:
a.sendPushNotificationSync(
notification.post,
notification.user,
notification.channel,
notification.channelName,
notification.senderName,
notification.explicitMention,
notification.channelWideMention,
notification.replyToThreadType,
)
default:
mlog.Error(fmt.Sprintf("Invalid notification type %v", notification.notificationType))
}
}
}
mlog.Debug(fmt.Sprintf("Clearing push notification to %v with channel_id %v", msg.DeviceId, msg.ChannelId))
func (a *App) StartPushNotificationsHubWorkers() {
for x := 0; x < PUSH_NOTIFICATION_HUB_WORKERS; x++ {
channel := a.PushNotificationsHub.Channels[x]
a.Go(func() { a.pushNotificationWorker(channel) })
}
}
for _, session := range sessions {
tmpMessage := *model.PushNotificationFromJson(strings.NewReader(msg.ToJson()))
tmpMessage.SetDeviceIdAndPlatform(session.DeviceId)
a.Go(func() {
a.sendToPushProxy(tmpMessage, session)
})
}
})
func (a *App) StopPushNotificationsHubWorkers() {
for _, channel := range a.PushNotificationsHub.Channels {
close(channel)
}
}
func (a *App) sendToPushProxy(msg model.PushNotification, session *model.Session) {
......
......@@ -251,7 +251,7 @@ func (api *PluginAPI) AddChannelMember(channelId, userId string) (*model.Channel
return nil, err
}
return api.app.AddChannelMember(userId, channel, userRequestorId, postRootId)
return api.app.AddChannelMember(userId, channel, userRequestorId, postRootId, false)
}
func (api *PluginAPI) GetChannelMember(channelId, userId string) (*model.ChannelMember, *model.AppError) {
......
......@@ -24,7 +24,7 @@ import (
"golang.org/x/net/html/charset"
)
func (a *App) CreatePostAsUser(post *model.Post) (*model.Post, *model.AppError) {
func (a *App) CreatePostAsUser(post *model.Post, clearPushNotifications bool) (*model.Post, *model.AppError) {
// Check that channel has not been deleted
var channel *model.Channel
if result := <-a.Srv.Store.Channel().Get(post.ChannelId, true); result.Err != nil {
......@@ -78,14 +78,8 @@ func (a *App) CreatePostAsUser(post *model.Post) (*model.Post, *model.AppError)
} else {
// Update the LastViewAt only if the post does not have from_webhook prop set (eg. Zapier app)
if _, ok := post.Props["from_webhook"]; !ok {
if result := <-a.Srv.Store.Channel().UpdateLastViewedAt([]string{post.ChannelId}, post.UserId); result.Err != nil {
mlog.Error(fmt.Sprintf("Encountered error updating last viewed, channel_id=%s, user_id=%s, err=%v", post.ChannelId, post.UserId, result.Err))
}
if *a.Config().ServiceSettings.EnableChannelViewedMessages {
message := model.NewWebSocketEvent(model.WEBSOCKET_EVENT_CHANNEL_VIEWED, "", "", post.UserId, nil)
message.Add("channel_id", post.ChannelId)
a.Publish(message)
if _, err := a.MarkChannelsAsViewed([]string{post.ChannelId}, post.UserId, clearPushNotifications); err != nil {
mlog.Error(fmt.Sprintf("Encountered error updating last viewed, channel_id=%s, user_id=%s, err=%v", post.ChannelId, post.UserId, err))
}
}
......
......@@ -115,7 +115,7 @@ func TestPostReplyToPostWhereRootPosterLeftChannel(t *testing.T) {
CreateAt: 0,
}
if _, err := th.App.CreatePostAsUser(&replyPost); err != nil {
if _, err := th.App.CreatePostAsUser(&replyPost, false); err != nil {
t.Fatal(err)
}
}
......@@ -175,7 +175,7 @@ func TestPostAction(t *testing.T) {
},
}
post, err := th.App.CreatePostAsUser(&interactivePost)
post, err := th.App.CreatePostAsUser(&interactivePost, false)
require.Nil(t, err)
attachments, ok := post.Props["attachments"].([]*model.SlackAttachment)
......@@ -212,7 +212,7 @@ func TestPostAction(t *testing.T) {
},
}
post2, err := th.App.CreatePostAsUser(&menuPost)
post2, err := th.App.CreatePostAsUser(&menuPost, false)
require.Nil(t, err)
attachments2, ok := post2.Props["attachments"].([]*model.SlackAttachment)
......@@ -267,7 +267,7 @@ func TestPostAction(t *testing.T) {
},
}
postplugin, err := th.App.CreatePostAsUser(&interactivePostPlugin)
postplugin, err := th.App.CreatePostAsUser(&interactivePostPlugin, false)
require.Nil(t, err)
attachmentsPlugin, ok := postplugin.Props["attachments"].([]*model.SlackAttachment)
......@@ -308,7 +308,7 @@ func TestPostAction(t *testing.T) {
},
}
postSiteURL, err := th.App.CreatePostAsUser(&interactivePostSiteURL)
postSiteURL, err := th.App.CreatePostAsUser(&interactivePostSiteURL, false)
require.Nil(t, err)
attachmentsSiteURL, ok := postSiteURL.Props["attachments"].([]*model.SlackAttachment)
......@@ -350,7 +350,7 @@ func TestPostAction(t *testing.T) {
},
}
postSubpath, err := th.App.CreatePostAsUser(&interactivePostSubpath)
postSubpath, err := th.App.CreatePostAsUser(&interactivePostSubpath, false)
require.Nil(t, err)
attachmentsSubpath, ok := postSubpath.Props["attachments"].([]*model.SlackAttachment)
......@@ -389,7 +389,7 @@ func TestPostChannelMentions(t *testing.T) {
CreateAt: 0,
}
result, err := th.App.CreatePostAsUser(post)
result, err := th.App.CreatePostAsUser(post, false)
require.Nil(t, err)
assert.Equal(t, map[string]interface{}{
"mention-test": map[string]interface{}{
......
......@@ -949,6 +949,16 @@ func (us SqlUserStore) GetUnreadCountForChannel(userId string, channelId string)
})
}
func (us SqlUserStore) GetAnyUnreadPostCountForChannel(userId string, channelId string) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
if count, err := us.GetReplica().SelectInt("SELECT SUM(c.TotalMsgCount - cm.MsgCount) FROM Channels c INNER JOIN ChannelMembers cm ON c.Id = :ChannelId AND cm.ChannelId = :ChannelId AND cm.UserId = :UserId", map[string]interface{}{"ChannelId": channelId, "UserId": userId}); err != nil {
result.Err = model.NewAppError("SqlUserStore.GetMentionCountForChannel", "store.sql_user.get_unread_count_for_channel.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
result.Data = count
}
})
}
func (us SqlUserStore) Search(teamId string, term string, options map[string]bool) store.StoreChannel {
return store.Do(func(result *store.StoreResult) {
searchQuery := ""
......
......@@ -265,6 +265,7 @@ type UserStore interface {
AnalyticsActiveCount(time int64) StoreChannel
GetUnreadCount(userId string) StoreChannel
GetUnreadCountForChannel(userId string, channelId string) StoreChannel
GetAnyUnreadPostCountForChannel(userId string, channelId string) StoreChannel
GetRecentlyActiveUsersForTeam(teamId string, offset, limit int) StoreChannel
GetNewUsersForTeam(teamId string, offset, limit int) StoreChannel
Search(teamId string, term string, options map[string]bool) StoreChannel
......
......@@ -194,6 +194,22 @@ func (_m *UserStore) GetAllUsingAuthService(authService string) store.StoreChann
return r0
}
// GetAnyUnreadPostCountForChannel provides a mock function with given fields: userId, channelId
func (_m *UserStore) GetAnyUnreadPostCountForChannel(userId string, channelId string) store.StoreChannel {
ret := _m.Called(userId, channelId)
var r0 store.StoreChannel
if rf, ok := ret.Get(0).(func(string, string) store.StoreChannel); ok {
r0 = rf(userId, channelId)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(store.StoreChannel)
}
}
return r0
}
// GetByAuth provides a mock function with given fields: authData, authService
func (_m *UserStore) GetByAuth(authData *string, authService string) store.StoreChannel {
ret := _m.Called(authData, authService)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment