From 0b19890ea831c4daea287849974306cfebd30702 Mon Sep 17 00:00:00 2001 From: Tyler Davis Date: Fri, 8 Oct 2021 00:16:55 +0000 Subject: [PATCH] geoip2: watch and update Add goroutine to watch file modification time each minute. Lock and update the files independently if the modification time is newer than the last time it was loaded. Also ancillary fixes for type signatures, and formatting test files. Fixes #116, #16 --- http_test.go | 2 - targeting/geoip2/geoip2.go | 96 ++++++++++++++++++++++++++++++------- targeting/targeting_test.go | 9 ++-- 3 files changed, 84 insertions(+), 23 deletions(-) diff --git a/http_test.go b/http_test.go index 6e834204..a67edd4f 100644 --- a/http_test.go +++ b/http_test.go @@ -15,7 +15,6 @@ import ( ) func TestHTTP(t *testing.T) { - geoprovider, err := geoip2.New(geoip2.FindDB()) if err == nil { targeting.Setup(geoprovider) @@ -43,5 +42,4 @@ func TestHTTP(t *testing.T) { t.Log("/version didn't start with 'GeoDNS '") t.Fail() } - } diff --git a/targeting/geoip2/geoip2.go b/targeting/geoip2/geoip2.go index 60f7ed74..3df781c2 100644 --- a/targeting/geoip2/geoip2.go +++ b/targeting/geoip2/geoip2.go @@ -2,12 +2,14 @@ package geoip2 import ( "fmt" + "io/fs" "log" "net" "os" "path/filepath" "strings" "sync" + "time" "github.com/abh/geodns/countries" "github.com/abh/geodns/targeting/geo" @@ -29,17 +31,25 @@ var dbFiles map[geoType][]string type GeoIP2 struct { dir string - country *geoip2.Reader - city *geoip2.Reader - asn *geoip2.Reader - mu sync.RWMutex + country geodb + city geodb + asn geodb + + mu sync.RWMutex +} + +type geodb struct { + db *geoip2.Reader // Database reader + fp string // FilePath + lastModified int64 // Epoch time + // l sync.Mutex // Individual lock for separate DB access and reload -- Future? } func init() { dbFiles = map[geoType][]string{ - countryDB: []string{"GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb"}, - asnDB: []string{"GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb"}, - cityDB: []string{"GeoIP2-City.mmdb", "GeoLite2-City.mmdb"}, + countryDB: {"GeoIP2-Country.mmdb", "GeoLite2-Country.mmdb"}, + asnDB: {"GeoIP2-ASN.mmdb", "GeoLite2-ASN.mmdb"}, + cityDB: {"GeoIP2-City.mmdb", "GeoLite2-City.mmdb"}, } } @@ -64,14 +74,15 @@ func FindDB() string { } func (g *GeoIP2) open(t geoType, db string) (*geoip2.Reader, error) { - fileName := filepath.Join(g.dir, db) + var fi fs.FileInfo if len(db) == 0 { found := false for _, f := range dbFiles[t] { + var err error fileName = filepath.Join(g.dir, f) - if _, err := os.Stat(fileName); err == nil { + if fi, err = os.Stat(fileName); err == nil { found = true break } @@ -90,11 +101,17 @@ func (g *GeoIP2) open(t geoType, db string) (*geoip2.Reader, error) { switch t { case countryDB: - g.country = n + g.country.db = n + g.country.lastModified = fi.ModTime().UTC().Unix() + g.country.fp = fileName case cityDB: - g.city = n + g.city.db = n + g.city.lastModified = fi.ModTime().UTC().Unix() + g.city.fp = fileName case asnDB: - g.asn = n + g.asn.db = n + g.asn.lastModified = fi.ModTime().UTC().Unix() + g.asn.fp = fileName } return n, nil } @@ -106,11 +123,11 @@ func (g *GeoIP2) get(t geoType, db string) (*geoip2.Reader, error) { switch t { case countryDB: - r = g.country + r = g.country.db case cityDB: - r = g.city + r = g.city.db case asnDB: - r = g.asn + r = g.asn.db } // unlock so the g.open() call below won't lock @@ -123,6 +140,50 @@ func (g *GeoIP2) get(t geoType, db string) (*geoip2.Reader, error) { return g.open(t, db) } +func (g *GeoIP2) watchFiles() { + // Not worried about goroutines leaking because only one geoip2.New call is made in geodns.go (outside of testing) + ticker := time.NewTicker(1 * time.Minute) + go func() { + for { + select { + case <-ticker.C: + g.checkForUpdate() + default: + time.Sleep(12 * time.Second) // Sleep to avoid constant looping + } + } + }() +} + +func (g *GeoIP2) checkForUpdate() { + // Iterate through each file, check modtime. If new, reload file + d := []*geodb{&g.country, &g.city, &g.asn} // Slice of pointers is kinda gross, but want to directly reference struct values (per const type) + for _, v := range d { + fi, err := os.Stat(v.fp) + if err != nil { + log.Printf("unable to stat DB file at %s :: %v", v.fp, err) + continue + } + if fi.ModTime().UTC().Unix() > v.lastModified { + g.mu.Lock() + e := v.db.Close() + if e != nil { + g.mu.Unlock() + log.Printf("unable to close DB file %s : %v", v.fp, e) + continue + } + n, e := geoip2.Open(v.fp) + if e != nil { + g.mu.Unlock() + log.Printf("unable to reopen DB file %s : %v", v.fp, e) + continue + } + v.db = n + g.mu.Unlock() + } + } +} + // New returns a new GeoIP2 provider func New(dir string) (*GeoIP2, error) { g := &GeoIP2{ @@ -133,6 +194,8 @@ func New(dir string) (*GeoIP2, error) { return nil, err } + go g.watchFiles() // Launch goroutine to monitor + return g, nil } @@ -205,7 +268,7 @@ func (g *GeoIP2) HasLocation() (bool, error) { // GetLocation returns a geo.Location object for the given IP func (g *GeoIP2) GetLocation(ip net.IP) (l *geo.Location, err error) { - c, err := g.city.City(ip) + c, err := g.city.db.City(ip) if err != nil { log.Printf("Could not lookup CountryRegion for '%s': %s", ip.String(), err) return @@ -229,5 +292,4 @@ func (g *GeoIP2) GetLocation(ip net.IP) (l *geo.Location, err error) { } return - } diff --git a/targeting/targeting_test.go b/targeting/targeting_test.go index b6f109a3..c2da5aee 100644 --- a/targeting/targeting_test.go +++ b/targeting/targeting_test.go @@ -31,9 +31,9 @@ func TestTargetParse(t *testing.T) { } tests := [][]string{ - []string{"@ continent country asn", "@ continent country asn"}, - []string{"asn country", "country asn"}, - []string{"continent @ country", "@ continent country"}, + {"@ continent country asn", "@ continent country asn"}, + {"asn country", "country asn"}, + {"continent @ country", "@ continent country"}, } for _, strs := range tests { @@ -102,7 +102,8 @@ func TestGetTargets(t *testing.T) { if ok, _ := g.HasASN(); ok { tests = append(tests, - test{"@ continent regiongroup country region asn ip", + test{ + "@ continent regiongroup country region asn ip", []string{"[98.248.0.1]", "[98.248.0.0]", "as7922", "us-ca", "us-west", "us", "north-america", "@"}, "98.248.0.1", },