diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go index c2a959d1ad..483d82c76e 100644 --- a/app/dns/dohdns.go +++ b/app/dns/dohdns.go @@ -12,6 +12,7 @@ import ( "sync" "sync/atomic" "time" + dns_feature "v2ray.com/core/features/dns" "golang.org/x/net/dns/dnsmessage" "v2ray.com/core/common" @@ -213,9 +214,13 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { if updated { s.ips[req.domain] = rec - s.pub.Publish(req.domain, nil) } - + switch req.reqType { + case dnsmessage.TypeA: + s.pub.Publish(req.domain+"4", nil) + case dnsmessage.TypeAAAA: + s.pub.Publish(req.domain+"6", nil) + } s.Unlock() common.Must(s.cleanup.Start()) } @@ -336,12 +341,15 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net. return nil, lastErr } + if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { + return nil, dns_feature.ErrEmptyResponse + } + return nil, errRecordNotFound } // QueryIP is called from dns.Server->queryIPTimeout func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { - fqdn := Fqdn(domain) ips, err := s.findIPsForDomain(fqdn, option) @@ -350,9 +358,32 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOpt return ips, err } - sub := s.pub.Subscribe(fqdn) - defer sub.Close() - + // ipv4 and ipv6 belong to different subscription groups + var sub4, sub6 *pubsub.Subscriber + if option.IPv4Enable { + sub4 = s.pub.Subscribe(fqdn + "4") + defer sub4.Close() + } + if option.IPv6Enable { + sub6 = s.pub.Subscribe(fqdn + "6") + defer sub6.Close() + } + done := make(chan interface{}) + go func() { + if sub4 != nil { + select { + case <-sub4.Wait(): + case <-ctx.Done(): + } + } + if sub6 != nil { + select { + case <-sub6.Wait(): + case <-ctx.Done(): + } + } + close(done) + }() s.sendQuery(ctx, fqdn, option) for { @@ -364,7 +395,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOpt select { case <-ctx.Done(): return nil, ctx.Err() - case <-sub.Wait(): + case <-done: } } } diff --git a/app/dns/udpns.go b/app/dns/udpns.go index f5f527e09a..70148fdec9 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -158,9 +158,13 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) { if updated { s.ips[domain] = rec - s.pub.Publish(domain, nil) } - + if newRec.A != nil { + s.pub.Publish(domain+"4", nil) + } + if newRec.AAAA != nil { + s.pub.Publish(domain+"6", nil) + } s.Unlock() common.Must(s.cleanup.Start()) } @@ -245,9 +249,32 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I return ips, err } - sub := s.pub.Subscribe(fqdn) - defer sub.Close() - + // ipv4 and ipv6 belong to different subscription groups + var sub4, sub6 *pubsub.Subscriber + if option.IPv4Enable { + sub4 = s.pub.Subscribe(fqdn + "4") + defer sub4.Close() + } + if option.IPv6Enable { + sub6 = s.pub.Subscribe(fqdn + "6") + defer sub6.Close() + } + done := make(chan interface{}) + go func() { + if sub4 != nil { + select { + case <-sub4.Wait(): + case <-ctx.Done(): + } + } + if sub6 != nil { + select { + case <-sub6.Wait(): + case <-ctx.Done(): + } + } + close(done) + }() s.sendQuery(ctx, fqdn, option) for { @@ -259,7 +286,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I select { case <-ctx.Done(): return nil, ctx.Err() - case <-sub.Wait(): + case <-done: } } }