From dc3cf9d49553eb748b60786dbc351c94ba682738 Mon Sep 17 00:00:00 2001 From: Burak Sezer Date: Wed, 9 Nov 2022 23:39:25 +0300 Subject: [PATCH] fix: AverageLoad() function panics with "divide by zero" when no members are in the hash ring #19 --- consistent.go | 15 +++++++++++++-- consistent_test.go | 30 ++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/consistent.go b/consistent.go index e468a25..77b89f3 100644 --- a/consistent.go +++ b/consistent.go @@ -136,7 +136,7 @@ func New(members []Member, config Config) *Consistent { return c } -// GetMembers returns a thread-safe copy of members. +// GetMembers returns a thread-safe copy of members. If there are no members, it returns an empty slice of Member. func (c *Consistent) GetMembers() []Member { c.mu.RLock() defer c.mu.RUnlock() @@ -151,12 +151,23 @@ func (c *Consistent) GetMembers() []Member { // AverageLoad exposes the current average load. func (c *Consistent) AverageLoad() float64 { + c.mu.RLock() + defer c.mu.RUnlock() + + return c.averageLoad() +} + +func (c *Consistent) averageLoad() float64 { + if len(c.members) == 0 { + return 0 + } + avgLoad := float64(c.partitionCount/uint64(len(c.members))) * c.config.Load return math.Ceil(avgLoad) } func (c *Consistent) distributeWithLoad(partID, idx int, partitions map[int]*Member, loads map[string]float64) { - avgLoad := c.AverageLoad() + avgLoad := c.averageLoad() var count int for { count++ diff --git a/consistent_test.go b/consistent_test.go index bd11f68..d4a7fce 100644 --- a/consistent_test.go +++ b/consistent_test.go @@ -95,22 +95,32 @@ func TestConsistentRemove(t *testing.T) { } func TestConsistentLoad(t *testing.T) { - members := []Member{} + var members []Member for i := 0; i < 8; i++ { member := testMember(fmt.Sprintf("node%d.olric", i)) members = append(members, member) } cfg := newConfig() - c := New(members, cfg) - if len(c.GetMembers()) != len(members) { - t.Fatalf("inserted member count is different") - } - maxLoad := c.AverageLoad() - for member, load := range c.LoadDistribution() { - if load > maxLoad { - t.Fatalf("%s exceeds max load. Its load: %f, max load: %f", member, load, maxLoad) + + t.Run("Average load should be greater than the member's load", func(t *testing.T) { + c := New(members, cfg) + if len(c.GetMembers()) != len(members) { + t.Fatalf("inserted member count is different") } - } + maxLoad := c.AverageLoad() + for member, load := range c.LoadDistribution() { + if load > maxLoad { + t.Fatalf("%s exceeds max load. Its load: %f, max load: %f", member, load, maxLoad) + } + } + }) + + t.Run("Average load should equal to zero if there are no members", func(t *testing.T) { + c := New(nil, cfg) + if c.AverageLoad() != 0 { + t.Fatalf("AverageLoad should equal to zero") + } + }) } func TestConsistentLocateKey(t *testing.T) {