Skip to content

Commit

Permalink
feat: add url deny handler (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomMoulard authored May 4, 2024
1 parent 0b7e597 commit cec5182
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ lint:
golangci-lint run

.PHONY: test
TEST_ARGS ?= -v -cover -race -tags DEBUG
TEST_ARGS ?= -v -cover -race -tags DEBUG,TEST
test:
go test ${TEST_ARGS} ./...

Expand Down
93 changes: 53 additions & 40 deletions fail2ban.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@ import (
"github.com/tomMoulard/fail2ban/pkg/chain"
"github.com/tomMoulard/fail2ban/pkg/data"
"github.com/tomMoulard/fail2ban/pkg/files"
"github.com/tomMoulard/fail2ban/pkg/ipchecking"
lAllow "github.com/tomMoulard/fail2ban/pkg/list/allow"
lDeny "github.com/tomMoulard/fail2ban/pkg/list/deny"
logger "github.com/tomMoulard/fail2ban/pkg/log"
uAllow "github.com/tomMoulard/fail2ban/pkg/url/allow"
uDeny "github.com/tomMoulard/fail2ban/pkg/url/deny"
)

func init() {
log.SetOutput(os.Stdout)
}

// IPViewed struct.
type IPViewed struct {
viewed time.Time
nb int
denied bool
}

// Urlregexp struct.
type Urlregexp struct {
Regexp string `yaml:"regexp"`
Expand Down Expand Up @@ -149,7 +144,7 @@ type Fail2Ban struct {
rules RulesTransformed

muIP sync.Mutex
ipViewed map[string]IPViewed
ipViewed map[string]ipchecking.IPViewed
}

// ImportIP extract all ip from config sources.
Expand Down Expand Up @@ -233,17 +228,22 @@ func New(_ context.Context, next http.Handler, config *Config, name string) (htt

log.Println("Plugin: FailToBan is up and running")

f2b := &Fail2Ban{
next: next,
name: name,
rules: rules,
ipViewed: make(map[string]ipchecking.IPViewed),
}

urlDeny := uDeny.New(rules.URLRegexpBan, &f2b.muIP, &f2b.ipViewed)

return chain.New(
next,
denyHandler,
allowHandler,
urlDeny,
urlAllow,
&Fail2Ban{
next: next,
name: name,
rules: rules,
ipViewed: make(map[string]IPViewed),
},
f2b,
), nil
}

Expand All @@ -255,73 +255,86 @@ func (u *Fail2Ban) ServeHTTP(rw http.ResponseWriter, req *http.Request) (*chain.
return nil, errors.New("failed to get data from request context")
}

if !u.shouldAllow(data.RemoteIP, req.URL.String()) {
if !u.shouldAllow(data.RemoteIP) {
return &chain.Status{Return: true}, nil
}

return nil, nil
}

// shouldAllow check if the request should be allowed.
func (u *Fail2Ban) shouldAllow(remoteIP, reqURL string) bool {
// Urlregexp ban
func (u *Fail2Ban) shouldAllow(remoteIP string) bool {
u.muIP.Lock()
defer u.muIP.Unlock()

ip, foundIP := u.ipViewed[remoteIP]
urlBytes := []byte(reqURL)

for _, reg := range u.rules.URLRegexpBan {
if reg.Match(urlBytes) {
u.ipViewed[remoteIP] = IPViewed{time.Now(), ip.nb + 1, true}

LoggerDEBUG.Printf("Url (%q) was matched by regexpBan: %q for %q", reqURL, reg.String(), remoteIP)

return false
}
}

// Fail2Ban
if !foundIP {
u.ipViewed[remoteIP] = IPViewed{time.Now(), 1, false}
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: time.Now(),
Count: 1,
}

LoggerDEBUG.Printf("welcome %q", remoteIP)

return true
}

if ip.denied {
if time.Now().Before(ip.viewed.Add(u.rules.Bantime)) {
u.ipViewed[remoteIP] = IPViewed{ip.viewed, ip.nb + 1, true}
if ip.Denied {
if time.Now().Before(ip.Viewed.Add(u.rules.Bantime)) {
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: ip.Viewed,
Count: ip.Count + 1,
Denied: true,
}

LoggerDEBUG.Printf("%q is still banned since %q, %d request",
remoteIP, ip.viewed.Format(time.RFC3339), ip.nb+1)
remoteIP, ip.Viewed.Format(time.RFC3339), ip.Count+1)

return false
}

u.ipViewed[remoteIP] = IPViewed{time.Now(), 1, false}
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: time.Now(),
Count: 1,
Denied: false,
}

LoggerDEBUG.Println(remoteIP + " is no longer banned")

return true
}

if time.Now().Before(ip.viewed.Add(u.rules.Findtime)) {
if ip.nb+1 >= u.rules.MaxRetry {
u.ipViewed[remoteIP] = IPViewed{time.Now(), ip.nb + 1, true}
if time.Now().Before(ip.Viewed.Add(u.rules.Findtime)) {
if ip.Count+1 >= u.rules.MaxRetry {
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: time.Now(),
Count: ip.Count + 1,
Denied: true,
}

LoggerDEBUG.Println(remoteIP + " is now banned temporarily")

return false
}

u.ipViewed[remoteIP] = IPViewed{ip.viewed, ip.nb + 1, false}
LoggerDEBUG.Printf("welcome back %q for the %d time", remoteIP, ip.nb+1)
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: ip.Viewed,
Count: ip.Count + 1,
Denied: false,
}

LoggerDEBUG.Printf("welcome back %q for the %d time", remoteIP, ip.Count+1)

return true
}

u.ipViewed[remoteIP] = IPViewed{time.Now(), 1, false}
u.ipViewed[remoteIP] = ipchecking.IPViewed{
Viewed: time.Now(),
Count: 1,
Denied: false,
}

LoggerDEBUG.Printf("welcome back %q", remoteIP)

Expand Down
52 changes: 20 additions & 32 deletions fail2ban_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ import (
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"sync/atomic"
"testing"
"time"

"github.com/tomMoulard/fail2ban/pkg/ipchecking"
"golang.org/x/net/websocket"
)

Expand Down Expand Up @@ -373,23 +373,22 @@ func TestShouldAllow(t *testing.T) {
name string
cfg *Fail2Ban
remoteIP string
reqURL string
expect bool
}{
{
name: "first request",
cfg: &Fail2Ban{
ipViewed: map[string]IPViewed{},
ipViewed: map[string]ipchecking.IPViewed{},
},
expect: true,
},
{
name: "second request",
cfg: &Fail2Ban{
ipViewed: map[string]IPViewed{
ipViewed: map[string]ipchecking.IPViewed{
"10.0.0.0": {
viewed: time.Now(),
nb: 1,
Viewed: time.Now(),
Count: 1,
},
},
},
Expand All @@ -402,11 +401,11 @@ func TestShouldAllow(t *testing.T) {
rules: RulesTransformed{
Bantime: 300 * time.Second,
},
ipViewed: map[string]IPViewed{
ipViewed: map[string]ipchecking.IPViewed{
"10.0.0.0": {
viewed: time.Now(),
nb: 1,
denied: true,
Viewed: time.Now(),
Count: 1,
Denied: true,
},
},
},
Expand All @@ -419,11 +418,11 @@ func TestShouldAllow(t *testing.T) {
rules: RulesTransformed{
Bantime: 300 * time.Second,
},
ipViewed: map[string]IPViewed{
ipViewed: map[string]ipchecking.IPViewed{
"10.0.0.0": {
viewed: time.Now().Add(-600 * time.Second),
nb: 1,
denied: true,
Viewed: time.Now().Add(-600 * time.Second),
Count: 1,
Denied: true,
},
},
},
Expand All @@ -437,10 +436,10 @@ func TestShouldAllow(t *testing.T) {
MaxRetry: 1,
Findtime: 300 * time.Second,
},
ipViewed: map[string]IPViewed{
ipViewed: map[string]ipchecking.IPViewed{
"10.0.0.0": {
viewed: time.Now().Add(600 * time.Second),
nb: 1,
Viewed: time.Now().Add(600 * time.Second),
Count: 1,
},
},
},
Expand All @@ -454,34 +453,23 @@ func TestShouldAllow(t *testing.T) {
MaxRetry: 3,
Findtime: 300 * time.Second,
},
ipViewed: map[string]IPViewed{
ipViewed: map[string]ipchecking.IPViewed{
"10.0.0.0": {
viewed: time.Now().Add(600 * time.Second),
nb: 1,
Viewed: time.Now().Add(600 * time.Second),
Count: 1,
},
},
},
remoteIP: "10.0.0.0",
expect: true,
},
{
name: "block regexp",
cfg: &Fail2Ban{
rules: RulesTransformed{
URLRegexpBan: []*regexp.Regexp{regexp.MustCompile("/test")}, // comment me.
},
ipViewed: map[string]IPViewed{},
},
reqURL: "/test",
expect: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

got := test.cfg.shouldAllow(test.remoteIP, test.reqURL)
got := test.cfg.shouldAllow(test.remoteIP)
if test.expect != got {
t.Errorf("wanted '%t' got '%t'", test.expect, got)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/ipchecking/ipChecking.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@ import (
"log"
"net/netip"
"strings"
"time"
)

// IPViewed struct.
type IPViewed struct {
Viewed time.Time
Count int
Denied bool
}

// NetIP struct that holds an NetIP IP address, and a IP network.
// If the network is nil, the NetIP is a single IP.
type NetIP struct {
Expand Down
4 changes: 2 additions & 2 deletions pkg/list/deny/deny.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ func New(ipList []string) (*deny, error) {
return &deny{list: list}, nil
}

func (a *deny) ServeHTTP(w http.ResponseWriter, r *http.Request) (*chain.Status, error) {
func (d *deny) ServeHTTP(w http.ResponseWriter, r *http.Request) (*chain.Status, error) {
data := data.GetData(r)
if data == nil {
return nil, errors.New("failed to get data from request context")
}

l.Printf("data: %+v", data)

if a.list.Contains(data.RemoteIP) {
if d.list.Contains(data.RemoteIP) {
l.Printf("IP %s is denied", data.RemoteIP)

return &chain.Status{Return: true}, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/list/deny/deny_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ func TestDeny(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()

a, err := New(test.ipList)
d, err := New(test.ipList)
require.NoError(t, err)

recorder := &httptest.ResponseRecorder{}
req := httptest.NewRequest(http.MethodGet, "https://example.com/foo", nil)
req, err = data.ServeHTTP(recorder, req)
require.NoError(t, err)

got, err := a.ServeHTTP(recorder, req)
got, err := d.ServeHTTP(recorder, req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatus, got)
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/url/allow/allow.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package allow is a middleware
// Package allow is a middleware that force allows requests from a list of regexps.
package allow

import (
Expand Down
Loading

0 comments on commit cec5182

Please sign in to comment.