Unverified Commit 0027d998 authored by Harrison Healey's avatar Harrison Healey Committed by GitHub

MM-11855 Add App.HTTPService to allow mocking of HTTP client (#9359)

* MM-11855 Add App.HTTPService to allow mocking of HTTP client

* Initialize HTTPService earlier
parent 29100070
......@@ -7,12 +7,10 @@ import (
"crypto/ecdsa"
"fmt"
"html/template"
"net"
"net/http"
"path"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
......@@ -100,6 +98,8 @@ type App struct {
diagnosticId string
phase2PermissionsMigrationComplete bool
HTTPService HTTPService
}
var appCount = 0
......@@ -125,6 +125,9 @@ func New(options ...Option) (outApp *App, outErr error) {
clientConfig: make(map[string]string),
licenseListeners: map[string]func(){},
}
app.HTTPService = MakeHTTPService(app)
defer func() {
if outErr != nil {
app.Shutdown()
......@@ -285,6 +288,8 @@ func (a *App) Shutdown() {
mlog.Info("Server stopped")
a.DisableConfigWatch()
a.HTTPService.Close()
}
var accountMigrationInterface func(*App) einterfaces.AccountMigrationInterface
......@@ -505,43 +510,6 @@ func (a *App) HTMLTemplates() *template.Template {
return nil
}
func (a *App) HTTPClient(trustURLs bool) *http.Client {
insecure := a.Config().ServiceSettings.EnableInsecureOutgoingConnections != nil && *a.Config().ServiceSettings.EnableInsecureOutgoingConnections
if trustURLs {
return utils.NewHTTPClient(insecure, nil, nil)
}
allowHost := func(host string) bool {
if a.Config().ServiceSettings.AllowedUntrustedInternalConnections == nil {
return false
}
for _, allowed := range strings.Fields(*a.Config().ServiceSettings.AllowedUntrustedInternalConnections) {
if host == allowed {
return true
}
}
return false
}
allowIP := func(ip net.IP) bool {
if !utils.IsReservedIP(ip) {
return true
}
if a.Config().ServiceSettings.AllowedUntrustedInternalConnections == nil {
return false
}
for _, allowed := range strings.Fields(*a.Config().ServiceSettings.AllowedUntrustedInternalConnections) {
if _, ipRange, err := net.ParseCIDR(allowed); err == nil && ipRange.Contains(ip) {
return true
}
}
return false
}
return utils.NewHTTPClient(insecure, allowHost, allowIP)
}
func (a *App) Handle404(w http.ResponseWriter, r *http.Request) {
err := model.NewAppError("Handle404", "api.context.404.app_error", nil, "", http.StatusNotFound)
......
......@@ -6,6 +6,8 @@ package app
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"time"
......@@ -33,6 +35,8 @@ type TestHelper struct {
tempConfigPath string
tempWorkspace string
MockedHTTPService *MockedHTTPService
}
type persistentTestStore struct {
......@@ -163,6 +167,13 @@ func (me *TestHelper) InitSystemAdmin() *TestHelper {
return me
}
func (me *TestHelper) MockHTTPService(handler http.Handler) *TestHelper {
me.MockedHTTPService = MakeMockedHTTPService(handler)
me.App.HTTPService = me.MockedHTTPService
return me
}
func (me *TestHelper) MakeEmail() string {
return "success_" + model.NewId() + "@simulator.amazonses.com"
}
......@@ -503,3 +514,22 @@ func (me *FakeClusterInterface) sendClearRoleCacheMessage() {
Event: model.CLUSTER_EVENT_INVALIDATE_CACHE_FOR_ROLES,
})
}
type MockedHTTPService struct {
Server *httptest.Server
}
func MakeMockedHTTPService(handler http.Handler) *MockedHTTPService {
return &MockedHTTPService{
Server: httptest.NewServer(handler),
}
}
func (h *MockedHTTPService) MakeClient(trustURLs bool) *http.Client {
return h.Server.Client()
}
func (h *MockedHTTPService) Close() {
h.Server.CloseClientConnections()
h.Server.Close()
}
......@@ -244,7 +244,7 @@ func (a *App) ExecuteCommand(args *model.CommandArgs) (*model.CommandResponse, *
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
if resp, err := a.HTTPClient(false).Do(req); err != nil {
if resp, err := a.HTTPService.MakeClient(false).Do(req); err != nil {
return nil, model.NewAppError("command", "api.command.execute_command.failed.app_error", map[string]interface{}{"Trigger": trigger}, err.Error(), http.StatusInternalServerError)
} else {
if resp.StatusCode == http.StatusOK {
......
// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
// See License.txt for license information.
package app
import (
"net"
"net/http"
"strings"
"github.com/mattermost/mattermost-server/utils"
)
// Wraps the functionality for creating a new http.Client to encapsulate that and allow it to be mocked when testing
type HTTPService interface {
MakeClient(trustURLs bool) *http.Client
Close()
}
type HTTPServiceImpl struct {
app *App
}
func MakeHTTPService(app *App) HTTPService {
return &HTTPServiceImpl{app}
}
func (h *HTTPServiceImpl) MakeClient(trustURLs bool) *http.Client {
insecure := h.app.Config().ServiceSettings.EnableInsecureOutgoingConnections != nil && *h.app.Config().ServiceSettings.EnableInsecureOutgoingConnections
if trustURLs {
return utils.NewHTTPClient(insecure, nil, nil)
}
allowHost := func(host string) bool {
if h.app.Config().ServiceSettings.AllowedUntrustedInternalConnections == nil {
return false
}
for _, allowed := range strings.Fields(*h.app.Config().ServiceSettings.AllowedUntrustedInternalConnections) {
if host == allowed {
return true
}
}
return false
}
allowIP := func(ip net.IP) bool {
if !utils.IsReservedIP(ip) {
return true
}
if h.app.Config().ServiceSettings.AllowedUntrustedInternalConnections == nil {
return false
}
for _, allowed := range strings.Fields(*h.app.Config().ServiceSettings.AllowedUntrustedInternalConnections) {
if _, ipRange, err := net.ParseCIDR(allowed); err == nil && ipRange.Contains(ip) {
return true
}
}
return false
}
return utils.NewHTTPClient(insecure, allowHost, allowIP)
}
func (h *HTTPServiceImpl) Close() {
// Does nothing, but allows this to be overridden when mocking the service
}
package app
import (
"io/ioutil"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMockHTTPService(t *testing.T) {
getCalled := false
putCalled := false
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/get" && r.Method == http.MethodGet {
getCalled = true
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
} else if r.URL.Path == "/put" && r.Method == http.MethodPut {
putCalled = true
w.WriteHeader(http.StatusCreated)
w.Write([]byte("CREATED"))
} else {
w.WriteHeader(http.StatusNotFound)
}
})
th := Setup().MockHTTPService(handler)
defer th.TearDown()
url := th.MockedHTTPService.Server.URL
t.Run("GET", func(t *testing.T) {
client := th.App.HTTPService.MakeClient(false)
resp, err := client.Get(url + "/get")
defer consumeAndClose(resp)
bodyContents, _ := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "OK", string(bodyContents))
assert.True(t, getCalled)
})
t.Run("PUT", func(t *testing.T) {
client := th.App.HTTPService.MakeClient(false)
request, _ := http.NewRequest(http.MethodPut, url+"/put", nil)
resp, err := client.Do(request)
defer consumeAndClose(resp)
bodyContents, _ := ioutil.ReadAll(resp.Body)
require.Nil(t, err)
assert.Equal(t, http.StatusCreated, resp.StatusCode)
assert.Equal(t, "CREATED", string(bodyContents))
assert.True(t, putCalled)
})
}
......@@ -184,7 +184,7 @@ func (a *App) sendToPushProxy(msg model.PushNotification, session *model.Session
request, _ := http.NewRequest("POST", strings.TrimRight(*a.Config().EmailSettings.PushNotificationServer, "/")+model.API_URL_SUFFIX_V1+"/send_push", strings.NewReader(msg.ToJson()))
if resp, err := a.HTTPClient(true).Do(request); err != nil {
if resp, err := a.HTTPService.MakeClient(true).Do(request); err != nil {
mlog.Error(fmt.Sprintf("Device push reported as error for UserId=%v SessionId=%v message=%v", session.UserId, session.Id, err.Error()), mlog.String("user_id", session.UserId))
} else {
pushResponse := model.PushResponseFromJson(resp.Body)
......
......@@ -761,7 +761,7 @@ func (a *App) AuthorizeOAuthUser(w http.ResponseWriter, r *http.Request, service
var ar *model.AccessResponse
var bodyBytes []byte
if resp, err := a.HTTPClient(true).Do(req); err != nil {
if resp, err := a.HTTPService.MakeClient(true).Do(req); err != nil {
return nil, "", stateProps, model.NewAppError("AuthorizeOAuthUser", "api.user.authorize_oauth_user.token_failed.app_error", nil, err.Error(), http.StatusInternalServerError)
} else {
bodyBytes, _ = ioutil.ReadAll(resp.Body)
......@@ -791,7 +791,7 @@ func (a *App) AuthorizeOAuthUser(w http.ResponseWriter, r *http.Request, service
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+ar.AccessToken)
if resp, err := a.HTTPClient(true).Do(req); err != nil {
if resp, err := a.HTTPService.MakeClient(true).Do(req); err != nil {
return nil, "", stateProps, model.NewAppError("AuthorizeOAuthUser", "api.user.authorize_oauth_user.service.app_error", map[string]interface{}{"Service": service}, err.Error(), http.StatusInternalServerError)
} else {
bodyBytes, _ = ioutil.ReadAll(resp.Body)
......
......@@ -777,7 +777,7 @@ func (a *App) GetFileInfosForPost(postId string, readFromMaster bool) ([]*model.
func (a *App) GetOpenGraphMetadata(requestURL string) *opengraph.OpenGraph {
og := opengraph.NewOpenGraph()
res, err := a.HTTPClient(false).Get(requestURL)
res, err := a.HTTPService.MakeClient(false).Get(requestURL)
if err != nil {
mlog.Error(fmt.Sprintf("GetOpenGraphMetadata request failed for url=%v with err=%v", requestURL, err.Error()))
return og
......@@ -890,9 +890,9 @@ func (a *App) DoPostAction(postId, actionId, userId, selectedOption string) *mod
siteURL, _ := url.Parse(*a.Config().ServiceSettings.SiteURL)
subpath, _ := utils.GetSubpathFromConfig(a.Config())
if (url.Hostname() == "localhost" || url.Hostname() == "127.0.0.1" || url.Hostname() == siteURL.Hostname()) && strings.HasPrefix(url.Path, path.Join(subpath, "plugins")) {
httpClient = a.HTTPClient(true)
httpClient = a.HTTPService.MakeClient(true)
} else {
httpClient = a.HTTPClient(false)
httpClient = a.HTTPService.MakeClient(false)
}
resp, err := httpClient.Do(req)
......
......@@ -107,7 +107,7 @@ func (a *App) TriggerWebhook(payload *model.OutgoingWebhookPayload, hook *model.
req, _ := http.NewRequest("POST", url, body)
req.Header.Set("Content-Type", contentType)
req.Header.Set("Accept", "application/json")
if resp, err := a.HTTPClient(false).Do(req); err != nil {
if resp, err := a.HTTPService.MakeClient(false).Do(req); err != nil {
mlog.Error(fmt.Sprintf("Event POST failed, err=%s", err.Error()))
} else {
defer consumeAndClose(resp)
......
......@@ -59,7 +59,7 @@ func (a *App) GetWebrtcToken(sessionId string) (string, *model.AppError) {
rq, _ := http.NewRequest("POST", *a.Config().WebrtcSettings.GatewayAdminUrl, strings.NewReader(model.MapToJson(data)))
rq.Header.Set("Content-Type", "application/json")
if rp, err := a.HTTPClient(true).Do(rq); err != nil {
if rp, err := a.HTTPService.MakeClient(true).Do(rq); err != nil {
return "", model.NewAppError("WebRTC.Token", "model.client.connecting.app_error", nil, err.Error(), http.StatusInternalServerError)
} else if rp.StatusCode >= 300 {
defer consumeAndClose(rp)
......@@ -93,5 +93,5 @@ func (a *App) RevokeWebrtcToken(sessionId string) {
rq.Header.Set("Content-Type", "application/json")
// we do not care about the response
a.HTTPClient(true).Do(rq)
a.HTTPService.MakeClient(true).Do(rq)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment