Skip to content

Commit

Permalink
fix authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
nektro committed Sep 26, 2020
1 parent 7e27f09 commit bdde741
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"strconv"
"strings"

"github.com/gorilla/sessions"
"github.com/nektro/go-util/util"
"github.com/nektro/go-util/vflag"
dbstorage "github.com/nektro/go.dbstorage"
Expand Down Expand Up @@ -103,7 +102,8 @@ func main() {
//

htp.Register("/portal", "GET", mw(func(w http.ResponseWriter, r *http.Request) {
_, u, err := pageInit(r, w, http.MethodGet, true, true, false, true)
c := htp.GetController(r)
u, err := pageInit(c, r, w, http.MethodGet, true, true, false, true)
if err != nil {
return
}
Expand All @@ -115,7 +115,8 @@ func main() {
}))

htp.Register("/upload", "GET", mw(func(w http.ResponseWriter, r *http.Request) {
_, u, err := pageInit(r, w, http.MethodGet, true, true, false, true)
c := htp.GetController(r)
u, err := pageInit(c, r, w, http.MethodGet, true, true, false, true)
if err != nil {
return
}
Expand All @@ -124,8 +125,9 @@ func main() {
})
}))

htp.Register("/p/{hash}", "GET", mw(func(w http.ResponseWriter, r *http.Request) {
_, _, err := pageInit(r, w, http.MethodGet, false, false, false, true)
htp.Register("/p/{hash:[0-9a-f]+}", "GET", mw(func(w http.ResponseWriter, r *http.Request) {
c := htp.GetController(r)
_, err := pageInit(c, r, w, http.MethodGet, false, false, false, true)
if err != nil {
return
}
Expand All @@ -150,7 +152,8 @@ func main() {
}))

htp.Register("/users", "GET", mw(func(w http.ResponseWriter, r *http.Request) {
_, u, err := pageInit(r, w, http.MethodGet, true, true, true, true)
c := htp.GetController(r)
u, err := pageInit(c, r, w, http.MethodGet, true, true, true, true)
if err != nil {
return
}
Expand All @@ -162,7 +165,8 @@ func main() {
//

htp.Register("/b/upload", "POST", mw(func(w http.ResponseWriter, r *http.Request) {
_, u, err := pageInit(r, w, http.MethodPost, true, true, false, false)
c := htp.GetController(r)
u, err := pageInit(c, r, w, http.MethodPost, true, true, false, false)
if err != nil {
return
}
Expand Down Expand Up @@ -222,7 +226,8 @@ func main() {
}))

htp.Register("/b/users/update/*", "PUT", mw(func(w http.ResponseWriter, r *http.Request) {
_, _, err := pageInit(r, w, http.MethodPut, true, true, true, false)
c := htp.GetController(r)
_, err := pageInit(c, r, w, http.MethodPut, true, true, true, false)
if err != nil {
writeJson(w, map[string]interface{}{})
return
Expand Down Expand Up @@ -285,38 +290,35 @@ func isLoggedIn(r *http.Request) bool {
return isLoggedInS(etc.GetSession(r))
}

func pageInit(r *http.Request, w http.ResponseWriter, method string, requireLogin bool, requireMember bool, requireAdmin bool, htmlOut bool) (*sessions.Session, *User, error) {
func pageInit(c *htp.Controller, r *http.Request, w http.ResponseWriter, method string, requireLogin bool, requireMember bool, requireAdmin bool, htmlOut bool) (*User, error) {
if r.Method != method {
writeResponse(r, w, htmlOut, "Forbidden Method", F("%s is not allowed on this endpoint.", r.Method), "", "")
return nil, nil, E("bad http method")
return nil, E("bad http method")
}
if method == http.MethodPost {
r.ParseMultipartForm(int64(config.MaxFileSize * Megabyte))
}
if method == http.MethodPut || method == http.MethodPatch {
if method == http.MethodPut {
r.ParseForm()
}
if !requireLogin {
return nil, nil, nil
return nil, nil
}

s := etc.GetSession(r)
if requireLogin && !isLoggedInS(s) {
writeResponse(r, w, htmlOut, "Authentication Required", "You must log in to access this site.", "/login", "Please Log In")
return s, nil, E("not logged in")
}
s := etc.JWTGetClaims(c, r)
sp := strings.SplitN(s["sub"].(string), "\n", 2)

u := queryUserBySnowflake(s.Values["provider"].(string), s.Values["user"].(string))
u := queryUserBySnowflake(sp[0], sp[1])
if requireMember && !u.IsMember {
writeResponse(r, w, htmlOut, "Access Forbidden", "You must be a member to view this page.", "", "")
return s, u, E("not a member")
return u, E("not a member")
}
if requireAdmin && !u.IsAdmin {
writeResponse(r, w, htmlOut, "Access Forbidden", "You must be an admin to view this page.", "", "")
return s, u, E("not an admin")
return u, E("not an admin")
}

return s, u, nil
return u, nil
}

func writeResponse(r *http.Request, w http.ResponseWriter, htmlOut bool, title string, message string, url string, link string) {
Expand Down

0 comments on commit bdde741

Please sign in to comment.