diff --git a/go.mod b/go.mod index 1768ad26..4c02913f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/glebarez/sqlite v1.10.0 github.com/go-playground/validator/v10 v10.16.0 github.com/google/uuid v1.5.0 + github.com/gorilla/securecookie v1.1.2 github.com/gorilla/sessions v1.2.2 github.com/hashicorp/go-memdb v1.3.4 github.com/labstack/echo/v4 v4.11.4 @@ -57,7 +58,6 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/gorilla/mux v1.8.1 // indirect - github.com/gorilla/securecookie v1.1.2 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/internal/cli/main.go b/internal/cli/main.go index a78a6d87..31b641c6 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -87,6 +87,9 @@ func Initialize(ctx *cli.Context) { log.Fatal().Err(err).Msg("Failed to create symlinks") } + if err := os.MkdirAll(filepath.Join(homePath, "sessions"), 0755); err != nil { + log.Fatal().Err(err).Send() + } if err := os.MkdirAll(filepath.Join(homePath, "repos"), 0755); err != nil { log.Fatal().Err(err).Send() } diff --git a/internal/utils/session.go b/internal/utils/session.go new file mode 100644 index 00000000..74eadc00 --- /dev/null +++ b/internal/utils/session.go @@ -0,0 +1,26 @@ +package utils + +import ( + "github.com/gorilla/securecookie" + "github.com/rs/zerolog/log" + "os" +) + +func ReadKey(filePath string) []byte { + key, err := os.ReadFile(filePath) + if err == nil { + return key + } + + key = securecookie.GenerateRandomKey(32) + if key == nil { + log.Fatal().Msg("Failed to generate a new key for sessions") + } + + err = os.WriteFile(filePath, key, 0600) + if err != nil { + log.Fatal().Err(err).Msgf("Failed to save the key to %s", filePath) + } + + return key +} diff --git a/internal/web/server.go b/internal/web/server.go index 701da1cd..8a04b585 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -35,10 +35,11 @@ import ( ) var ( - dev bool - store *sessions.CookieStore - re = regexp.MustCompile("[^a-z0-9]+") - fm = template.FuncMap{ + dev bool + flashStore *sessions.CookieStore // session store for flash messages + userStore *sessions.FilesystemStore // session store for user sessions + re = regexp.MustCompile("[^a-z0-9]+") + fm = template.FuncMap{ "split": strings.Split, "indexByte": strings.IndexByte, "toInt": func(i string) int { @@ -160,8 +161,13 @@ type Server struct { func NewServer(isDev bool) *Server { dev = isDev - store = sessions.NewCookieStore([]byte("opengist")) - gothic.Store = store + flashStore = sessions.NewCookieStore([]byte("opengist")) + userStore = sessions.NewFilesystemStore(path.Join(config.GetHomeDir(), "sessions"), + utils.ReadKey(path.Join(config.GetHomeDir(), "sessions", "session-auth.key")), + utils.ReadKey(path.Join(config.GetHomeDir(), "sessions", "session-encrypt.key")), + ) + userStore.MaxLength(10 * 1024) + gothic.Store = userStore e := echo.New() e.HideBanner = true diff --git a/internal/web/test/server.go b/internal/web/test/server.go index 25aaaede..5bcf2243 100644 --- a/internal/web/test/server.go +++ b/internal/web/test/server.go @@ -142,6 +142,9 @@ func setup(t *testing.T) { homePath := config.GetHomeDir() log.Info().Msg("Data directory: " + homePath) + err = os.MkdirAll(filepath.Join(homePath, "sessions"), 0755) + require.NoError(t, err, "Could not create sessions directory") + err = os.MkdirAll(filepath.Join(homePath, "tmp", "repos"), 0755) require.NoError(t, err, "Could not create tmp repos directory") diff --git a/internal/web/util.go b/internal/web/util.go index 2ec1af58..ecc7cd3f 100644 --- a/internal/web/util.go +++ b/internal/web/util.go @@ -68,7 +68,7 @@ func getUserLogged(ctx echo.Context) *db.User { } func setErrorFlashes(ctx echo.Context) { - sess, _ := store.Get(ctx.Request(), "flash") + sess, _ := flashStore.Get(ctx.Request(), "flash") setData(ctx, "flashErrors", sess.Flashes("error")) setData(ctx, "flashSuccess", sess.Flashes("success")) @@ -77,13 +77,13 @@ func setErrorFlashes(ctx echo.Context) { } func addFlash(ctx echo.Context, flashMessage string, flashType string) { - sess, _ := store.Get(ctx.Request(), "flash") + sess, _ := flashStore.Get(ctx.Request(), "flash") sess.AddFlash(flashMessage, flashType) _ = sess.Save(ctx.Request(), ctx.Response()) } func getSession(ctx echo.Context) *sessions.Session { - sess, _ := store.Get(ctx.Request(), "session") + sess, _ := userStore.Get(ctx.Request(), "session") return sess }