mirror of
				https://github.com/dcarrillo/whatismyip.git
				synced 2025-11-04 04:39:09 +00:00 
			
		
		
		
	New whatismydns feature (#29)
This commit is contained in:
		
							
								
								
									
										165
									
								
								resolver/setup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								resolver/setup.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,165 @@
 | 
			
		||||
package resolver
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/dcarrillo/whatismyip/internal/setting"
 | 
			
		||||
	"github.com/dcarrillo/whatismyip/internal/validator/uuid"
 | 
			
		||||
	"github.com/miekg/dns"
 | 
			
		||||
	"github.com/patrickmn/go-cache"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Resolver struct {
 | 
			
		||||
	handler *dns.ServeMux
 | 
			
		||||
	store   *cache.Cache
 | 
			
		||||
	domain  string
 | 
			
		||||
	rr      []string
 | 
			
		||||
	ipv4    []net.IP
 | 
			
		||||
	ipv6    []net.IP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ensureDotSuffix(s string) string {
 | 
			
		||||
	if !strings.HasSuffix(s, ".") {
 | 
			
		||||
		return s + "."
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Setup(store *cache.Cache) *Resolver {
 | 
			
		||||
	var ipv4, ipv6 []net.IP
 | 
			
		||||
	for _, ip := range setting.App.Resolver.Ipv4 {
 | 
			
		||||
		ipv4 = append(ipv4, net.ParseIP(ip))
 | 
			
		||||
	}
 | 
			
		||||
	for _, ip := range setting.App.Resolver.Ipv6 {
 | 
			
		||||
		ipv6 = append(ipv6, net.ParseIP(ip))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resolver := &Resolver{
 | 
			
		||||
		handler: dns.NewServeMux(),
 | 
			
		||||
		store:   store,
 | 
			
		||||
		domain:  ensureDotSuffix(setting.App.Resolver.Domain),
 | 
			
		||||
		rr:      setting.App.Resolver.ResourceRecords,
 | 
			
		||||
		ipv4:    ipv4,
 | 
			
		||||
		ipv6:    ipv6,
 | 
			
		||||
	}
 | 
			
		||||
	resolver.handler.HandleFunc(resolver.domain, resolver.resolve)
 | 
			
		||||
	resolver.handler.HandleFunc(".", resolver.blackHole)
 | 
			
		||||
 | 
			
		||||
	return resolver
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsv *Resolver) Handler() *dns.ServeMux {
 | 
			
		||||
	return rsv.handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsv *Resolver) blackHole(w dns.ResponseWriter, r *dns.Msg) {
 | 
			
		||||
	msg := startReply(r)
 | 
			
		||||
	msg.SetRcode(r, dns.RcodeRefused)
 | 
			
		||||
	w.WriteMsg(msg)
 | 
			
		||||
	logger(w, r.Question[0], msg.Rcode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsv *Resolver) resolve(w dns.ResponseWriter, r *dns.Msg) {
 | 
			
		||||
	msg := startReply(r)
 | 
			
		||||
	q := r.Question[0]
 | 
			
		||||
	ip, _, _ := net.SplitHostPort(w.RemoteAddr().String())
 | 
			
		||||
 | 
			
		||||
	for _, res := range rsv.rr {
 | 
			
		||||
		t := strings.Split(res, " ")[2]
 | 
			
		||||
		if q.Qtype == dns.StringToType[t] {
 | 
			
		||||
			brr, err := buildRR(rsv.domain + " " + res)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				msg.SetRcode(r, dns.RcodeServerFailure)
 | 
			
		||||
				logger(w, q, msg.Rcode, err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				msg.Answer = append(msg.Answer, brr)
 | 
			
		||||
				logger(w, q, msg.Rcode)
 | 
			
		||||
			}
 | 
			
		||||
			w.WriteMsg(msg)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lowerName := strings.ToLower(q.Name) // lowercase because of dns-0x20
 | 
			
		||||
	subDomain := strings.Split(lowerName, ".")[0]
 | 
			
		||||
	switch {
 | 
			
		||||
	case uuid.IsValid(subDomain):
 | 
			
		||||
		msg.SetRcode(r, rsv.getIP(q, msg))
 | 
			
		||||
		rsv.store.Add(subDomain, ip, cache.DefaultExpiration)
 | 
			
		||||
	case lowerName == rsv.domain:
 | 
			
		||||
		msg.SetRcode(r, rsv.getIP(q, msg))
 | 
			
		||||
	default:
 | 
			
		||||
		msg.SetRcode(r, dns.RcodeRefused)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.WriteMsg(msg)
 | 
			
		||||
	logger(w, q, msg.Rcode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (rsv *Resolver) getIP(question dns.Question, msg *dns.Msg) int {
 | 
			
		||||
	if question.Qtype == dns.TypeA && len(rsv.ipv4) > 0 {
 | 
			
		||||
		for _, ip := range rsv.ipv4 {
 | 
			
		||||
			msg.Answer = append(msg.Answer, &dns.A{
 | 
			
		||||
				Hdr: setHdr(question),
 | 
			
		||||
				A:   ip,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		return dns.RcodeSuccess
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if question.Qtype == dns.TypeAAAA && len(rsv.ipv6) > 0 {
 | 
			
		||||
		for _, ip := range rsv.ipv6 {
 | 
			
		||||
			msg.Answer = append(msg.Answer, &dns.AAAA{
 | 
			
		||||
				Hdr:  setHdr(question),
 | 
			
		||||
				AAAA: ip,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		return dns.RcodeSuccess
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return dns.RcodeRefused
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildRR(rrs string) (dns.RR, error) {
 | 
			
		||||
	rr, err := dns.NewRR(rrs)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return rr, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setHdr(q dns.Question) dns.RR_Header {
 | 
			
		||||
	return dns.RR_Header{
 | 
			
		||||
		Name:   q.Name,
 | 
			
		||||
		Rrtype: q.Qtype,
 | 
			
		||||
		Class:  dns.ClassINET,
 | 
			
		||||
		Ttl:    60,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func startReply(r *dns.Msg) *dns.Msg {
 | 
			
		||||
	msg := new(dns.Msg)
 | 
			
		||||
	msg.SetReply(r)
 | 
			
		||||
	msg.Authoritative = true
 | 
			
		||||
 | 
			
		||||
	return msg
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func logger(w dns.ResponseWriter, q dns.Question, code int, err ...string) {
 | 
			
		||||
	emsg := ""
 | 
			
		||||
	if len(err) > 0 {
 | 
			
		||||
		emsg = " - " + strings.Join(err, " ")
 | 
			
		||||
	}
 | 
			
		||||
	ip, _, _ := net.SplitHostPort(w.RemoteAddr().String())
 | 
			
		||||
	log.Printf(
 | 
			
		||||
		"DNS %s - %s - %s - %s%s",
 | 
			
		||||
		ip,
 | 
			
		||||
		dns.TypeToString[q.Qtype],
 | 
			
		||||
		q.Name,
 | 
			
		||||
		dns.RcodeToString[code],
 | 
			
		||||
		emsg,
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user