diff --git a/.gitignore b/.gitignore index 07d4fbe94f..5297469b9e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.DS_Store bazel-* +.idea \ No newline at end of file diff --git a/app/dns/dohdns.go b/app/dns/dohdns.go index 61237ba8cc..483d82c76e 100644 --- a/app/dns/dohdns.go +++ b/app/dns/dohdns.go @@ -12,6 +12,7 @@ import ( "sync" "sync/atomic" "time" + dns_feature "v2ray.com/core/features/dns" "golang.org/x/net/dns/dnsmessage" "v2ray.com/core/common" @@ -185,7 +186,6 @@ func (s *DoHNameServer) Cleanup() error { func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { elapsed := time.Since(req.start) - newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() s.Lock() rec := s.ips[req.domain] @@ -198,17 +198,29 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { updated = true } case dnsmessage.TypeAAAA: + addr := make([]net.Address, 0) + for _, ip := range ipRec.IP { + if len(ip.IP()) == net.IPv6len { + addr = append(addr, ip) + } + } + ipRec.IP = addr if isNewer(rec.AAAA, ipRec) { rec.AAAA = ipRec updated = true } } + newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() if updated { s.ips[req.domain] = rec - s.pub.Publish(req.domain, nil) } - + switch req.reqType { + case dnsmessage.TypeA: + s.pub.Publish(req.domain+"4", nil) + case dnsmessage.TypeAAAA: + s.pub.Publish(req.domain+"6", nil) + } s.Unlock() common.Must(s.cleanup.Start()) } @@ -329,12 +341,15 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net. return nil, lastErr } + if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) { + return nil, dns_feature.ErrEmptyResponse + } + return nil, errRecordNotFound } // QueryIP is called from dns.Server->queryIPTimeout func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) { - fqdn := Fqdn(domain) ips, err := s.findIPsForDomain(fqdn, option) @@ -343,9 +358,32 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOpt return ips, err } - sub := s.pub.Subscribe(fqdn) - defer sub.Close() - + // ipv4 and ipv6 belong to different subscription groups + var sub4, sub6 *pubsub.Subscriber + if option.IPv4Enable { + sub4 = s.pub.Subscribe(fqdn + "4") + defer sub4.Close() + } + if option.IPv6Enable { + sub6 = s.pub.Subscribe(fqdn + "6") + defer sub6.Close() + } + done := make(chan interface{}) + go func() { + if sub4 != nil { + select { + case <-sub4.Wait(): + case <-ctx.Done(): + } + } + if sub6 != nil { + select { + case <-sub6.Wait(): + case <-ctx.Done(): + } + } + close(done) + }() s.sendQuery(ctx, fqdn, option) for { @@ -357,7 +395,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOpt select { case <-ctx.Done(): return nil, ctx.Err() - case <-sub.Wait(): + case <-done: } } } diff --git a/app/dns/udpns.go b/app/dns/udpns.go index f5f527e09a..70148fdec9 100644 --- a/app/dns/udpns.go +++ b/app/dns/udpns.go @@ -158,9 +158,13 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) { if updated { s.ips[domain] = rec - s.pub.Publish(domain, nil) } - + if newRec.A != nil { + s.pub.Publish(domain+"4", nil) + } + if newRec.AAAA != nil { + s.pub.Publish(domain+"6", nil) + } s.Unlock() common.Must(s.cleanup.Start()) } @@ -245,9 +249,32 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I return ips, err } - sub := s.pub.Subscribe(fqdn) - defer sub.Close() - + // ipv4 and ipv6 belong to different subscription groups + var sub4, sub6 *pubsub.Subscriber + if option.IPv4Enable { + sub4 = s.pub.Subscribe(fqdn + "4") + defer sub4.Close() + } + if option.IPv6Enable { + sub6 = s.pub.Subscribe(fqdn + "6") + defer sub6.Close() + } + done := make(chan interface{}) + go func() { + if sub4 != nil { + select { + case <-sub4.Wait(): + case <-ctx.Done(): + } + } + if sub6 != nil { + select { + case <-sub6.Wait(): + case <-ctx.Done(): + } + } + close(done) + }() s.sendQuery(ctx, fqdn, option) for { @@ -259,7 +286,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I select { case <-ctx.Done(): return nil, ctx.Err() - case <-sub.Wait(): + case <-done: } } }