Skip to content

Commit

Permalink
Merge pull request #968 from go-kivik/opts
Browse files Browse the repository at this point in the history
Standardize use of options methods
  • Loading branch information
flimzy committed May 10, 2024
2 parents aa1978d + 686b34c commit 0bcafec
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 61 deletions.
2 changes: 1 addition & 1 deletion x/sqlite/changes.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type normalChanges struct {
var _ driver.Changes = &normalChanges{}

func (d *db) newNormalChanges(ctx context.Context, opts optsMap, since, lastSeq *uint64, sinceNow bool, feed string) (*normalChanges, error) {
limit, err := opts.limit()
limit, err := opts.changesLimit()
if err != nil {
return nil, err
}
Expand Down
39 changes: 10 additions & 29 deletions x/sqlite/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import (
)

func (d *db) Get(ctx context.Context, id string, options driver.Options) (*driver.Document, error) {
opts := map[string]interface{}{}
options.Apply(opts)
opts := newOpts(options)

var r revision

Expand All @@ -35,40 +34,22 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
}
defer tx.Rollback()

var (
optConflicts, _ = opts["conflicts"].(bool)
optDeletedConflicts, _ = opts["deleted_conflicts"].(bool)
optRevsInfo, _ = opts["revs_info"].(bool)
optRevs, _ = opts["revs"].(bool) // TODO: opts.revs()
optLocalSeq, _ = opts["local_seq"].(bool)
optAttachments, _ = opts["attachments"].(bool)
optAttsSince, _ = opts["atts_since"].([]string)
optsRev, _ = opts["rev"].(string)
latest, _ = opts["latest"].(bool)
)

if optsRev != "" {
r, err = parseRev(optsRev)
if opts.rev() != "" {
r, err = parseRev(opts.rev())
if err != nil {
return nil, err
}
}
toMerge, r, err := d.getCoreDoc(ctx, tx, id, r, latest, false)
toMerge, r, err := d.getCoreDoc(ctx, tx, id, r, opts.latest(), false)
if err != nil {
return nil, err
}

if !optLocalSeq {
if !opts.localSeq() {
toMerge.LocalSeq = 0
}

if meta, _ := opts["meta"].(bool); meta {
optConflicts = true
optDeletedConflicts = true
optRevsInfo = true
}

if optConflicts {
if opts.conflicts() {
revs, err := d.conflicts(ctx, tx, id, r, false)
if err != nil {
return nil, err
Expand All @@ -77,7 +58,7 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
toMerge.Conflicts = revs
}

if optDeletedConflicts {
if opts.deletedConflicts() {
revs, err := d.conflicts(ctx, tx, id, r, true)
if err != nil {
return nil, err
Expand All @@ -86,7 +67,7 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
toMerge.DeletedConflicts = revs
}

if optRevsInfo || optRevs {
if opts.revsInfo() || opts.revs() {
rows, err := tx.QueryContext(ctx, d.query(`
WITH RECURSIVE Ancestors AS (
-- Base case: Select the starting node for ancestors
Expand Down Expand Up @@ -133,7 +114,7 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
if err := rows.Err(); err != nil {
return nil, err
}
if optRevsInfo {
if opts.revsInfo() {
info := make([]map[string]string, 0, len(revs))
for _, r := range revs {
info = append(info, map[string]string{
Expand All @@ -155,7 +136,7 @@ func (d *db) Get(ctx context.Context, id string, options driver.Options) (*drive
}
}

atts, err := d.getAttachments(ctx, tx, id, r, optAttachments, optAttsSince)
atts, err := d.getAttachments(ctx, tx, id, r, opts.attachments(), opts.attsSince())
if err != nil {
return nil, err
}
Expand Down
11 changes: 4 additions & 7 deletions x/sqlite/getrev.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import (
)

func (d *db) GetRev(ctx context.Context, id string, options driver.Options) (string, error) {
opts := map[string]interface{}{}
options.Apply(opts)
opts := newOpts(options)

var r revision

Expand All @@ -34,15 +33,13 @@ func (d *db) GetRev(ctx context.Context, id string, options driver.Options) (str
}
defer tx.Rollback()

optsRev, _ := opts["rev"].(string)
latest, _ := opts["latest"].(bool)
if optsRev != "" {
r, err = parseRev(optsRev)
if opts.rev() != "" {
r, err = parseRev(opts.rev())
if err != nil {
return "", err
}
}
_, r, err = d.getCoreDoc(ctx, d.db, id, r, latest, true)
_, r, err = d.getCoreDoc(ctx, d.db, id, r, opts.latest(), true)
if err != nil {
return "", err
}
Expand Down
123 changes: 119 additions & 4 deletions x/sqlite/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package sqlite

import (
"math"
"net/http"
"strconv"

Expand Down Expand Up @@ -92,9 +93,9 @@ func (o optsMap) since() (bool, *uint64, error) {
return false, &since, err
}

// limit returns 0 if the limit is unset, or the limit value as a uint64. An
// explicit limit of 0 is converted to 1, as per CouchDB docs.
func (o optsMap) limit() (uint64, error) {
// changesLimit returns the changesLimit value as a uint64, or 0 if the changesLimit is unset. An
// explicit changesLimit of 0 is converted to 1, as per CouchDB docs.
func (o optsMap) changesLimit() (uint64, error) {
in, ok := o["limit"]
if !ok {
return 0, nil
Expand All @@ -109,6 +110,26 @@ func (o optsMap) limit() (uint64, error) {
return limit, nil
}

// limit returns the limit value as an int64, or -1 if the limit is unset.
// If the limit is invalid, an error is returned with status 400.
func (o optsMap) limit() (int64, error) {
in, ok := o["limit"]
if !ok {
return -1, nil
}
return toInt64(in, "malformed 'limit' parameter")
}

// skip returns the skip value as an int64, or 0 if the skip is unset.
// If the skip is invalid, an error is returned with status 400.
func (o optsMap) skip() (int64, error) {
in, ok := o["skip"]
if !ok {
return 0, nil
}
return toInt64(in, "malformed 'skip' parameter")
}

// toUint64 converts the input to a uint64. If the input is malformed, it
// returns an error with msg as the message, and 400 as the status code.
func toUint64(in interface{}, msg string) (uint64, error) {
Expand Down Expand Up @@ -162,6 +183,56 @@ func toUint64(in interface{}, msg string) (uint64, error) {
}
}

// toInt64 converts the input to a int64. If the input is malformed, it
// returns an error with msg as the message, and 400 as the status code.
func toInt64(in interface{}, msg string) (int64, error) {
switch t := in.(type) {
case int:
return int64(t), nil
case int64:
return t, nil
case int8:
return int64(t), nil
case int16:
return int64(t), nil
case int32:
return int64(t), nil
case uint:
return int64(t), nil
case uint8:
return int64(t), nil
case uint16:
return int64(t), nil
case uint32:
return int64(t), nil
case uint64:
if t > math.MaxInt64 {
return 0, &internal.Error{Status: http.StatusBadRequest, Message: msg}
}
return int64(t), nil
case string:
i, err := strconv.ParseInt(t, 10, 64)
if err != nil {
return 0, &internal.Error{Status: http.StatusBadRequest, Message: msg}
}
return i, nil
case float32:
i := int64(t)
if float32(i) != t {
return 0, &internal.Error{Status: http.StatusBadRequest, Message: msg}
}
return i, nil
case float64:
i := int64(t)
if float64(i) != t {
return 0, &internal.Error{Status: http.StatusBadRequest, Message: msg}
}
return i, nil
default:
return 0, &internal.Error{Status: http.StatusBadRequest, Message: msg}
}
}

func toBool(in interface{}) (value bool, ok bool) {
switch t := in.(type) {
case bool:
Expand All @@ -174,8 +245,13 @@ func toBool(in interface{}) (value bool, ok bool) {
}
}

func (o optsMap) descending() bool {
v, _ := toBool(o["descending"])
return v
}

func (o optsMap) direction() string {
if v, _ := toBool(o["descending"]); v {
if o.descending() {
return "DESC"
}
return "ASC"
Expand Down Expand Up @@ -268,3 +344,42 @@ func (o optsMap) groupLevel() (uint64, error) {
}
return toUint64(raw, "invalid value for `group_level`")
}

func (o optsMap) conflicts() bool {
if o.meta() {
return true
}
v, _ := toBool(o["conflicts"])
return v
}

func (o optsMap) meta() bool {
v, _ := toBool(o["meta"])
return v
}

func (o optsMap) deletedConflicts() bool {
if o.meta() {
return true
}
v, _ := toBool(o["deleted_conflicts"])
return v
}

func (o optsMap) revsInfo() bool {
if o.meta() {
return true
}
v, _ := toBool(o["revs_info"])
return v
}

func (o optsMap) localSeq() bool {
v, _ := toBool(o["local_seq"])
return v
}

func (o optsMap) attsSince() []string {
attsSince, _ := o["atts_since"].([]string)
return attsSince
}
32 changes: 12 additions & 20 deletions x/sqlite/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,7 @@ func (d *db) DesignDocs(ctx context.Context, options driver.Options) (driver.Row
func (d *db) queryView(ctx context.Context, view string, options driver.Options) (driver.Rows, error) {
opts := newOpts(options)

var (
optConflicts, _ = opts["conflicts"].(bool)
optDescending, _ = opts["descending"].(bool)
optIncludeDocs, _ = opts["include_docs"].(bool)
optLimit, _ = opts["limit"].(int)
optSkip, _ = opts["skip"].(int)
)

direction := "ASC"
if optDescending {
direction = "DESC"
}

args := []interface{}{optIncludeDocs}
args := []interface{}{opts.includeDocs()}

where := []string{"rev.rank = 1"}
switch view {
Expand All @@ -90,15 +77,20 @@ func (d *db) queryView(ctx context.Context, view string, options driver.Options)
where = append(where, "rev.id LIKE '_design/%'")
}
if endkey := opts.endKey(); endkey != "" {
where = append(where, fmt.Sprintf("rev.id %s $%d", endKeyOp(optDescending, opts.inclusiveEnd()), len(args)+1))
where = append(where, fmt.Sprintf("rev.id %s $%d", endKeyOp(opts.descending(), opts.inclusiveEnd()), len(args)+1))
args = append(args, endkey)
}
if startkey := opts.startKey(); startkey != "" {
where = append(where, fmt.Sprintf("rev.id %s $%d", startKeyOp(optDescending), len(args)+1))
where = append(where, fmt.Sprintf("rev.id %s $%d", startKeyOp(opts.descending()), len(args)+1))
args = append(args, startkey)
}
if optLimit == 0 {
optLimit = -1
limit, err := opts.limit()
if err != nil {
return nil, err
}
skip, err := opts.skip()
if err != nil {
return nil, err
}

query := fmt.Sprintf(d.query(`
Expand Down Expand Up @@ -134,7 +126,7 @@ func (d *db) queryView(ctx context.Context, view string, options driver.Options)
GROUP BY rev.id, rev.rev, rev.rev_id
ORDER BY id %[1]s
LIMIT %[3]d OFFSET %[4]d
`), direction, strings.Join(where, " AND "), optLimit, optSkip)
`), opts.direction(), strings.Join(where, " AND "), limit, skip)
results, err := d.db.QueryContext(ctx, query, args...) //nolint:rowserrcheck // Err checked in Next
if err != nil {
return nil, err
Expand All @@ -144,7 +136,7 @@ func (d *db) queryView(ctx context.Context, view string, options driver.Options)
ctx: ctx,
db: d,
rows: results,
conflicts: optConflicts,
conflicts: opts.conflicts(),
}, nil
}

Expand Down
11 changes: 11 additions & 0 deletions x/sqlite/views_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package sqlite
import (
"context"
"io"
"net/http"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -560,6 +561,16 @@ func TestDBAllDocs(t *testing.T) {
},
}
})
tests.Add("invalid limit value", test{
options: kivik.Params(map[string]interface{}{"limit": "chicken"}),
wantErr: "malformed 'limit' parameter",
wantStatus: http.StatusBadRequest,
})
tests.Add("invalid skip value", test{
options: kivik.Params(map[string]interface{}{"skip": "chicken"}),
wantErr: "malformed 'skip' parameter",
wantStatus: http.StatusBadRequest,
})

/*
TODO:
Expand Down

0 comments on commit 0bcafec

Please sign in to comment.