diff --git a/.gitignore b/.gitignore index 96cfb51..bbbb2b7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ gertdns gertdns.exe conf.toml auth.toml +*.v4.csv +*.v6.csv diff --git a/README.md b/README.md index e3e149c..dc3e4be 100644 --- a/README.md +++ b/README.md @@ -53,3 +53,28 @@ Default: `conf.toml` Will define what file should be used to define users that can log in. Type: `string` Default: `auth.toml` + +### --data-path +Will define where stored data is put (i.e. IP addresses for subdomains). All records will be saved here every second if they have been changed and when the application gets shut down. +Type: `string` +Default: `./` + +## Routes +### `/` +If in debug mode, will output all registered records, otherwise prints `"Working"`. + +### `/update/{domain}/{type}` +Updates a given record. +#### URL parts +`domain` (`string`): defines the subdomain that is to be modified +`type` (`"v4"` | `"v6"`): specifies whether an IPv4 or IPv6 record is to be changed. +#### query parameters +`ipv4` (`string`) (only if `type` is `"v4"`): specifies the IPv4 address to be applied. +`ipv6` (`string`) (only if `type` is `"v6"`): specifies the IPv6 address to be applied. +`user` (`string`): username as specified in _auth file_. +`password` (`string`): password as specified in _auth file_. + +#### examples +/update/**example.example**/**v4**?ipv4=**127.0.0.1**&user=**username**&password=**password** + +/update/**example.example**/**v6**?ipv6=**::1**&user=**username**&password=**password** diff --git a/auth/auth.go b/auth/auth.go index 74554e4..245059f 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -1,8 +1,6 @@ package auth import ( - "log" - "github.com/raja/argon2pw" ) @@ -73,8 +71,6 @@ func Init(authFilePath string) error { } for name, user := range users { - log.Printf("%s\n", name) - log.Printf("%+v\n", user) parsedUser, err := user.Tidy() if err != nil { return err diff --git a/dns/dns.go b/dns/dns.go index bf18788..f5e7807 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -1,35 +1,141 @@ package dns import ( + "bufio" "errors" "fmt" "log" + "os" + "path" "strconv" + "strings" "sync" + "time" "github.com/MarekWojt/gertdns/config" + "github.com/MarekWojt/gertdns/util" + "github.com/gookit/color" "github.com/miekg/dns" ) +var domains []*domain = make([]*domain, 0) +var saveTicker *time.Ticker = time.NewTicker(time.Second) +var saving sync.Mutex = sync.Mutex{} +var currentDataPath string + +const ( + IPV4_FILE = "v4.csv" + IPV6_FILE = "v6.csv" +) + type domain struct { - Root string - Mutv4 sync.RWMutex - Mutv6 sync.RWMutex - Ipv4 map[string]string - Ipv6 map[string]string + Root string + Mutv4 sync.RWMutex + Mutv6 sync.RWMutex + Mutv4Changed sync.RWMutex + Mutv6Changed sync.RWMutex + Ipv4 map[string]string + Ipv4Changed bool + Ipv6 map[string]string + Ipv6Changed bool } -var domains []*domain = make([]*domain, 0) +func (currentDomain *domain) IsV4Changed() bool { + currentDomain.Mutv4Changed.RLock() + result := currentDomain.Ipv4Changed + currentDomain.Mutv4Changed.RUnlock() -func Init() { + return result +} + +func (currentDomain *domain) IsV6Changed() bool { + currentDomain.Mutv6Changed.RLock() + result := currentDomain.Ipv6Changed + currentDomain.Mutv6Changed.RUnlock() + + return result +} + +func (currentDomain *domain) MarkV6Changed(changed bool) { + currentDomain.Mutv6Changed.Lock() + currentDomain.Ipv6Changed = changed + currentDomain.Mutv6Changed.Unlock() +} + +func (currentDomain *domain) MarkV4Changed(changed bool) { + currentDomain.Mutv4Changed.Lock() + currentDomain.Ipv4Changed = changed + currentDomain.Mutv4Changed.Unlock() +} + +func (currentDomain *domain) SetV4(domain string, ipv4 string) { + currentDomain.Mutv4.Lock() + currentDomain.Ipv4[domain] = ipv4 + currentDomain.Mutv4.Unlock() +} + +func (currentDomain *domain) SetV6(domain string, ipv6 string) { + currentDomain.Mutv6.Lock() + currentDomain.Ipv6[domain] = ipv6 + currentDomain.Mutv6.Unlock() +} + +func loadFile(ty string, currentDomain *domain) { + if ty != IPV4_FILE && ty != IPV6_FILE { + panic("type passed to loadFile must be either IPV4_FILE or IPV6_FILE") + } + + filePath := path.Join(currentDataPath, currentDomain.Root+ty) + f, err := os.Open(filePath) + if err != nil { + color.Warnf("Could not load file for domain %s: %s\n", currentDomain.Root, err) + } else { + log.Printf("Reading file: %s", filePath) + scanner := bufio.NewScanner(f) + + lineCounter := 0 + for scanner.Scan() { + lineCounter++ + currentLine := scanner.Text() + cols := strings.Split(currentLine, "\t") + if len(cols) < 2 { + color.Warnf("Error reading line %d of ipv4 addresses for domain %s: too few columns\n", lineCounter, currentDomain.Root) + continue + } + + if ty == IPV4_FILE { + currentDomain.Ipv4[cols[0]] = cols[1] + } else if ty == IPV6_FILE { + currentDomain.Ipv6[cols[0]] = cols[1] + } + } + color.Infof("Read file: %s\n", filePath) + } + f.Close() +} + +func Init(dataPath string) { + currentDataPath = dataPath for _, currentDomain := range config.Config.DNS.Domains { + currentDomain = util.ParseDomain(currentDomain) log.Printf("Added domain root: %s\n", currentDomain) - domains = append(domains, &domain{ + + domainObj := &domain{ Root: currentDomain, Ipv4: make(map[string]string), Ipv6: make(map[string]string), - }) + } + domains = append(domains, domainObj) + loadFile(IPV4_FILE, domainObj) + loadFile(IPV6_FILE, domainObj) } + + go func() { + for { + <-saveTicker.C + Save() + } + }() } func Run() (*dns.Server, error) { @@ -54,9 +160,12 @@ func UpdateIpv6(domain string, ipv6 string) error { for _, currentDomain := range domains { if dns.IsSubDomain(currentDomain.Root, domain) { log.Printf("Updating domain %s AAAA %s\n", domain, ipv6) - currentDomain.Mutv6.Lock() - currentDomain.Ipv6[domain] = ipv6 - currentDomain.Mutv6.Unlock() + + if !currentDomain.IsV6Changed() { + currentDomain.MarkV6Changed(true) + } + + currentDomain.SetV6(domain, ipv6) return nil } } @@ -66,11 +175,15 @@ func UpdateIpv6(domain string, ipv6 string) error { func UpdateIpv4(domain string, ipv4 string) (err error) { for _, currentDomain := range domains { + log.Printf("%s sub of %s ?\n", domain, currentDomain.Root) if dns.IsSubDomain(currentDomain.Root, domain) { log.Printf("Updating domain %s A %s\n", domain, ipv4) - currentDomain.Mutv4.Lock() - currentDomain.Ipv4[domain] = ipv4 - currentDomain.Mutv4.Unlock() + + if !currentDomain.IsV4Changed() { + currentDomain.MarkV4Changed(true) + } + + currentDomain.SetV4(domain, ipv4) return nil } } @@ -125,3 +238,54 @@ func handleDnsRequest(currentDomain *domain) func(w dns.ResponseWriter, r *dns.M w.WriteMsg(m) } } + +func Save() (errs []error) { + saving.Lock() + for _, domain := range domains { + if domain.IsV4Changed() { + ipv4Data := "" + domain.Mutv4.RLock() + for key, val := range domain.Ipv4 { + ipv4Data += key + "\t" + val + "\n" + } + domain.Mutv4.RUnlock() + err := os.WriteFile(path.Join(currentDataPath, domain.Root+IPV4_FILE), []byte(ipv4Data), 0644) + if err != nil { + errs = append(errs, err) + color.Errorf("Failed to save ipv4 data for domain %s: %s\n", domain.Root, err) + } else { + // did successfully save, so mark as saved + domain.MarkV4Changed(false) + } + } + + if domain.IsV6Changed() { + ipv6Data := "" + domain.Mutv6.RLock() + for key, val := range domain.Ipv6 { + ipv6Data += key + "\t" + val + "\n" + } + domain.Mutv6.RUnlock() + err := os.WriteFile(path.Join(currentDataPath, domain.Root+IPV6_FILE), []byte(ipv6Data), 0644) + if err != nil { + errs = append(errs, err) + color.Errorf("Failed to save ipv6 data for domain %s: %s\n", domain.Root, err) + } else { + // did successfully save, so mark as saved + domain.MarkV6Changed(false) + } + } + } + saving.Unlock() + + errLen := len(errs) + if errLen > 0 { + color.Errorf("%d errors occurred while trying to save\n", errLen) + } + return +} + +func Shutdown() { + saveTicker.Stop() + Save() +} diff --git a/main.go b/main.go index ad7a2c5..3cb4db8 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,9 @@ package main import ( "flag" "log" + "os" + "os/signal" + "syscall" "github.com/MarekWojt/gertdns/auth" "github.com/MarekWojt/gertdns/config" @@ -14,6 +17,7 @@ import ( var ( configFile = flag.String("config-file", "conf.toml", "Path to configuration file") authFile = flag.String("auth-file", "auth.toml", "Path to authentication file") + dataPath = flag.String("data-path", "./", "Where to save data") enableDebugMode = flag.Bool("enable-debug-mode", false, "Enables debug mode, will output a list of all registered records on the index page of the HTTP server") ) @@ -30,7 +34,7 @@ func main() { log.Fatalf("Failed to load configuration: %s\n ", err.Error()) } - dns.Init() + dns.Init(*dataPath) web.Init(*enableDebugMode) err = auth.Init(*authFile) if err != nil { @@ -69,8 +73,17 @@ func main() { webChan <- err }() + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGTERM) + go func() { + <-c + dns.Shutdown() + os.Exit(0) + }() + currentDnsResult := <-dnsChan defer currentDnsResult.server.Shutdown() + defer dns.Shutdown() <-webChan <-webChan } diff --git a/util/util.go b/util/util.go new file mode 100644 index 0000000..678c9ee --- /dev/null +++ b/util/util.go @@ -0,0 +1,11 @@ +package util + +import "strings" + +func ParseDomain(domain string) string { + if !strings.HasSuffix(domain, ".") { + return domain + "." + } + + return domain +} diff --git a/web/web.go b/web/web.go index a430772..8ac0aaf 100644 --- a/web/web.go +++ b/web/web.go @@ -8,6 +8,7 @@ import ( "github.com/MarekWojt/gertdns/auth" "github.com/MarekWojt/gertdns/config" "github.com/MarekWojt/gertdns/dns" + "github.com/MarekWojt/gertdns/util" "github.com/fasthttp/router" "github.com/valyala/fasthttp" ) @@ -85,6 +86,7 @@ func index(ctx *fasthttp.RequestCtx) { func updateV4(ctx *fasthttp.RequestCtx) { domain := ctx.UserValue("domain").(string) + domain = util.ParseDomain(domain) ipv4 := string(ctx.QueryArgs().PeekBytes(ipv4Param)) if ipv4 == "" { ctx.WriteString("Missing ipv4 query parameter") @@ -104,6 +106,7 @@ func updateV4(ctx *fasthttp.RequestCtx) { func updateV6(ctx *fasthttp.RequestCtx) { domain := ctx.UserValue("domain").(string) + domain = util.ParseDomain(domain) ipv6 := string(ctx.QueryArgs().PeekBytes(ipv6Param)) if ipv6 == "" { ctx.WriteString("Missing ipv6 query parameter") @@ -129,6 +132,7 @@ func authenticatedRequest(request func(ctx *fasthttp.RequestCtx)) func(ctx *fast ctx.SetStatusCode(fasthttp.StatusBadRequest) return } + domain = util.ParseDomain(domain) user := string(ctx.QueryArgs().PeekBytes(userParam)) if user == "" {