Files
gertdns/dns/dns.go

291 lines
6.6 KiB
Go

package dns
import (
"bufio"
"errors"
"fmt"
"log"
"os"
"path"
"strconv"
"strings"
"sync"
"time"
"github.com/MarekWojt/gertdns/config"
"github.com/MarekWojt/gertdns/util"
"github.com/gookit/color"
"github.com/miekg/dns"
)
var domains []*domain = make([]*domain, 0)
var saveTicker *time.Ticker = time.NewTicker(time.Second)
var saving sync.Mutex = sync.Mutex{}
var currentDataPath string
const (
IPV4_FILE = "v4.csv"
IPV6_FILE = "v6.csv"
)
type domain struct {
Root string
Mutv4 sync.RWMutex
Mutv6 sync.RWMutex
Mutv4Changed sync.RWMutex
Mutv6Changed sync.RWMutex
Ipv4 map[string]string
Ipv4Changed bool
Ipv6 map[string]string
Ipv6Changed bool
}
func (currentDomain *domain) IsV4Changed() bool {
currentDomain.Mutv4Changed.RLock()
result := currentDomain.Ipv4Changed
currentDomain.Mutv4Changed.RUnlock()
return result
}
func (currentDomain *domain) IsV6Changed() bool {
currentDomain.Mutv6Changed.RLock()
result := currentDomain.Ipv6Changed
currentDomain.Mutv6Changed.RUnlock()
return result
}
func (currentDomain *domain) MarkV6Changed(changed bool) {
currentDomain.Mutv6Changed.Lock()
currentDomain.Ipv6Changed = changed
currentDomain.Mutv6Changed.Unlock()
}
func (currentDomain *domain) MarkV4Changed(changed bool) {
currentDomain.Mutv4Changed.Lock()
currentDomain.Ipv4Changed = changed
currentDomain.Mutv4Changed.Unlock()
}
func (currentDomain *domain) SetV4(domain string, ipv4 string) {
currentDomain.Mutv4.Lock()
currentDomain.Ipv4[domain] = ipv4
currentDomain.Mutv4.Unlock()
}
func (currentDomain *domain) SetV6(domain string, ipv6 string) {
currentDomain.Mutv6.Lock()
currentDomain.Ipv6[domain] = ipv6
currentDomain.Mutv6.Unlock()
}
func loadFile(ty string, currentDomain *domain) {
if ty != IPV4_FILE && ty != IPV6_FILE {
panic("type passed to loadFile must be either IPV4_FILE or IPV6_FILE")
}
filePath := path.Join(currentDataPath, currentDomain.Root+ty)
f, err := os.Open(filePath)
if err != nil {
color.Warnf("Could not load file for domain %s: %s\n", currentDomain.Root, err)
return
}
defer f.Close()
log.Printf("Reading file: %s", filePath)
scanner := bufio.NewScanner(f)
lineCounter := 0
for scanner.Scan() {
lineCounter++
currentLine := scanner.Text()
cols := strings.Split(currentLine, "\t")
if len(cols) < 2 {
color.Warnf("Error reading line %d of ipv4 addresses for domain %s: too few columns\n", lineCounter, currentDomain.Root)
continue
}
if ty == IPV4_FILE {
currentDomain.Ipv4[cols[0]] = cols[1]
} else if ty == IPV6_FILE {
currentDomain.Ipv6[cols[0]] = cols[1]
}
}
color.Infof("Read file: %s\n", filePath)
}
func Init(dataPath string) {
currentDataPath = dataPath
for _, currentDomain := range config.Config.DNS.Domains {
currentDomain = util.ParseDomain(currentDomain)
log.Printf("Added domain root: %s\n", currentDomain)
domainObj := &domain{
Root: currentDomain,
Ipv4: make(map[string]string),
Ipv6: make(map[string]string),
}
domains = append(domains, domainObj)
loadFile(IPV4_FILE, domainObj)
loadFile(IPV6_FILE, domainObj)
}
go func() {
for {
<-saveTicker.C
Save()
}
}()
}
func Run() (*dns.Server, error) {
// attach request handler func
for _, currentDomain := range domains {
dns.HandleFunc(currentDomain.Root, handleDnsRequest(currentDomain))
}
// start server
server := &dns.Server{Addr: ":" + strconv.Itoa(int(config.Config.DNS.Port)), Net: "udp"}
log.Printf("Starting DNS at %d\n", config.Config.DNS.Port)
err := server.ListenAndServe()
if err != nil {
server.Shutdown()
return server, err
}
return server, nil
}
func UpdateIpv6(domain string, ipv6 string) error {
for _, currentDomain := range domains {
if dns.IsSubDomain(currentDomain.Root, domain) {
log.Printf("Updating domain %s AAAA %s\n", domain, ipv6)
if !currentDomain.IsV6Changed() {
currentDomain.MarkV6Changed(true)
}
currentDomain.SetV6(domain, ipv6)
return nil
}
}
return errors.New("no root found")
}
func UpdateIpv4(domain string, ipv4 string) (err error) {
for _, currentDomain := range domains {
log.Printf("%s sub of %s ?\n", domain, currentDomain.Root)
if dns.IsSubDomain(currentDomain.Root, domain) {
log.Printf("Updating domain %s A %s\n", domain, ipv4)
if !currentDomain.IsV4Changed() {
currentDomain.MarkV4Changed(true)
}
currentDomain.SetV4(domain, ipv4)
return nil
}
}
return errors.New("no root found")
}
func Get() []*domain {
return domains
}
func parseQuery(m *dns.Msg, currentDomain *domain) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeA:
currentDomain.Mutv4.RLock()
ip := currentDomain.Ipv4[q.Name]
currentDomain.Mutv4.RUnlock()
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf(q.Name + " 300 IN A " + ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeAAAA:
currentDomain.Mutv6.RLock()
ip := currentDomain.Ipv6[q.Name]
currentDomain.Mutv6.RUnlock()
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf(q.Name + " 300 IN AAAA " + ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
}
func handleDnsRequest(currentDomain *domain) func(w dns.ResponseWriter, r *dns.Msg) {
return func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false
switch r.Opcode {
case dns.OpcodeQuery:
parseQuery(m, currentDomain)
}
w.WriteMsg(m)
}
}
func Save() (errs []error) {
saving.Lock()
for _, domain := range domains {
if domain.IsV4Changed() {
ipv4Data := ""
domain.Mutv4.RLock()
for key, val := range domain.Ipv4 {
ipv4Data += key + "\t" + val + "\n"
}
domain.Mutv4.RUnlock()
err := os.WriteFile(path.Join(currentDataPath, domain.Root+IPV4_FILE), []byte(ipv4Data), 0644)
if err != nil {
errs = append(errs, err)
color.Errorf("Failed to save ipv4 data for domain %s: %s\n", domain.Root, err)
} else {
// did successfully save, so mark as saved
domain.MarkV4Changed(false)
}
}
if domain.IsV6Changed() {
ipv6Data := ""
domain.Mutv6.RLock()
for key, val := range domain.Ipv6 {
ipv6Data += key + "\t" + val + "\n"
}
domain.Mutv6.RUnlock()
err := os.WriteFile(path.Join(currentDataPath, domain.Root+IPV6_FILE), []byte(ipv6Data), 0644)
if err != nil {
errs = append(errs, err)
color.Errorf("Failed to save ipv6 data for domain %s: %s\n", domain.Root, err)
} else {
// did successfully save, so mark as saved
domain.MarkV6Changed(false)
}
}
}
saving.Unlock()
errLen := len(errs)
if errLen > 0 {
color.Errorf("%d errors occurred while trying to save\n", errLen)
}
return
}
func Shutdown() {
saveTicker.Stop()
Save()
}