From 4c2a201622839930bfb7ad481d38bd9fccf11d10 Mon Sep 17 00:00:00 2001 From: marek Date: Sun, 17 Oct 2021 17:03:15 +0200 Subject: [PATCH] Verbesserungen DNS-Server --- dns/dns.go | 69 ++++++++++++++++++++++++++++++++++++++---------------- main.go | 4 +++- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/dns/dns.go b/dns/dns.go index be3f57f..edc0d37 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -11,23 +11,23 @@ import ( ) type domain struct { - Root string - Mutv4 sync.RWMutex - Mutv6 sync.RWMutex - Ipv4 map[string]string - Ipv6 map[string]string + root string + mutv4 sync.RWMutex + mutv6 sync.RWMutex + ipv4 map[string]string + ipv6 map[string]string } -var Domains = []*domain{} +var domains = []*domain{} func parseQuery(m *dns.Msg, currentDomain *domain) { for _, q := range m.Question { switch q.Qtype { case dns.TypeA: log.Printf("Query for A record of %s\n", q.Name) - currentDomain.Mutv4.RLock() - ip := currentDomain.Ipv4[q.Name] - currentDomain.Mutv4.RUnlock() + currentDomain.mutv4.RLock() + ip := currentDomain.ipv4[q.Name] + currentDomain.mutv4.RUnlock() if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) if err == nil { @@ -36,9 +36,9 @@ func parseQuery(m *dns.Msg, currentDomain *domain) { } case dns.TypeAAAA: log.Printf("Query for AAAA record of %s\n", q.Name) - currentDomain.Mutv6.RLock() - ip := currentDomain.Ipv6[q.Name] - currentDomain.Mutv6.RUnlock() + currentDomain.mutv6.RLock() + ip := currentDomain.ipv6[q.Name] + currentDomain.mutv6.RUnlock() if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) if err == nil { @@ -64,21 +64,50 @@ func handleDnsRequest(currentDomain *domain) func(w dns.ResponseWriter, r *dns.M } } -func Run() error { +func Load() { + for _, currentDomain := range config.Config.Domains { + domains = append(domains, &domain{ + root: currentDomain, + }) + } +} + +func Run() (*dns.Server, error) { // attach request handler func - for _, currentDomain := range Domains { - dns.HandleFunc(currentDomain.Root, handleDnsRequest(currentDomain)) + for _, currentDomain := range domains { + dns.HandleFunc(currentDomain.root, handleDnsRequest(currentDomain)) } // start server - port := config.Config.Port - server := &dns.Server{Addr: ":" + strconv.Itoa(int(port)), Net: "udp"} - log.Printf("Starting DNS at %d\n", port) + server := &dns.Server{Addr: ":" + strconv.Itoa(int(config.Config.Port)), Net: "udp"} + log.Printf("Starting DNS at %d\n", config.Config.Port) err := server.ListenAndServe() if err != nil { server.Shutdown() - return err + return server, err } - return nil + return server, nil +} + +func UpdateIpv6(domain string, ipv6 string) { + for _, currentDomain := range domains { + if dns.IsSubDomain(currentDomain.root, domain) { + currentDomain.mutv6.Lock() + currentDomain.ipv6[domain] = ipv6 + currentDomain.mutv6.Unlock() + break + } + } +} + +func UpdateIpv4(domain string, ipv4 string) { + for _, currentDomain := range domains { + if dns.IsSubDomain(currentDomain.root, domain) { + currentDomain.mutv4.Lock() + currentDomain.ipv4[domain] = ipv4 + currentDomain.mutv4.Unlock() + break + } + } } diff --git a/main.go b/main.go index 7070785..a9685dc 100644 --- a/main.go +++ b/main.go @@ -20,8 +20,10 @@ func main() { log.Fatalf("Failed to load configuration: %s\n ", err.Error()) } - err = dns.Run() + dns.Load() + server, err := dns.Run() if err != nil { log.Fatalf("Failed to start DNS server: %s\n ", err.Error()) } + defer server.Shutdown() }