Skip to content

Commit

Permalink
new connection store and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
andyollylarkin committed Nov 22, 2023
1 parent af11f97 commit 96ae108
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 57 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.20
require (
github.com/gorilla/mux v1.8.1
github.com/gorilla/websocket v1.5.1
github.com/hashicorp/golang-lru v1.0.2
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
)
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c=
github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
Expand Down
75 changes: 75 additions & 0 deletions transport/ws_transport/internal/conn_store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package internal

import (
"fmt"
"net"
"sync"
)

type ConnectionStore struct {
conns map[string]*WsConnAdapter
mu sync.RWMutex
}

func NewConnectionStore() *ConnectionStore {
return &ConnectionStore{
conns: make(map[string]*WsConnAdapter),
}
}

// ConnCacheSet store connection in cache.
func (cs *ConnectionStore) ConnCacheSet(addr net.Addr, conn *WsConnAdapter) error {
h, err := extractIpFromAddr(addr)
if err != nil {
return fmt.Errorf("cant set conn cache for %s, %w", addr.String(), err)
}

cs.mu.Lock()
defer cs.mu.Unlock()

cs.conns[h] = conn

return nil
}

// ConnCacheRemove remove connection from cache.
func (cs *ConnectionStore) ConnCacheRemove(addr net.Addr) bool {
h, err := extractIpFromAddr(addr)
if err != nil {
return false
}

cs.mu.Lock()
defer cs.mu.Unlock()

delete(cs.conns, h)

return true
}

// ConnCacheGet get connection from cache.
func (cs *ConnectionStore) ConnCacheGet(addr net.Addr) (*WsConnAdapter, bool, error) {
h, err := extractIpFromAddr(addr)
if err != nil {
return nil, false, fmt.Errorf("cant get conn for addr %s from cache, %w", addr.String(), err)
}

cs.mu.Lock()
defer cs.mu.Unlock()

conn, ok := cs.conns[h]
if !ok {
return nil, false, nil
}

return conn, true, nil
}

func extractIpFromAddr(addr net.Addr) (string, error) {
h, _, err := net.SplitHostPort(addr.String())
if err != nil {
return "", err
}

return h, nil
}
164 changes: 164 additions & 0 deletions transport/ws_transport/internal/conn_store_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package internal

import (
"errors"
"net"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestConnectionStore_ConnCacheSet(t *testing.T) {
type fields struct {
conns map[string]*WsConnAdapter
mu sync.RWMutex
}
type args struct {
addr net.Addr
conn *WsConnAdapter
}
tests := []struct {
name string
fields fields
expectedAddr string
args args
err error
}{
{
name: "Set ok",
fields: fields{
conns: make(map[string]*WsConnAdapter),
},
args: args{
addr: &WsAddr{
WsAddrTCP: func() net.TCPAddr {
a, _ := net.ResolveTCPAddr("tcp", "192.168.1.1:8888")
return *a
}(),
},
conn: &WsConnAdapter{
realRemoteAddr: func() *net.TCPAddr {
a, _ := net.ResolveTCPAddr("tcp", "192.168.1.1:8888")
return a
}(),
},
},
expectedAddr: "192.168.1.1:8888",
err: nil,
},
{
name: "Set cant parse addr",
fields: fields{
conns: make(map[string]*WsConnAdapter),
},
args: args{
addr: &WsAddr{
WsAddrTCP: func() net.TCPAddr {
a := net.TCPAddr{
IP: net.IPv4(192, 168, 1, 1),
}
return a
}(),
},
conn: &WsConnAdapter{
realRemoteAddr: func() *net.TCPAddr {
a := net.TCPAddr{
IP: net.IPv4(192, 168, 1, 1),
}
return &a
}(),
},
},
expectedAddr: "192.168.1.1:0",
err: errors.New(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := &ConnectionStore{
conns: tt.fields.conns,
mu: tt.fields.mu,
}
err := cs.ConnCacheSet(tt.args.addr, tt.args.conn)
if err == nil {
require.NoError(t, err)
} else {
require.Errorf(t, err, err.Error())
}
require.NotNil(t, tt.args.conn)
assert.Equal(t, tt.expectedAddr, tt.args.conn.RemoteAddr().String())
})
}
}

func TestConnectionStore_ConnCacheGet(t *testing.T) {
type fields struct {
conns map[string]*WsConnAdapter
mu sync.RWMutex
}
type args struct {
addr net.Addr
conn *WsConnAdapter
setConnectionToCache bool
}
tests := []struct {
name string
fields fields
args args
want *WsConnAdapter
want1 bool
wantErr bool
}{
{
name: "Get connection ok",
fields: fields{
conns: make(map[string]*WsConnAdapter),
},
args: args{
addr: &WsAddr{
WsAddrTCP: net.TCPAddr{
IP: net.IPv4(192, 168, 1, 1),
},
},
setConnectionToCache: true,
conn: &WsConnAdapter{},
},
want: &WsConnAdapter{},
want1: true,
wantErr: false,
},
{
name: "Get connection when connection not set",
fields: fields{
conns: make(map[string]*WsConnAdapter),
},
args: args{
addr: &WsAddr{
WsAddrTCP: net.TCPAddr{
IP: net.IPv4(192, 168, 1, 1),
},
},
setConnectionToCache: false,
conn: &WsConnAdapter{},
},
want: &WsConnAdapter{},
want1: false,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := &ConnectionStore{
conns: tt.fields.conns,
}
if tt.args.setConnectionToCache {
cs.ConnCacheSet(tt.args.addr, tt.want)
}
_, ok, err := cs.ConnCacheGet(tt.args.addr)
require.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.want1, ok)
})
}
}
62 changes: 8 additions & 54 deletions transport/ws_transport/ws_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/andyollylarkin/smudge-custom-transport/transport"
"github.com/andyollylarkin/smudge-custom-transport/transport/ws_transport/internal"
"github.com/gorilla/websocket"
lru "github.com/hashicorp/golang-lru"
)

const (
Expand All @@ -25,61 +24,16 @@ var (
)

type WsTransport struct {
cache *lru.Cache
cache *internal.ConnectionStore
wg sync.WaitGroup
remoteWsServerPort *int
wsBasePath string
connChan chan *internal.WsConnAdapter
logger smudge.Logger
}

// connCacheSet store connection in LRU cache.
func (wst *WsTransport) connCacheSet(addr net.Addr, conn *internal.WsConnAdapter) (bool, error) {
h, _, err := net.SplitHostPort(addr.String())
if err != nil {
return false, fmt.Errorf("cant set conn cache for %s, %w", addr.String(), err)
}

return wst.cache.Add(h, conn), nil
}

// connCacheRemove remove connection from LRU cache.
func (wst *WsTransport) connCacheRemove(addr net.Addr) bool {
h, _, err := net.SplitHostPort(addr.String())
if err != nil {
return false
}

return wst.cache.Remove(h)
}

// connCacheGet get connection from LRU cache.
func (wst *WsTransport) connCacheGet(addr net.Addr) (*internal.WsConnAdapter, bool, error) {
h, _, err := net.SplitHostPort(addr.String())
if err != nil {
return nil, false, fmt.Errorf("cant get conn for addr %s from cache, %w", addr.String(), err)
}

conn, ok := wst.cache.Get(h)
if !ok {
wst.logger.Logf(smudge.LogDebug, "connection get cache miss for %s", addr.String())

return nil, false, nil
}

wsConn, ok := conn.(*internal.WsConnAdapter)
if !ok {
return nil, false, fmt.Errorf("cant get conn for addr %s from cache. Conn type isn't WsConnAdapter", addr.String())
}

return wsConn, true, nil
}

func NewWsTransport(logger smudge.Logger, remoteWsServerPort *int, wsBasePath string) (*WsTransport, error) {
cache, err := lru.New(MaxLRUCacheItems)
if err != nil {
return nil, fmt.Errorf("cant create connections cache: %w", err)
}
cache := internal.NewConnectionStore()

t := new(WsTransport)

Expand All @@ -100,7 +54,7 @@ func (wst *WsTransport) UpgageWebsocket(w http.ResponseWriter, r *http.Request)
return fmt.Errorf("cant upgrade websocket connection: %w", err)
}

_, ok, err := wst.connCacheGet(wsconn.RemoteAddr())
_, ok, err := wst.cache.ConnCacheGet(wsconn.RemoteAddr())
if err != nil {
return err
}
Expand All @@ -114,7 +68,7 @@ func (wst *WsTransport) UpgageWebsocket(w http.ResponseWriter, r *http.Request)
return err
}

_, err = wst.connCacheSet(adapter.RemoteAddr(), adapter)
err = wst.cache.ConnCacheSet(adapter.RemoteAddr(), adapter)
if err != nil {
return err
}
Expand Down Expand Up @@ -145,14 +99,14 @@ func (wst *WsTransport) Listen(network string, addr transport.SockAddr) (transpo

func (wst *WsTransport) connCloseMonitor(connErrChan chan net.Addr) {
for addr := range connErrChan {
conn, ok, err := wst.connCacheGet(addr)
conn, ok, err := wst.cache.ConnCacheGet(addr)
if err != nil || !ok {
continue
}

conn.ActuallyClose()

wst.connCacheRemove(addr)
wst.cache.ConnCacheRemove(addr)

wst.logger.Logf(smudge.LogDebug, "Actually close %s", conn.RemoteAddr().String())
}
Expand All @@ -161,7 +115,7 @@ func (wst *WsTransport) connCloseMonitor(connErrChan chan net.Addr) {
func (wst *WsTransport) Dial(ctx context.Context, laddr transport.SockAddr,
raddr transport.SockAddr,
) (transport.GenericConn, error) {
c, ok, err := wst.connCacheGet(raddr)
c, ok, err := wst.cache.ConnCacheGet(raddr)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -212,7 +166,7 @@ func (wst *WsTransport) Dial(ctx context.Context, laddr transport.SockAddr,
return nil, err
}

_, err = wst.connCacheSet(adapter.RemoteAddr(), adapter)
err = wst.cache.ConnCacheSet(adapter.RemoteAddr(), adapter)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 96ae108

Please sign in to comment.