Skip to content

Commit

Permalink
fix: auto redirect to summary page for any authentication mechanism (r…
Browse files Browse the repository at this point in the history
…esolve #589)
  • Loading branch information
muety committed Jan 6, 2024
1 parent fae0128 commit f04508c
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 5 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func main() {
settingsHandler := routes.NewSettingsHandler(userService, heartbeatService, summaryService, aliasService, aggregationService, languageMappingService, projectLabelService, keyValueService, mailService)
subscriptionHandler := routes.NewSubscriptionHandler(userService, mailService, keyValueService)
projectsHandler := routes.NewProjectsHandler(userService, heartbeatService)
homeHandler := routes.NewHomeHandler(keyValueService)
homeHandler := routes.NewHomeHandler(userService, keyValueService)
loginHandler := routes.NewLoginHandler(userService, mailService)
imprintHandler := routes.NewImprintHandler(keyValueService)
leaderboardHandler := condition.TernaryOperator[bool, routes.Handler](config.App.LeaderboardEnabled, routes.NewLeaderboardHandler(userService, leaderboardService), routes.NewNoopHandler())
Expand Down
35 changes: 35 additions & 0 deletions mocks/key_value_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package mocks

import (
"github.com/muety/wakapi/models"
"github.com/stretchr/testify/mock"
)

type KeyValueServiceMock struct {
mock.Mock
}

func (m *KeyValueServiceMock) GetString(s string) (*models.KeyStringValue, error) {
args := m.Called(s)
return args.Get(0).(*models.KeyStringValue), args.Error(1)
}

func (m *KeyValueServiceMock) MustGetString(s string) *models.KeyStringValue {
args := m.Called(s)
return args.Get(0).(*models.KeyStringValue)
}

func (m *KeyValueServiceMock) GetByPrefix(s string) ([]*models.KeyStringValue, error) {
args := m.Called(s)
return args.Get(0).([]*models.KeyStringValue), args.Error(1)
}

func (m *KeyValueServiceMock) PutString(v *models.KeyStringValue) error {
args := m.Called(v)
return args.Error(0)
}

func (m *KeyValueServiceMock) DeleteString(s string) error {
args := m.Called(s)
return args.Error(0)
}
13 changes: 9 additions & 4 deletions routes/home.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/gorilla/schema"
conf "github.com/muety/wakapi/config"
"github.com/muety/wakapi/models"
"github.com/muety/wakapi/middlewares"
"github.com/muety/wakapi/models/view"
routeutils "github.com/muety/wakapi/routes/utils"
"github.com/muety/wakapi/services"
Expand All @@ -18,30 +18,35 @@ import (

type HomeHandler struct {
config *conf.Config
userSrvc services.IUserService
keyValueSrvc services.IKeyValueService
}

var loginDecoder = schema.NewDecoder()
var signupDecoder = schema.NewDecoder()
var resetPasswordDecoder = schema.NewDecoder()

func NewHomeHandler(keyValueService services.IKeyValueService) *HomeHandler {
func NewHomeHandler(userService services.IUserService, keyValueService services.IKeyValueService) *HomeHandler {
return &HomeHandler{
config: conf.Get(),
userSrvc: userService,
keyValueSrvc: keyValueService,
}
}

func (h *HomeHandler) RegisterRoutes(router chi.Router) {
router.Get("/", h.GetIndex)
router.Group(func(r chi.Router) {
r.Use(middlewares.NewAuthenticateMiddleware(h.userSrvc).WithOptionalFor("/").Handler)
r.Get("/", h.GetIndex)
})
}

func (h *HomeHandler) GetIndex(w http.ResponseWriter, r *http.Request) {
if h.config.IsDev() {
loadTemplates()
}

if cookie, err := r.Cookie(models.AuthCookieKey); err == nil && cookie.Value != "" {
if user := middlewares.GetPrincipal(r); user != nil {
http.Redirect(w, r, fmt.Sprintf("%s/summary", h.config.Server.BasePath), http.StatusFound)
return
}
Expand Down
123 changes: 123 additions & 0 deletions routes/home_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package routes

import (
"github.com/go-chi/chi/v5"
"github.com/muety/wakapi/config"
"github.com/muety/wakapi/middlewares"
"github.com/muety/wakapi/mocks"
"github.com/muety/wakapi/models"
"github.com/stretchr/testify/assert"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)

var (
user1 = models.User{
ID: "user1",
ShareDataMaxDays: 30,
ShareLanguages: true,
ApiKey: "fakekey",
}
)

func TestHomeHandler_Get_NotLoggedIn(t *testing.T) {
config.Set(config.Empty())
config.Get().Env = "dev"

if cwd, _ := os.Getwd(); strings.HasSuffix(cwd, "routes") {
os.Chdir("..")
}

router := chi.NewRouter()
router.Use(middlewares.NewPrincipalMiddleware())

userServiceMock := new(mocks.UserServiceMock)
userServiceMock.On("GetUserById", user1.ID).Return(&user1, nil)

keyValueServiceMock := new(mocks.KeyValueServiceMock)
keyValueServiceMock.On("GetString", config.KeyLatestTotalTime).Return(&models.KeyStringValue{Key: config.KeyLatestTotalTime, Value: "0"}, nil)
keyValueServiceMock.On("GetString", config.KeyLatestTotalUsers).Return(&models.KeyStringValue{Key: config.KeyLatestTotalUsers, Value: "0"}, nil)
keyValueServiceMock.On("GetString", config.KeyNewsbox).Return(&models.KeyStringValue{Key: config.KeyNewsbox, Value: ""}, nil)

homeHandler := NewHomeHandler(userServiceMock, keyValueServiceMock)
homeHandler.RegisterRoutes(router)

t.Run("when requesting frontpage", func(t *testing.T) {
t.Run("should display it without authentication", func(t *testing.T) {
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)

router.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()

assert.Equal(t, http.StatusOK, res.StatusCode)

data, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("unextected error. Error: %s", err)
}

assert.Contains(t, string(data), "<a href=\"login\" class=\"btn-primary\">")
keyValueServiceMock.AssertNumberOfCalls(t, "GetString", 3)
})
})
}

func TestHomeHandler_Get_LoggedIn(t *testing.T) {
config.Set(config.Empty())

router := chi.NewRouter()
router.Use(middlewares.NewPrincipalMiddleware())

userServiceMock := new(mocks.UserServiceMock)
userServiceMock.On("GetUserByKey", user1.ApiKey).Return(&user1, nil)
userServiceMock.On("GetUserById", user1.ID).Return(&user1, nil)

keyValueServiceMock := new(mocks.KeyValueServiceMock)

homeHandler := NewHomeHandler(userServiceMock, keyValueServiceMock)
homeHandler.RegisterRoutes(router)

t.Run("when requesting frontpage", func(t *testing.T) {
t.Run("should redirect in case of api key auth", func(t *testing.T) {
rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
q := req.URL.Query()
q.Set("api_key", user1.ApiKey)
req.URL.RawQuery = q.Encode()

router.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()

assert.Equal(t, http.StatusFound, res.StatusCode)
})

t.Run("should redirect in case of trusted header auth", func(t *testing.T) {
c := config.Get()
c.Security.TrustedHeaderAuth = true
c.Security.TrustedHeaderAuthKey = "Remote-User"
c.Security.TrustReverseProxyIps = "127.0.0.1"
c.Security.ParseTrustReverseProxyIPs()

rec := httptest.NewRecorder()

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Remote-User", user1.ID)
req.RemoteAddr = "127.0.0.1:12345"

router.ServeHTTP(rec, req)
res := rec.Result()
defer res.Body.Close()

assert.Equal(t, http.StatusFound, res.StatusCode)
})
})
}

0 comments on commit f04508c

Please sign in to comment.