Skip to content

Commit

Permalink
all: dnsfilter rm config embed
Browse files Browse the repository at this point in the history
  • Loading branch information
Mizzick committed Aug 30, 2023
1 parent a2ca8b5 commit 5aa6212
Show file tree
Hide file tree
Showing 20 changed files with 324 additions and 222 deletions.
7 changes: 3 additions & 4 deletions internal/dnsforward/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
s.serverLock.RLock()
defer s.serverLock.RUnlock()

disabledUntil = s.dnsFilter.ProtectionDisabledUntil
enabled, disabledUntil = s.dnsFilter.ProtectionStatus()
if disabledUntil == nil {
return s.dnsFilter.ProtectionEnabled, nil
return enabled, nil
}

if time.Now().Before(*disabledUntil) {
Expand Down Expand Up @@ -526,8 +526,7 @@ func (s *Server) enableProtectionAfterPause() {
s.serverLock.Lock()
defer s.serverLock.Unlock()

s.dnsFilter.ProtectionEnabled = true
s.dnsFilter.ProtectionDisabledUntil = nil
s.dnsFilter.UpdateProtectionStatus(true, nil)

log.Info("dns: protection is restarted after pause")
}
7 changes: 2 additions & 5 deletions internal/dnsforward/dnsforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,8 @@ func (s *Server) setupLocalResolvers() (err error) {
func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.conf = *conf

err = validateBlockingMode(
s.dnsFilter.BlockingMode,
s.dnsFilter.BlockingIPv4,
s.dnsFilter.BlockingIPv6,
)
mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode()
err = validateBlockingMode(mode, bIPv4, bIPv6)
if err != nil {
return fmt.Errorf("checking blocking mode: %w", err)
}
Expand Down
42 changes: 26 additions & 16 deletions internal/dnsforward/dnsforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ func createTestServer(
})
require.NoError(t, err)

if s.dnsFilter.BlockingMode == "" {
s.dnsFilter.BlockingMode = filtering.BlockingModeDefault
}

err = s.Prepare(&forwardConf)
require.NoError(t, err)

Expand Down Expand Up @@ -178,7 +174,9 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
var keyPem []byte
_, certPem, keyPem = createServerTLSConfig(t)

s = createTestServer(t, &filtering.Config{}, ServerConfig{
s = createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
Expand Down Expand Up @@ -351,9 +349,8 @@ func TestServer_timeout(t *testing.T) {
},
}

s, err := NewServer(DNSCreateParams{DNSFilter: &filtering.DNSFilter{}})
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
require.NoError(t, err)
s.dnsFilter.BlockingMode = filtering.BlockingModeDefault

err = s.Prepare(srvConf)
require.NoError(t, err)
Expand All @@ -362,10 +359,9 @@ func TestServer_timeout(t *testing.T) {
})

t.Run("default", func(t *testing.T) {
s, err := NewServer(DNSCreateParams{DNSFilter: &filtering.DNSFilter{}})
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)})
require.NoError(t, err)

s.dnsFilter.BlockingMode = filtering.BlockingModeDefault
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{
Enabled: false,
}
Expand All @@ -377,7 +373,9 @@ func TestServer_timeout(t *testing.T) {
}

func TestServerWithProtectionDisabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
Expand Down Expand Up @@ -490,6 +488,7 @@ func TestSafeSearch(t *testing.T) {
}

filterConf := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeSearchConf: safeSearchConf,
SafeSearchCacheSize: 1000,
Expand Down Expand Up @@ -564,7 +563,9 @@ func TestSafeSearch(t *testing.T) {
}

func TestInvalidRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
Expand Down Expand Up @@ -631,7 +632,9 @@ func TestServerCustomClientUpstream(t *testing.T) {
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) {
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
Expand Down Expand Up @@ -674,7 +677,9 @@ var testIPv4 = map[string][]net.IP{
}

func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t, &filtering.Config{}, ServerConfig{
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
Expand Down Expand Up @@ -789,7 +794,9 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
},
},
}
s := createTestServer(t, &filtering.Config{}, forwardConf, nil)
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{
CName: testCNAMEs,
Expand Down Expand Up @@ -901,8 +908,10 @@ func TestBlockedCustomIP(t *testing.T) {
err = s.Prepare(conf)
assert.Error(t, err)

s.dnsFilter.BlockingIPv4 = netip.AddrFrom4([4]byte{0, 0, 0, 1})
s.dnsFilter.BlockingIPv6 = netip.MustParseAddr("::1")
s.dnsFilter.UpdateBlockingMode(
filtering.BlockingModeCustomIP,
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
netip.MustParseAddr("::1"))

err = s.Prepare(conf)
require.NoError(t, err)
Expand Down Expand Up @@ -980,6 +989,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
ans4, _ := aghtest.HostToIPs(hostname)

filterConf := &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
ProtectionEnabled: true,
SafeBrowsingEnabled: true,
SafeBrowsingChecker: sbChecker,
Expand Down
11 changes: 8 additions & 3 deletions internal/dnsforward/dnsrewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,14 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
}

// Helper functions and entities.
srv := &Server{
dnsFilter: &filtering.DNSFilter{},
}
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
}, nil)

makeQ := func(qtype rules.RRType) (req *dns.Msg) {
return &dns.Msg{
Question: []dns.Question{{
Expand Down
22 changes: 9 additions & 13 deletions internal/dnsforward/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
upstreamFile := s.conf.UpstreamDNSFileName
bootstraps := stringutil.CloneSliceOrEmpty(s.conf.BootstrapDNS)
fallbacks := stringutil.CloneSliceOrEmpty(s.conf.FallbackDNS)
blockingMode := s.dnsFilter.BlockingMode
blockingIPv4 := s.dnsFilter.BlockingIPv4
blockingIPv6 := s.dnsFilter.BlockingIPv6
blockingMode, blockingIPv4, blockingIPv6 := s.dnsFilter.BlockingMode()
ratelimit := s.conf.Ratelimit

customIP := s.conf.EDNSClientSubnet.CustomIP
Expand Down Expand Up @@ -320,11 +318,11 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
defer s.serverLock.Unlock()

if dc.BlockingMode != nil {
s.dnsFilter.BlockingMode = *dc.BlockingMode
if *dc.BlockingMode == filtering.BlockingModeCustomIP {
s.dnsFilter.BlockingIPv4 = dc.BlockingIPv4
s.dnsFilter.BlockingIPv6 = dc.BlockingIPv6
}
s.dnsFilter.UpdateBlockingMode(*dc.BlockingMode, dc.BlockingIPv4, dc.BlockingIPv6)
}

if dc.ProtectionEnabled != nil {
s.dnsFilter.UpdateProtectionEnabled(*dc.ProtectionEnabled)
}

if dc.UpstreamMode != nil {
Expand All @@ -336,7 +334,6 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
s.conf.EDNSClientSubnet.CustomIP = dc.EDNSCSCustomIP
}

setIfNotNil(&s.dnsFilter.ProtectionEnabled, dc.ProtectionEnabled)
setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled)
setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6)

Expand Down Expand Up @@ -690,8 +687,8 @@ func (s *Server) parseUpstreamLine(
}

// dnsFilter can be nil during application update.
if s.dnsFilter != nil && s.dnsFilter.EtcHosts != nil {
recs := s.dnsFilter.EtcHosts.MatchName(extractUpstreamHost(upstreamAddr))
if s.dnsFilter != nil {
recs := s.dnsFilter.EtcHostsRecords(extractUpstreamHost(upstreamAddr))
for _, rec := range recs {
opts.ServerIPAddrs = append(opts.ServerIPAddrs, rec.Addr.AsSlice())
}
Expand Down Expand Up @@ -832,8 +829,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
s.serverLock.Lock()
defer s.serverLock.Unlock()

s.dnsFilter.ProtectionEnabled = protectionReq.Enabled
s.dnsFilter.ProtectionDisabledUntil = disabledUntil
s.dnsFilter.UpdateProtectionStatus(protectionReq.Enabled, disabledUntil)
}()

s.conf.ConfigModified()
Expand Down
6 changes: 3 additions & 3 deletions internal/dnsforward/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
s.sysResolvers = &fakeSystemResolvers{}

defaultConf := s.conf
defaultFilterConf := filterConf

err := s.Start()
assert.NoError(t, err)
Expand Down Expand Up @@ -248,7 +247,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {

t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
s.dnsFilter.Config = *defaultFilterConf
s.dnsFilter.UpdateBlockingMode(filtering.BlockingModeDefault, netip.Addr{}, netip.Addr{})
s.conf = defaultConf
s.conf.Config.EDNSClientSubnet = &EDNSClientSubnet{}
})
Expand Down Expand Up @@ -500,7 +499,8 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
require.NoError(t, err)

srv := createTestServer(t, &filtering.Config{
EtcHosts: hc,
BlockingMode: filtering.BlockingModeDefault,
EtcHosts: hc,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Expand Down
20 changes: 11 additions & 9 deletions internal/dnsforward/msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ func (s *Server) genDNSFilterMessage(
req := dctx.Req
qt := req.Question[0].Qtype
if qt != dns.TypeA && qt != dns.TypeAAAA {
if s.dnsFilter.BlockingMode == filtering.BlockingModeNullIP {
m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP {
return s.makeResponse(req)
}

Expand All @@ -59,9 +60,9 @@ func (s *Server) genDNSFilterMessage(

switch res.Reason {
case filtering.FilteredSafeBrowsing:
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost, dctx)
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
case filtering.FilteredParental:
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost, dctx)
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx)
case filtering.FilteredSafeSearch:
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
Expand All @@ -76,13 +77,13 @@ func (s *Server) genDNSFilterMessage(
// blocking mode.
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
qt := req.Question[0].Qtype
switch m := s.dnsFilter.BlockingMode; m {
switch m, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); m {
case filtering.BlockingModeCustomIP:
switch qt {
case dns.TypeA:
return s.genARecord(req, s.dnsFilter.BlockingIPv4)
return s.genARecord(req, bIPv4)
case dns.TypeAAAA:
return s.genAAAARecord(req, s.dnsFilter.BlockingIPv6)
return s.genAAAARecord(req, bIPv6)
default:
// Generally shouldn't happen, since the types are checked in
// genDNSFilterMessage.
Expand All @@ -103,7 +104,8 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req)
default:
log.Error("dns: invalid blocking mode %q", s.dnsFilter.BlockingMode)
mode, _, _ := s.dnsFilter.BlockingMode()
log.Error("dns: invalid blocking mode %q", mode)

return s.makeResponse(req)
}
Expand Down Expand Up @@ -132,7 +134,7 @@ func (s *Server) hdr(req *dns.Msg, rrType rules.RRType) (h dns.RR_Header) {
return dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: rrType,
Ttl: s.dnsFilter.BlockedResponseTTL,
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
}
}
Expand Down Expand Up @@ -352,7 +354,7 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
Hdr: dns.RR_Header{
Name: zone,
Rrtype: dns.TypeSOA,
Ttl: s.dnsFilter.BlockedResponseTTL,
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
},
Mbox: "hostmaster.", // zone will be appended later if it's not empty or "."
Expand Down
2 changes: 1 addition & 1 deletion internal/dnsforward/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
Rrtype: dns.TypePTR,
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
// https://github.com/AdguardTeam/AdGuardHome/issues/3932.
Ttl: s.dnsFilter.BlockedResponseTTL,
Ttl: s.dnsFilter.BlockedResponseTTL(),
Class: dns.ClassINET,
},
Ptr: dns.Fqdn(strings.Join([]string{host, s.localDomainSuffix}, ".")),
Expand Down
Loading

0 comments on commit 5aa6212

Please sign in to comment.