Skip to content

Commit

Permalink
fix(api-ext): use same save logic as the api (#518)
Browse files Browse the repository at this point in the history
* switched session priority to header then cookie

* fix(db): GetBookmarks handle no rows error

* fix(api-ext): using same save logic as the api
  • Loading branch information
fmartingr authored Oct 15, 2022
1 parent 5f1adc6 commit d1f0ce8
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 31 deletions.
31 changes: 30 additions & 1 deletion internal/database/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ func testDatabase(t *testing.T, dbFactory testDatabaseFactory) {
"testCreateBookmark": testCreateBookmark,
"testCreateBookmarkTwice": testCreateBookmarkTwice,
"testCreateBookmarkWithTag": testCreateBookmarkWithTag,
"testCreateTwoDifferentBookmarks": testCreateTwoDifferentBookmarks,
"testUpdateBookmark": testUpdateBookmark,
"testGetBookmark": testGetBookmark,
"testGetBookmarkNotExistant": testGetBookmarkNotExistant,
"testGetBookmarks": testGetBookmarks,
"testGetBookmarksCount": testGetBookmarksCount,
"testCreateTwoDifferentBookmarks": testCreateTwoDifferentBookmarks,
}

for testName, testCase := range tests {
Expand Down Expand Up @@ -124,6 +126,33 @@ func testUpdateBookmark(t *testing.T, db DB) {
assert.Equal(t, savedBookmark.ID, result[0].ID)
}

func testGetBookmark(t *testing.T, db DB) {
ctx := context.TODO()

book := model.Bookmark{
URL: "https://github.com/go-shiori/shiori",
Title: "shiori",
}

result, err := db.SaveBookmarks(ctx, true, book)
assert.NoError(t, err, "Save bookmarks must not fail")

savedBookmark, exists, err := db.GetBookmark(ctx, result[0].ID, "")
assert.True(t, exists, "Bookmark should exist")
assert.NoError(t, err, "Get bookmark should not fail")
assert.Equal(t, result[0].ID, savedBookmark.ID, "Retrieved bookmark should be the same")
assert.Equal(t, book.URL, savedBookmark.URL, "Retrieved bookmark should be the same")
}

func testGetBookmarkNotExistant(t *testing.T, db DB) {
ctx := context.TODO()

savedBookmark, exists, err := db.GetBookmark(ctx, 1, "")
assert.NoError(t, err, "Get bookmark should not fail")
assert.False(t, exists, "Bookmark should not exist")
assert.Equal(t, model.Bookmark{}, savedBookmark)
}

func testGetBookmarks(t *testing.T, db DB) {
ctx := context.TODO()

Expand Down
2 changes: 1 addition & 1 deletion internal/database/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ func (db *MySQLDatabase) GetBookmark(ctx context.Context, id int, url string) (m
}

book := model.Bookmark{}
if err := db.GetContext(ctx, &book, query, args...); err != nil {
if err := db.GetContext(ctx, &book, query, args...); err != nil && err != sql.ErrNoRows {
return book, false, errors.WithStack(err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/database/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ func (db *PGDatabase) GetBookmark(ctx context.Context, id int, url string) (mode
}

book := model.Bookmark{}
if err := db.GetContext(ctx, &book, query, args...); err != nil {
if err := db.GetContext(ctx, &book, query, args...); err != nil && err != sql.ErrNoRows {
return book, false, errors.WithStack(err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/database/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ func (db *SQLiteDatabase) GetBookmark(ctx context.Context, id int, url string) (
}

book := model.Bookmark{}
if err := db.GetContext(ctx, &book, query, args...); err != nil {
if err := db.GetContext(ctx, &book, query, args...); err != nil && err != sql.ErrNoRows {
return book, false, errors.WithStack(err)
}

Expand Down
33 changes: 16 additions & 17 deletions internal/webserver/handler-api-ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,8 @@ func (h *handler) apiInsertViaExtension(w http.ResponseWriter, r *http.Request,
book.Tags = append(book.Tags, newTag)
}
}
} else {
book = request
book.ID, err = h.DB.CreateNewID(ctx, "bookmark")
if err != nil {
panic(fmt.Errorf("failed to create ID: %v", err))
}
} else if request.Title == "" {
request.Title = request.URL
}

// Since we are using extension, the extension might send the HTML content
Expand All @@ -76,6 +72,15 @@ func (h *handler) apiInsertViaExtension(w http.ResponseWriter, r *http.Request,
contentBuffer = bytes.NewBufferString(book.HTML)
}

// Save the bookmark with whatever we already have downloaded
// since we need the ID in order to download the archive
books, err := h.DB.SaveBookmarks(ctx, true, request)
if err != nil {
log.Printf("error saving bookmark before downloading content: %s", err)
return
}
book = books[0]

// At this point the web page already downloaded.
// Time to process it.
if contentBuffer != nil {
Expand All @@ -94,20 +99,14 @@ func (h *handler) apiInsertViaExtension(w http.ResponseWriter, r *http.Request,
tmp.Close()
}

// If we can't process or update the saved bookmark, just log it and continue on with the
// request.
if err != nil && isFatalErr {
panic(fmt.Errorf("failed to process bookmark: %v", err))
log.Printf("failed to process bookmark: %v", err)
} else if _, err := h.DB.SaveBookmarks(ctx, false, book); err != nil {
log.Printf("error saving bookmark after downloading content: %s", err)
}
}
if _, err := h.DB.SaveBookmarks(ctx, true, book); err != nil {
log.Printf("error saving bookmark after downloading content: %s", err)
}

// Save bookmark to database
results, err := h.DB.SaveBookmarks(ctx, false, book)
if err != nil || len(results) == 0 {
panic(fmt.Errorf("failed to save bookmark: %v", err))
}
book = results[0]

// Return the new bookmark
w.Header().Set("Content-Type", "application/json")
Expand Down
16 changes: 6 additions & 10 deletions internal/webserver/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,17 @@ func (h *handler) prepareTemplates() error {
}

func (h *handler) getSessionID(r *http.Request) string {
// Get session-id from header and cookie
headerSessionID := r.Header.Get("X-Session-Id")
cookieSessionID := func() string {
// Try to get session ID from the header
sessionID := r.Header.Get("X-Session-Id")

// If not, try it from the cookie
if sessionID == "" {
cookie, err := r.Cookie("session-id")
if err != nil {
return ""
}

return cookie.Value
}()

// Session ID in cookie is more priority than in header
sessionID := headerSessionID
if cookieSessionID != "" {
sessionID = cookieSessionID
sessionID = cookie.Value
}

return sessionID
Expand Down

0 comments on commit d1f0ce8

Please sign in to comment.