package main import "C" import ( "context" "os" "strconv" "strings" "time" bes "github.com/iegomez/mosquitto-go-auth/backends" "github.com/iegomez/mosquitto-go-auth/cache" "github.com/iegomez/mosquitto-go-auth/hashing" log "github.com/sirupsen/logrus" ) type AuthPlugin struct { backends *bes.Backends useCache bool logLevel log.Level logDest string logFile string ctx context.Context cache cache.Store hasher hashing.HashComparer retryCount int } // errors to signal mosquitto const ( AuthRejected = 0 AuthGranted = 1 AuthError = 2 ) var authOpts map[string]string //Options passed by mosquitto. var authPlugin AuthPlugin //General struct with options and conf. //export AuthPluginInit func AuthPluginInit(keys []*C.char, values []*C.char, authOptsNum int, version *C.char) { log.SetFormatter(&log.TextFormatter{ FullTimestamp: true, }) //Initialize auth plugin struct with default and given values. authPlugin = AuthPlugin{ logLevel: log.InfoLevel, ctx: context.Background(), } authOpts = make(map[string]string) for i := 0; i < authOptsNum; i++ { authOpts[C.GoString(keys[i])] = C.GoString(values[i]) } if retryCount, ok := authOpts["retry_count"]; ok { retry, err := strconv.ParseInt(retryCount, 10, 64) if err == nil { authPlugin.retryCount = int(retry) } else { log.Warningf("couldn't parse retryCount (err: %s), defaulting to 0", err) } } //Check if log level is given. Set level if any valid option is given. if logLevel, ok := authOpts["log_level"]; ok { logLevel = strings.Replace(logLevel, " ", "", -1) switch logLevel { case "debug": authPlugin.logLevel = log.DebugLevel case "info": authPlugin.logLevel = log.InfoLevel case "warn": authPlugin.logLevel = log.WarnLevel case "error": authPlugin.logLevel = log.ErrorLevel case "fatal": authPlugin.logLevel = log.FatalLevel case "panic": authPlugin.logLevel = log.PanicLevel default: log.Info("log_level unkwown, using default info level") } } if logDest, ok := authOpts["log_dest"]; ok { switch logDest { case "stdout": log.SetOutput(os.Stdout) case "file": if logFile, ok := authOpts["log_file"]; ok { file, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err == nil { log.SetOutput(file) } else { log.Errorf("failed to log to file, using default stderr: %s", err) } } default: log.Info("log_dest unknown, using default stderr") } } var err error authPlugin.backends, err = bes.Initialize(authOpts, authPlugin.logLevel, C.GoString(version)) if err != nil { log.Fatalf("error initializing backends: %s", err) } if cache, ok := authOpts["cache"]; ok && strings.Replace(cache, " ", "", -1) == "true" { log.Info("redisCache activated") authPlugin.useCache = true } else { log.Info("No cache set.") authPlugin.useCache = false } if authPlugin.useCache { setCache(authOpts) } } func setCache(authOpts map[string]string) { var aclCacheSeconds int64 = 30 var authCacheSeconds int64 = 30 var authJitterSeconds int64 = 0 var aclJitterSeconds int64 = 0 if authCacheSec, ok := authOpts["auth_cache_seconds"]; ok { authSec, err := strconv.ParseInt(authCacheSec, 10, 64) if err == nil { authCacheSeconds = authSec } else { log.Warningf("couldn't parse authCacheSeconds (err: %s), defaulting to %d", err, authCacheSeconds) } } if authJitterSec, ok := authOpts["auth_jitter_seconds"]; ok { authSec, err := strconv.ParseInt(authJitterSec, 10, 64) if err == nil { authJitterSeconds = authSec } else { log.Warningf("couldn't parse authJitterSeconds (err: %s), defaulting to %d", err, authJitterSeconds) } } if authJitterSeconds > authCacheSeconds { authJitterSeconds = authCacheSeconds log.Warningf("authJitterSeconds is larger than authCacheSeconds, defaulting to %d", authJitterSeconds) } if aclCacheSec, ok := authOpts["acl_cache_seconds"]; ok { aclSec, err := strconv.ParseInt(aclCacheSec, 10, 64) if err == nil { aclCacheSeconds = aclSec } else { log.Warningf("couldn't parse aclCacheSeconds (err: %s), defaulting to %d", err, aclCacheSeconds) } } if aclJitterSec, ok := authOpts["acl_jitter_seconds"]; ok { aclSec, err := strconv.ParseInt(aclJitterSec, 10, 64) if err == nil { aclJitterSeconds = aclSec } else { log.Warningf("couldn't parse aclJitterSeconds (err: %s), defaulting to %d", err, aclJitterSeconds) } } if aclJitterSeconds > aclCacheSeconds { aclJitterSeconds = aclCacheSeconds log.Warningf("aclJitterSeconds is larger than aclCacheSeconds, defaulting to %d", aclJitterSeconds) } reset := false if cacheReset, ok := authOpts["cache_reset"]; ok && cacheReset == "true" { reset = true } refreshExpiration := false if refresh, ok := authOpts["cache_refresh"]; ok && refresh == "true" { refreshExpiration = true } switch authOpts["cache_type"] { case "redis": host := "localhost" port := "6379" db := 3 password := "" cluster := false if authOpts["cache_mode"] == "true" { cluster = true } if cachePassword, ok := authOpts["cache_password"]; ok { password = cachePassword } if cluster { addressesOpt := authOpts["redis_cluster_addresses"] if addressesOpt == "" { log.Errorln("cache Redis cluster addresses missing, defaulting to no cache.") authPlugin.useCache = false return } // Take the given addresses and trim spaces from them. addresses := strings.Split(addressesOpt, ",") for i := 0; i < len(addresses); i++ { addresses[i] = strings.TrimSpace(addresses[i]) } authPlugin.cache = cache.NewRedisClusterStore( password, addresses, time.Duration(authCacheSeconds)*time.Second, time.Duration(aclCacheSeconds)*time.Second, time.Duration(authJitterSeconds)*time.Second, time.Duration(aclJitterSeconds)*time.Second, refreshExpiration, ) } else { if cacheHost, ok := authOpts["cache_host"]; ok { host = cacheHost } if cachePort, ok := authOpts["cache_port"]; ok { port = cachePort } if cacheDB, ok := authOpts["cache_db"]; ok { parsedDB, err := strconv.ParseInt(cacheDB, 10, 32) if err == nil { db = int(parsedDB) } else { log.Warningf("couldn't parse cache db (err: %s), defaulting to %d", err, db) } } authPlugin.cache = cache.NewSingleRedisStore( host, port, password, db, time.Duration(authCacheSeconds)*time.Second, time.Duration(aclCacheSeconds)*time.Second, time.Duration(authJitterSeconds)*time.Second, time.Duration(aclJitterSeconds)*time.Second, refreshExpiration, ) } default: authPlugin.cache = cache.NewGoStore( time.Duration(authCacheSeconds)*time.Second, time.Duration(aclCacheSeconds)*time.Second, time.Duration(authJitterSeconds)*time.Second, time.Duration(aclJitterSeconds)*time.Second, refreshExpiration, ) } if !authPlugin.cache.Connect(authPlugin.ctx, reset) { authPlugin.cache = nil authPlugin.useCache = false log.Infoln("couldn't start cache, defaulting to no cache") } } //export AuthUnpwdCheck func AuthUnpwdCheck(username, password, clientid *C.char) uint8 { var ok bool var err error for try := 0; try <= authPlugin.retryCount; try++ { ok, err = authUnpwdCheck(C.GoString(username), C.GoString(password), C.GoString(clientid)) if err == nil { break } } if err != nil { log.Error(err) return AuthError } if ok { return AuthGranted } return AuthRejected } func authUnpwdCheck(username, password, clientid string) (bool, error) { var authenticated bool var cached bool var granted bool var err error if authPlugin.useCache { log.Debugf("checking auth cache for %s", username) cached, granted = authPlugin.cache.CheckAuthRecord(authPlugin.ctx, username, password) if cached { log.Debugf("found in cache: %s", username) return granted, nil } } authenticated, err = authPlugin.backends.AuthUnpwdCheck(username, password, clientid) if authPlugin.useCache && err == nil { authGranted := "false" if authenticated { authGranted = "true" } log.Debugf("setting auth cache for %s", username) if setAuthErr := authPlugin.cache.SetAuthRecord(authPlugin.ctx, username, password, authGranted); setAuthErr != nil { log.Errorf("set auth cache: %s", setAuthErr) return false, setAuthErr } } return authenticated, err } //export AuthAclCheck func AuthAclCheck(clientid, username, topic *C.char, acc C.int) uint8 { var ok bool var err error for try := 0; try <= authPlugin.retryCount; try++ { ok, err = authAclCheck(C.GoString(clientid), C.GoString(username), C.GoString(topic), int(acc)) if err == nil { break } } if err != nil { log.Error(err) return AuthError } if ok { return AuthGranted } return AuthRejected } func authAclCheck(clientid, username, topic string, acc int) (bool, error) { var aclCheck bool var cached bool var granted bool var err error if authPlugin.useCache { log.Debugf("checking acl cache for %s", username) cached, granted = authPlugin.cache.CheckACLRecord(authPlugin.ctx, username, topic, clientid, acc) if cached { log.Debugf("found in cache: %s", username) return granted, nil } } aclCheck, err = authPlugin.backends.AuthAclCheck(clientid, username, topic, acc) if authPlugin.useCache && err == nil { authGranted := "false" if aclCheck { authGranted = "true" } log.Debugf("setting acl cache (granted = %s) for %s", authGranted, username) if setACLErr := authPlugin.cache.SetACLRecord(authPlugin.ctx, username, topic, clientid, acc, authGranted); setACLErr != nil { log.Errorf("set acl cache: %s", setACLErr) return false, setACLErr } } log.Debugf("Acl is %t for user %s", aclCheck, username) return aclCheck, err } //export AuthPskKeyGet func AuthPskKeyGet() bool { return true } //export AuthPluginCleanup func AuthPluginCleanup() { log.Info("Cleaning up plugin") //If cache is set, close cache connection. if authPlugin.cache != nil { authPlugin.cache.Close() } authPlugin.backends.Halt() } func main() {}