Unverified Commit 55261a2b authored by Chris's avatar Chris Committed by GitHub

fix scheduled task race (#8355)

parent 5fb4b1bc
......@@ -7,6 +7,7 @@ import (
"fmt"
"html/template"
"strconv"
"sync"
"time"
"github.com/mattermost/mattermost-server/model"
......@@ -57,6 +58,8 @@ type EmailBatchingJob struct {
app *App
newNotifications chan *batchedNotification
pendingNotifications map[string][]*batchedNotification
task *model.ScheduledTask
taskMutex sync.Mutex
}
func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob {
......@@ -68,12 +71,17 @@ func NewEmailBatchingJob(a *App, bufferSize int) *EmailBatchingJob {
}
func (job *EmailBatchingJob) Start() {
if task := model.GetTaskByName(EMAIL_BATCHING_TASK_NAME); task != nil {
task.Cancel()
}
l4g.Debug(utils.T("api.email_batching.start.starting"), *job.app.Config().EmailSettings.EmailBatchingInterval)
model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second)
newTask := model.CreateRecurringTask(EMAIL_BATCHING_TASK_NAME, job.CheckPendingEmails, time.Duration(*job.app.Config().EmailSettings.EmailBatchingInterval)*time.Second)
job.taskMutex.Lock()
oldTask := job.task
job.task = newTask
job.taskMutex.Unlock()
if oldTask != nil {
oldTask.Cancel()
}
}
func (job *EmailBatchingJob) Add(user *model.User, post *model.Post, team *model.Team) bool {
......
......@@ -5,7 +5,6 @@ package model
import (
"fmt"
"sync"
"time"
)
......@@ -15,89 +14,57 @@ type ScheduledTask struct {
Name string `json:"name"`
Interval time.Duration `json:"interval"`
Recurring bool `json:"recurring"`
function TaskFunc
timer *time.Timer
}
var taskMutex = sync.Mutex{}
var tasks = make(map[string]*ScheduledTask)
func addTask(task *ScheduledTask) {
taskMutex.Lock()
defer taskMutex.Unlock()
tasks[task.Name] = task
}
func removeTaskByName(name string) {
taskMutex.Lock()
defer taskMutex.Unlock()
delete(tasks, name)
}
func GetTaskByName(name string) *ScheduledTask {
taskMutex.Lock()
defer taskMutex.Unlock()
if task, ok := tasks[name]; ok {
return task
}
return nil
}
func GetAllTasks() *map[string]*ScheduledTask {
taskMutex.Lock()
defer taskMutex.Unlock()
return &tasks
function func()
cancel chan struct{}
cancelled chan struct{}
}
func CreateTask(name string, function TaskFunc, timeToExecution time.Duration) *ScheduledTask {
task := &ScheduledTask{
Name: name,
Interval: timeToExecution,
Recurring: false,
function: function,
}
taskRunner := func() {
go task.function()
removeTaskByName(task.Name)
}
task.timer = time.AfterFunc(timeToExecution, taskRunner)
addTask(task)
return task
return createTask(name, function, timeToExecution, false)
}
func CreateRecurringTask(name string, function TaskFunc, interval time.Duration) *ScheduledTask {
return createTask(name, function, interval, true)
}
func createTask(name string, function TaskFunc, interval time.Duration, recurring bool) *ScheduledTask {
task := &ScheduledTask{
Name: name,
Interval: interval,
Recurring: true,
Recurring: recurring,
function: function,
cancel: make(chan struct{}),
cancelled: make(chan struct{}),
}
taskRecurer := func() {
go task.function()
task.timer.Reset(task.Interval)
}
go func() {
defer close(task.cancelled)
task.timer = time.AfterFunc(interval, taskRecurer)
ticker := time.NewTicker(interval)
defer func() {
ticker.Stop()
}()
addTask(task)
for {
select {
case <-ticker.C:
function()
case <-task.cancel:
return
}
if !task.Recurring {
break
}
}
}()
return task
}
func (task *ScheduledTask) Cancel() {
task.timer.Stop()
removeTaskByName(task.Name)
}
// Executes the task immediatly. A recurring task will be run regularally after interval.
func (task *ScheduledTask) Execute() {
task.function()
task.timer.Reset(task.Interval)
close(task.cancel)
<-task.cancelled
}
func (task *ScheduledTask) String() string {
......
......@@ -4,185 +4,72 @@
package model
import (
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestCreateTask(t *testing.T) {
TASK_NAME := "Test Task"
TASK_TIME := time.Second * 3
TASK_TIME := time.Second * 2
testValue := 0
executionCount := new(int32)
testFunc := func() {
testValue = 1
atomic.AddInt32(executionCount, 1)
}
task := CreateTask(TASK_NAME, testFunc, TASK_TIME)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME + time.Second)
if testValue != 1 {
t.Fatal("Task did not execute")
}
if task.Name != TASK_NAME {
t.Fatal("Bad name")
}
if task.Interval != TASK_TIME {
t.Fatal("Bad interval")
}
if task.Recurring {
t.Fatal("should not reccur")
}
assert.EqualValues(t, 1, atomic.LoadInt32(executionCount))
assert.Equal(t, TASK_NAME, task.Name)
assert.Equal(t, TASK_TIME, task.Interval)
assert.False(t, task.Recurring)
}
func TestCreateRecurringTask(t *testing.T) {
TASK_NAME := "Test Recurring Task"
TASK_TIME := time.Second * 3
TASK_TIME := time.Second * 2
testValue := 0
executionCount := new(int32)
testFunc := func() {
testValue += 1
atomic.AddInt32(executionCount, 1)
}
task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME + time.Second)
if testValue != 1 {
t.Fatal("Task did not execute")
}
assert.EqualValues(t, 1, atomic.LoadInt32(executionCount))
time.Sleep(TASK_TIME)
if testValue != 2 {
t.Fatal("Task did not re-execute")
}
if task.Name != TASK_NAME {
t.Fatal("Bad name")
}
if task.Interval != TASK_TIME {
t.Fatal("Bad interval")
}
if !task.Recurring {
t.Fatal("should reccur")
}
assert.EqualValues(t, 2, atomic.LoadInt32(executionCount))
assert.Equal(t, TASK_NAME, task.Name)
assert.Equal(t, TASK_TIME, task.Interval)
assert.True(t, task.Recurring)
task.Cancel()
}
func TestCancelTask(t *testing.T) {
TASK_NAME := "Test Task"
TASK_TIME := time.Second * 3
TASK_TIME := time.Second
testValue := 0
executionCount := new(int32)
testFunc := func() {
testValue = 1
atomic.AddInt32(executionCount, 1)
}
task := CreateTask(TASK_NAME, testFunc, TASK_TIME)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
task.Cancel()
time.Sleep(TASK_TIME + time.Second)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
}
func TestGetAllTasks(t *testing.T) {
doNothing := func() {}
CreateTask("Task1", doNothing, time.Hour)
CreateTask("Task2", doNothing, time.Second)
CreateRecurringTask("Task3", doNothing, time.Second)
task4 := CreateRecurringTask("Task4", doNothing, time.Second)
task4.Cancel()
time.Sleep(time.Second * 3)
tasks := *GetAllTasks()
if len(tasks) != 2 {
t.Fatal("Wrong number of tasks got: ", len(tasks))
}
for _, task := range tasks {
if task.Name != "Task1" && task.Name != "Task3" {
t.Fatal("Wrong tasks")
}
}
}
func TestExecuteTask(t *testing.T) {
TASK_NAME := "Test Task"
TASK_TIME := time.Second * 5
testValue := 0
testFunc := func() {
testValue += 1
}
task := CreateTask(TASK_NAME, testFunc, TASK_TIME)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
task.Execute()
if testValue != 1 {
t.Fatal("Task did not execute")
}
time.Sleep(TASK_TIME + time.Second)
if testValue != 2 {
t.Fatal("Task re-executed")
}
}
func TestExecuteTaskRecurring(t *testing.T) {
TASK_NAME := "Test Recurring Task"
TASK_TIME := time.Second * 5
testValue := 0
testFunc := func() {
testValue += 1
}
task := CreateRecurringTask(TASK_NAME, testFunc, TASK_TIME)
if testValue != 0 {
t.Fatal("Unexpected execuition of task")
}
time.Sleep(time.Second * 3)
task.Execute()
if testValue != 1 {
t.Fatal("Task did not execute")
}
time.Sleep(time.Second * 3)
if testValue != 1 {
t.Fatal("Task should not have executed before 5 seconds")
}
time.Sleep(time.Second * 3)
if testValue != 2 {
t.Fatal("Task did not re-execute after forced execution")
}
assert.EqualValues(t, 0, atomic.LoadInt32(executionCount))
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment