Skip to content

Commit

Permalink
实现分布式节点缓存,已进行测试
Browse files Browse the repository at this point in the history
  • Loading branch information
peanutzhen committed Jul 30, 2021
1 parent 2c0117b commit 7e8b84c
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 15 deletions.
51 changes: 51 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2021 Peanutzhen. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.

package peanutcache

import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
)

// client 模块实现peanutcache访问其他远程节点
// 从而获取缓存的能力

type Client struct {
reqURL string
}

// Fetch 从remote peer获取对应缓存值
func (c *Client) Fetch(group string, key string) ([]byte, error) {
// 构造请求url
u := fmt.Sprintf(
"%s%s/%s",
c.reqURL,
url.QueryEscape(group),
url.QueryEscape(key),
)

resp, err := http.Get(u)
if err != nil {
return nil, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("peer Statuscode: %d", resp.StatusCode)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body failed, %v", err)
}
return body, nil
}

func NewClient(reqURL string) *Client {
return &Client{reqURL: reqURL}
}

var _ Fetcher = (*Client)(nil)
22 changes: 22 additions & 0 deletions peanutcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Group struct {
name string
cache *cache
retriever Retriever
server Picker
}

// NewGroup 创建一个新的缓存空间
Expand All @@ -54,6 +55,14 @@ func NewGroup(name string, maxBytes int64, retriever Retriever) *Group {
return g
}

// RegisterSvr 为 Group 注册 Server
func (g *Group) RegisterSvr(p Picker) {
if g.server != nil {
panic("group had been registered server")
}
g.server = p
}

// GetGroup 获取对应命名空间的缓存
func GetGroup(name string) *Group {
mu.RLock()
Expand All @@ -70,13 +79,25 @@ func (g *Group) Get(key string) (ByteView, error) {
log.Println("cache hit")
return value, nil
}
// cache missing, get it another way
return g.load(key)
}

func (g *Group) load(key string) (ByteView, error) {
if g.server != nil {
if fetcher, ok := g.server.Pick(key); ok {
bytes, err := fetcher.Fetch(g.name, key)
if err == nil {
return ByteView{b: cloneBytes(bytes)}, nil
}
log.Printf("fail to get *%s* from peer, %s.\n", key, err.Error())
return ByteView{}, err
}
}
return g.getLocally(key)
}

// getLocally 本地向Retriever取回数据并填充缓存
func (g *Group) getLocally(key string) (ByteView, error) {
bytes, err := g.retriever.retrieve(key)
if err != nil {
Expand All @@ -87,6 +108,7 @@ func (g *Group) getLocally(key string) (ByteView, error) {
return value, nil
}

// populateCache 提供填充缓存的能力
func (g *Group) populateCache(key string, value ByteView) {
g.cache.add(key, value)
}
18 changes: 18 additions & 0 deletions peers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright 2021 Peanutzhen. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.

package peanutcache

// peers 模块

// Picker 定义了获取分布式节点的能力
type Picker interface {
Pick(key string) (Fetcher, bool)
}

// Fetcher 定义了从远端获取缓存的能力
// 所以每个Peer应实现这个接口
type Fetcher interface {
Fetch(group string, key string) ([]byte, error)
}
68 changes: 62 additions & 6 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,41 @@ import (
"fmt"
"log"
"net/http"
"net/url"
"strings"
"sync"

"github.com/peanutzhen/peanutcache/consistenthash"
)

// server 模块为peanutcache提供http通信能力
// server 模块为peanutcache之间提供http通信能力
// 这样部署在其他机器上的cache可以通过http访问获取缓存
// 至于找哪台主机 那是一致性哈希的工作了
// 注意: peer间通信采用http协议

const defaultBasePath = "/_pcache/"
const (
defaultBasePath = "/_pcache/"
defaultReplicas = 50
)

// Server 和 Group 是解耦合的 所以Server要自己实现并发控制
type Server struct {
addr string
addr string // format: ip:port
basePath string

mu sync.Mutex
consHash *consistenthash.Consistency
clients map[string]*Client
}

func NewServer(addr string) *Server {
func NewServer(addr string) (*Server, error) {
if len(strings.Split(addr, ":")) != 2 {
return nil, fmt.Errorf("server addr format-> ip:port")
}
return &Server{
addr: addr,
basePath: defaultBasePath,
}
}, nil
}

func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -76,4 +92,44 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if err != nil {
panic("ByteView write to response failed")
}
}
}

// SetPeers 将各个远端主机IP配置到Server里
// 这样Server就可以Pick他们了
// 注意: 此操作是*覆写*操作!
// 注意: peersIP必须满足 http:https://x.x.x.x:port的格式
func (s *Server) SetPeers(peersURL ...string) {
s.mu.Lock()
defer s.mu.Unlock()
s.consHash = consistenthash.New(defaultReplicas, nil)
s.consHash.Register(peersURL...)
s.clients = make(map[string]*Client)
for _, peerURL := range peersURL {
if !validPeerURL(peerURL) {
panic(fmt.Sprintf("[peer %s] using not a http protocol or containing a path.", peerURL))
}
s.clients[peerURL] = NewClient(peerURL + defaultBasePath)
}
}

// Pick 根据一致性哈希选举出key应存放在的cache
// return false 代表从本地获取cache
func (s *Server) Pick(key string) (Fetcher, bool) {
s.mu.Lock()
defer s.mu.Unlock()

peerURL := s.consHash.GetPeer(key)
u, err := url.Parse(peerURL)
if err != nil {
return nil, false
}
// Pick itself
if u.Host == s.addr {
log.Printf("ooh! pick myself, I am %s\n", s.addr)
return nil, false
}
log.Printf("Pick remote peer: %s\n", peerURL)
return s.clients[peerURL], true
}

var _ Picker = (*Server)(nil)
18 changes: 9 additions & 9 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"testing"
)

func createTestSvr() *httptest.Server{
func createTestSvr() *httptest.Server {
mysql := map[string]string{
"Tom": "630",
"Jack": "589",
Expand All @@ -30,18 +30,18 @@ func createTestSvr() *httptest.Server{
return nil, fmt.Errorf("%s not exist", key)
}))

svr := NewServer("")
svr, _ := NewServer("localhost:9999")
ts := httptest.NewServer(svr)
svr.addr = ts.URL
return ts
}

func TestServer_GetExistsKey(t *testing.T) {
ts := createTestSvr()
res, _ := http.Get(fmt.Sprintf("%s%sscores/Tom", ts.URL,defaultBasePath))
res, _ := http.Get(fmt.Sprintf("%s%sscores/Tom", ts.URL, defaultBasePath))
body, _ := ioutil.ReadAll(res.Body)
if !reflect.DeepEqual(string(body), "630") {
t.Errorf("Tom %s(actual)/%s(ok)", string(body),"630")
t.Errorf("Tom %s(actual)/%s(ok)", string(body), "630")
}
res.Body.Close()
ts.Close()
Expand All @@ -54,7 +54,7 @@ func TestServer_GetBadPath(t *testing.T) {
t.Errorf("Status code should be 500.")
}
body, _ := ioutil.ReadAll(res.Body)
t.Log("错误basePath查询Tom返回: "+string(body))
t.Log("错误basePath查询Tom返回: " + string(body))
res.Body.Close()
ts.Close()
}
Expand All @@ -66,7 +66,7 @@ func TestServer_GetUnknownKey(t *testing.T) {
t.Errorf("Status code should be 500.")
}
body, _ := ioutil.ReadAll(res.Body)
t.Log("正确规则查询不存在key: "+string(body))
t.Log("正确规则查询不存在key: " + string(body))
res.Body.Close()
}

Expand All @@ -77,7 +77,7 @@ func TestServer_GetUnknownGroup(t *testing.T) {
t.Errorf("Status code should be 500.")
}
body, _ := ioutil.ReadAll(res.Body)
t.Log("正确规则查询不存在group: "+string(body))
t.Log("正确规则查询不存在group: " + string(body))
res.Body.Close()
}

Expand All @@ -88,6 +88,6 @@ func TestServer_GetNoKey(t *testing.T) {
t.Errorf("Status code should be 500.")
}
body, _ := ioutil.ReadAll(res.Body)
t.Log("错误规则不填key返回: "+string(body))
t.Log("错误规则不填key返回: " + string(body))
res.Body.Close()
}
}
13 changes: 13 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"runtime"
"strings"
"net/url"
)

// 显示错误时运行堆栈
Expand All @@ -25,3 +26,15 @@ func trace(errorMessage string) string {
}
return str.String()
}

// 判断是否满足http:https://x.x.x.x:port的格式
func validPeerURL(URL string) bool {
u, err := url.Parse(URL)
if err != nil {
return false
}
if u.Scheme != "http" || len(u.Path) != 0 {
return false
}
return true
}

0 comments on commit 7e8b84c

Please sign in to comment.