Skip to content

Commit

Permalink
Handle redis cluster MOVED errors in cache and redis backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
iegomez committed Jun 26, 2020
1 parent e8b112b commit 1371452
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 67 deletions.
153 changes: 114 additions & 39 deletions backends/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package backends

import (
"context"
"errors"
"fmt"
"strconv"
"strings"
Expand All @@ -21,6 +22,17 @@ type RedisClient interface {
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *goredis.StatusCmd
SAdd(ctx context.Context, key string, members ...interface{}) *goredis.IntCmd
Expire(ctx context.Context, key string, expiration time.Duration) *goredis.BoolCmd
ReloadState(ctx context.Context) error
}

type SingleRedisClient struct {
*goredis.Client
}

var SingleClientError = errors.New("unsupported reload state operation for Redis single client")

func (c SingleRedisClient) ReloadState(ctx context.Context) error {
return SingleClientError
}

type Redis struct {
Expand Down Expand Up @@ -106,7 +118,7 @@ func NewRedis(authOpts map[string]string, logLevel log.Level) (Redis, error) {
Password: redis.Password,
DB: int(redis.DB),
})
redis.conn = redisClient
redis.conn = SingleRedisClient{redisClient}
}

for {
Expand All @@ -122,48 +134,122 @@ func NewRedis(authOpts map[string]string, logLevel log.Level) (Redis, error) {

}

// Checks if an error was caused by a moved record in a cluster.
func IsMovedError(err error) bool {
s := err.Error()
if strings.HasPrefix(s, "MOVED ") || strings.HasPrefix(s, "ASK ") {
return true
}

return false
}

//GetUser checks that the username exists and the given password hashes to the same password.
func (o Redis) GetUser(username, password, clientid string) bool {
func (o Redis) GetUser(username, password, _ string) bool {
ok, err := o.getUser(username, password)
if err == nil {
return ok
}

pwHash, err := o.conn.Get(o.ctx, username).Result()
//If using Redis Cluster, reload state and attempt once more.
if IsMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false
}

//Retry once.
ok, err = o.getUser(username, password)
}

if err != nil {
log.Debugf("Redis get user error: %s", err)
return false
log.Debugf("redis get user error: %s", err)
}
return ok
}

if common.HashCompare(password, pwHash, o.SaltEncoding) {
return true
func (o Redis) getUser(username, password string) (bool, error) {
pwHash, err := o.conn.Get(o.ctx, username).Result()
if err != nil {
return false, err
}

return false
if common.HashCompare(password, pwHash, o.SaltEncoding) {
return true, nil
}

return false, nil
}

//GetSuperuser checks that the key username:su exists and has value "true".
func (o Redis) GetSuperuser(username string) bool {

if o.disableSuperuser {
return false
}

isSuper, err := o.conn.Get(o.ctx, fmt.Sprintf("%s:su", username)).Result()
ok, err := o.getSuperuser(username)
if err == nil {
return ok
}

//If using Redis Cluster, reload state and attempt once more.
if IsMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false
}

//Retry once.
ok, err = o.getSuperuser(username)
}

if err != nil {
log.Debugf("Redis get superuser error: %s", err)
return false
log.Debugf("redis get superuser error: %s", err)
}
return ok
}

func (o Redis) getSuperuser(username string) (bool, error) {
isSuper, err := o.conn.Get(o.ctx, fmt.Sprintf("%s:su", username)).Result()
if err != nil {
return false, err
}

if isSuper == "true" {
return true
return true, nil
}

return false
return false, nil
}

func (o Redis) CheckAcl(username, topic, clientid string, acc int32) bool {
ok, err := o.checkAcl(username, topic, clientid, acc)
if err == nil {
return ok
}

//If using Redis Cluster, reload state and attempt once more.
if IsMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false
}

//Retry once.
ok, err = o.checkAcl(username, topic, clientid, acc)
}

if err != nil {
log.Debugf("redis check acl error: %s", err)
}
return ok
}

//CheckAcl gets all acls for the username and tries to match against topic, acc, and username/clientid if needed.
func (o Redis) CheckAcl(username, topic, clientid string, acc int32) bool {
func (o Redis) checkAcl(username, topic, clientid string, acc int32) (bool, error) {

var acls []string //User specific acls.
var commonAcls []string //Common acls.
Expand All @@ -175,40 +261,34 @@ func (o Redis) CheckAcl(username, topic, clientid string, acc int32) bool {
var err error
acls, err = o.conn.SMembers(o.ctx, fmt.Sprintf("%s:sacls", username)).Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

//Get common subscribe acls.
commonAcls, err = o.conn.SMembers(o.ctx, "common:sacls").Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

case MOSQ_ACL_READ:
//Get all user read and readwrite acls.
urAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:racls", username)).Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}
urwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:rwacls", username)).Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

//Get common read and readwrite acls
rAcls, err := o.conn.SMembers(o.ctx, "common:racls").Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}
rwAcls, err := o.conn.SMembers(o.ctx, "common:rwacls").Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

acls = make([]string, len(urAcls)+len(urwAcls))
Expand All @@ -222,25 +302,21 @@ func (o Redis) CheckAcl(username, topic, clientid string, acc int32) bool {
//Get all user write and readwrite acls.
uwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:wacls", username)).Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}
urwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:rwacls", username)).Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

//Get common write and readwrite acls
wAcls, err := o.conn.SMembers(o.ctx, "common:wacls").Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}
rwAcls, err := o.conn.SMembers(o.ctx, "common:rwacls").Result()
if err != nil {
log.Debugf("Redis check acl error: %s", err)
return false
return false, err
}

acls = make([]string, len(uwAcls)+len(urwAcls))
Expand All @@ -255,20 +331,19 @@ func (o Redis) CheckAcl(username, topic, clientid string, acc int32) bool {
//Now loop through acls looking for a match.
for _, acl := range acls {
if common.TopicsMatch(acl, topic) {
return true
return true, nil
}
}

for _, acl := range commonAcls {
aclTopic := strings.Replace(acl, "%c", clientid, -1)
aclTopic = strings.Replace(aclTopic, "%u", username, -1)
if common.TopicsMatch(aclTopic, topic) {
return true
return true, nil
}
}

return false

return false, nil
}

//GetName returns the backend's name
Expand Down
Loading

0 comments on commit 1371452

Please sign in to comment.