From 5bc0c3b2b47a666fc371e8b2c3b661d67b8469b0 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Tue, 11 Feb 2025 02:50:13 +0300 Subject: [PATCH] fixes --- {dns-mitm => dns-mitm-proxy}/mitm.go | 40 ++--- kvas2.go | 145 +++++++--------- netfilter-helper/interface-to-ipset.go | 4 +- records.go | 228 ------------------------- records/records.go | 167 ++++++++++++++++++ records/records_test.go | 109 ++++++++++++ 6 files changed, 355 insertions(+), 338 deletions(-) rename {dns-mitm => dns-mitm-proxy}/mitm.go (80%) delete mode 100644 records.go create mode 100644 records/records.go create mode 100644 records/records_test.go diff --git a/dns-mitm/mitm.go b/dns-mitm-proxy/mitm.go similarity index 80% rename from dns-mitm/mitm.go rename to dns-mitm-proxy/mitm.go index d785990..d8b0a1e 100644 --- a/dns-mitm/mitm.go +++ b/dns-mitm-proxy/mitm.go @@ -1,19 +1,17 @@ -package dnsMitm +package dnsMitmProxy import ( "context" "encoding/binary" "fmt" "net" - "strconv" "time" "github.com/miekg/dns" "github.com/rs/zerolog/log" ) -type DNSMITM struct { - ListenPort uint16 +type DNSMITMProxy struct { TargetDNSServerAddress string TargetDNSServerPort uint16 @@ -21,7 +19,7 @@ type DNSMITM struct { ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error) } -func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) { +func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) { serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort)) if err != nil { return nil, fmt.Errorf("failed to dial DNS server: %w", err) @@ -65,7 +63,7 @@ func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) { return resp[:n], nil } -func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) { +func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) { var reqMsg dns.Msg if p.RequestHook != nil || p.ResponseHook != nil { err := reqMsg.Unpack(req) @@ -97,14 +95,14 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([] resp, err := p.requestDNS(req, network) if err != nil { - return nil, fmt.Errorf("failed to send request") + return nil, fmt.Errorf("failed to send request: %w", err) } if p.ResponseHook != nil { var respMsg dns.Msg err = respMsg.Unpack(resp) if err != nil { - return nil, fmt.Errorf("failed to parse response") + return nil, fmt.Errorf("failed to parse response: %w", err) } modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network) @@ -123,12 +121,7 @@ func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([] return resp, nil } -func (p DNSMITM) ListenTCP(ctx context.Context) error { - addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort))) - if err != nil { - return fmt.Errorf("failed to resolve tcp address: %v", err) - } - +func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error { listener, err := net.ListenTCP("tcp", addr) if err != nil { return fmt.Errorf("failed to listen tcp port: %v", err) @@ -184,12 +177,7 @@ func (p DNSMITM) ListenTCP(ctx context.Context) error { } } -func (p DNSMITM) ListenUDP(ctx context.Context) error { - addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort))) - if err != nil { - return fmt.Errorf("failed to resolve udp address: %v", err) - } - +func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error { conn, err := net.ListenUDP("udp", addr) if err != nil { return fmt.Errorf("failed to listen udp port: %v", err) @@ -226,14 +214,8 @@ func (p DNSMITM) ListenUDP(ctx context.Context) error { } } -func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM { - dnsMitm := &DNSMITM{ - ListenPort: listenPort, - TargetDNSServerAddress: targetDNSServerAddress, - TargetDNSServerPort: 53, +func New() *DNSMITMProxy { + return &DNSMITMProxy{ + TargetDNSServerPort: 53, } - if len(targetDNSServerPort) > 0 { - dnsMitm.TargetDNSServerPort = targetDNSServerPort[0] - } - return dnsMitm } diff --git a/kvas2.go b/kvas2.go index 12faab6..73e6ab2 100644 --- a/kvas2.go +++ b/kvas2.go @@ -6,19 +6,20 @@ import ( "fmt" "net" "os" + "strconv" "strings" "time" - "kvas2-go/dns-mitm" + "kvas2-go/dns-mitm-proxy" "kvas2-go/models" "kvas2-go/netfilter-helper" + "kvas2-go/records" "github.com/google/uuid" "github.com/miekg/dns" "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" "github.com/vishvananda/netlink/nl" - "golang.org/x/sys/unix" ) var ( @@ -39,10 +40,10 @@ type Config struct { type App struct { Config Config - DNSMITM *dnsMitm.DNSMITM + DNSMITM *dnsMitmProxy.DNSMITMProxy NetfilterHelper4 *netfilterHelper.NetfilterHelper NetfilterHelper6 *netfilterHelper.NetfilterHelper - Records *Records + Records *records.Records Groups map[uuid.UUID]*Group Link netlink.Link @@ -54,7 +55,7 @@ type App struct { func (a *App) handleLink(event netlink.LinkUpdate) { switch event.Change { - case unix.IFF_UP: + case 0x00000001: log.Debug(). Str("interface", event.Link.Attrs().Name). Str("operstatestr", event.Attrs().OperState.String()). @@ -94,7 +95,6 @@ func (a *App) start(ctx context.Context) (err error) { newCtx, cancel := context.WithCancel(ctx) defer cancel() - // TODO: Chan err errChan := make(chan error) /* @@ -102,16 +102,28 @@ func (a *App) start(ctx context.Context) (err error) { */ go func() { - err := a.DNSMITM.ListenUDP(newCtx) + addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort))) + if err != nil { + errChan <- fmt.Errorf("failed to resolve udp address: %v", err) + return + } + err = a.DNSMITM.ListenUDP(newCtx, addr) if err != nil { errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err) + return } }() go func() { - err := a.DNSMITM.ListenTCP(newCtx) + addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort))) + if err != nil { + errChan <- fmt.Errorf("failed to resolve tcp address: %v", err) + return + } + err = a.DNSMITM.ListenTCP(newCtx, addr) if err != nil { errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err) + return } }() @@ -125,20 +137,14 @@ func (a *App) start(ctx context.Context) (err error) { if err != nil { return fmt.Errorf("failed to override DNS (IPv4): %v", err) } - defer func() { - // TODO: Handle error - _ = a.dnsOverrider4.Disable() - }() + defer func() { _ = a.dnsOverrider4.Disable() }() a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList) err = a.dnsOverrider6.Enable() if err != nil { return fmt.Errorf("failed to override DNS (IPv6): %v", err) } - defer func() { - // TODO: Handle error - _ = a.dnsOverrider6.Disable() - }() + defer func() { _ = a.dnsOverrider6.Disable() }() /* Groups @@ -152,7 +158,6 @@ func (a *App) start(ctx context.Context) (err error) { } defer func() { for _, group := range a.Groups { - // TODO: Handle error _ = group.Disable() } }() @@ -170,13 +175,16 @@ func (a *App) start(ctx context.Context) (err error) { return fmt.Errorf("error while serve UNIX socket: %v", err) } defer func() { - // TODO: Handle error _ = socket.Close() _ = os.Remove(socketPath) }() go func() { for { + if newCtx.Err() != nil { + return + } + conn, err := socket.Accept() if err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { @@ -186,10 +194,7 @@ func (a *App) start(ctx context.Context) (err error) { } go func(conn net.Conn) { - defer func() { - // TODO: Handle error - _ = conn.Close() - }() + defer func() { _ = conn.Close() }() buf := make([]byte, 1024) n, err := conn.Read(buf) @@ -206,6 +211,12 @@ func (a *App) start(ctx context.Context) (err error) { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } } + if a.dnsOverrider6.Enabled { + err = a.dnsOverrider6.PutIPTable(args[2]) + if err != nil { + log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") + } + } for _, group := range a.Groups { if group.ifaceToIPSetNAT.Enabled { err := group.ifaceToIPSetNAT.PutIPTable(args[2]) @@ -240,7 +251,6 @@ func (a *App) start(ctx context.Context) (err error) { case event := <-linkUpdateChannel: a.handleLink(event) case err := <-errChan: - close(errChan) return err case <-ctx.Done(): return nil @@ -288,22 +298,17 @@ func (a *App) AddGroup(group *models.Group) error { Group: group, iptables: a.NetfilterHelper4.IPTables, ipset: ipset, - ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false), + ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false), } + grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting a.Groups[grp.ID] = grp return a.SyncGroup(grp) } func (a *App) SyncGroup(group *Group) error { - processedDomains := make(map[string]struct{}) - newIpsetAddressesMap := make(map[string]time.Duration) now := time.Now() - oldIpsetAddresses, err := group.ListIPv4() - if err != nil { - return fmt.Errorf("failed to get old ipset list: %w", err) - } - + addresses := make(map[string]time.Duration) knownDomains := a.Records.ListKnownDomains() for _, domain := range group.Rules { if !domain.IsEnabled() { @@ -315,26 +320,24 @@ func (a *App) SyncGroup(group *Group) error { continue } - cnames := a.Records.GetCNameRecords(domainName, true) - if len(cnames) == 0 { - continue - } - for _, cname := range cnames { - processedDomains[cname] = struct{}{} - } - - addresses := a.Records.GetARecords(domainName) - for _, address := range addresses { + domainAddresses := a.Records.GetARecords(domainName) + for _, address := range domainAddresses { ttl := now.Sub(address.Deadline) - if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL { - newIpsetAddressesMap[string(address.Address)] = ttl + if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL { + addresses[string(address.Address)] = ttl } } } } - for addr, ttl := range newIpsetAddressesMap { - if _, exists := oldIpsetAddresses[addr]; exists { + currentAddresses, err := group.ListIPv4() + if err != nil { + return fmt.Errorf("failed to get old ipset list: %w", err) + } + + for addr, ttl := range addresses { + // TODO: Check TTL + if _, exists := currentAddresses[addr]; exists { continue } ip := net.IP(addr) @@ -344,11 +347,16 @@ func (a *App) SyncGroup(group *Group) error { Str("address", ip.String()). Err(err). Msg("failed to add address") + } else { + log.Trace(). + Str("address", ip.String()). + Err(err). + Msg("add address") } } - for addr := range oldIpsetAddresses { - if _, exists := newIpsetAddressesMap[addr]; exists { + for addr := range currentAddresses { + if _, ok := addresses[addr]; ok { continue } ip := net.IP(addr) @@ -362,7 +370,7 @@ func (a *App) SyncGroup(group *Group) error { log.Trace(). Str("address", ip.String()). Err(err). - Msg("add address") + Msg("del address") } } @@ -400,9 +408,9 @@ func (a *App) processARecord(aRecord dns.A) { ttlDuration = a.Config.MinimalTTL } - a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration) + a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration) - names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true) + names := a.Records.GetAliases(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1]) for _, group := range a.Groups { Rule: for _, domain := range group.Rules { @@ -413,6 +421,7 @@ func (a *App) processARecord(aRecord dns.A) { if !domain.IsMatch(name) { continue } + // TODO: Check already existed err := group.AddIPv4(aRecord.A, ttlDuration) if err != nil { log.Error(). @@ -445,12 +454,12 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) { ttlDuration = a.Config.MinimalTTL } - a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration) + a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration) // TODO: Optimization now := time.Now() - aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name) - names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true) + aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1]) + names := a.Records.GetAliases(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1]) for _, group := range a.Groups { Rule: for _, domain := range group.Rules { @@ -496,12 +505,6 @@ func (a *App) handleMessage(msg dns.Msg) { for _, rr := range msg.Answer { a.handleRecord(rr) } - for _, rr := range msg.Ns { - a.handleRecord(rr) - } - for _, rr := range msg.Extra { - a.handleRecord(rr) - } } func New(config Config) (*App, error) { @@ -511,14 +514,10 @@ func New(config Config) (*App, error) { app.Config = config - app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress) + app.DNSMITM = dnsMitmProxy.New() + app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress + app.DNSMITM.TargetDNSServerPort = 53 app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) { - log.Debug(). - Str("network", network). - Str("clientAddr", clientAddr.String()). - Str("name", reqMsg.Question[0].Name). - Msg("received DNS request") - // TODO: Need to understand why it not works in proxy mode if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR { respMsg := &dns.Msg{ @@ -530,10 +529,6 @@ func New(config Config) (*App, error) { }, Question: reqMsg.Question, } - log.Debug(). - Str("network", network). - Str("clientAddr", clientAddr.String()). - Msg("sending DNS response") return nil, respMsg, nil } @@ -551,20 +546,12 @@ func New(config Config) (*App, error) { } respMsg.Answer = respMsg.Answer[:idx] - if len(respMsg.Answer) != 0 { - log.Debug(). - Str("network", network). - Str("clientAddr", clientAddr.String()). - Str("respMsg", respMsg.Answer[0].Header().Name). - Msg("sending DNS response") - } - app.handleMessage(respMsg) return &respMsg, nil } - app.Records = NewRecords() + app.Records = records.New() app.Groups = make(map[uuid.UUID]*Group, 0) link, err := netlink.LinkByName(app.Config.LinkName) diff --git a/netfilter-helper/interface-to-ipset.go b/netfilter-helper/interface-to-ipset.go index 4432abb..bfd5827 100644 --- a/netfilter-helper/interface-to-ipset.go +++ b/netfilter-helper/interface-to-ipset.go @@ -179,7 +179,7 @@ func (r *IfaceToIPSet) ForceEnable() error { // IPTables rules err = r.PutIPTable("all") if err != nil { - return nil + return err } // Mapping mark with table @@ -194,7 +194,7 @@ func (r *IfaceToIPSet) ForceEnable() error { err = r.IfaceHandle() if err != nil { - return nil + return err } r.Enabled = true diff --git a/records.go b/records.go deleted file mode 100644 index ee2a76c..0000000 --- a/records.go +++ /dev/null @@ -1,228 +0,0 @@ -package main - -import ( - "bytes" - "net" - "sync" - "time" -) - -type ARecord struct { - Address net.IP - Deadline time.Time -} - -func NewARecord(addr net.IP, deadline time.Time) *ARecord { - return &ARecord{ - Address: addr, - Deadline: deadline, - } -} - -type CNameRecord struct { - Alias string - Deadline time.Time -} - -func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord { - return &CNameRecord{ - Alias: domainName, - Deadline: deadline, - } -} - -type Records struct { - mutex sync.RWMutex - ARecords map[string][]*ARecord - CNameRecords map[string]*CNameRecord -} - -func (r *Records) cleanupARecords(now time.Time) { - for name, aRecords := range r.ARecords { - i := 0 - for _, aRecord := range aRecords { - if now.After(aRecord.Deadline) { - continue - } - aRecords[i] = aRecord - i++ - } - aRecords = aRecords[:i] - if i == 0 { - delete(r.ARecords, name) - } - } -} - -func (r *Records) cleanupCNameRecords(now time.Time) { - for name, record := range r.CNameRecords { - if now.After(record.Deadline) { - delete(r.CNameRecords, name) - } - } -} - -func (r *Records) getAliasedDomain(now time.Time, domainName string) string { - processedDomains := make(map[string]struct{}) - for { - if _, processed := processedDomains[domainName]; processed { - // Loop detected! - return "" - } else { - processedDomains[domainName] = struct{}{} - } - - cname, ok := r.CNameRecords[domainName] - if !ok { - break - } - if now.After(cname.Deadline) { - delete(r.CNameRecords, domainName) - break - } - domainName = cname.Alias - } - return domainName -} - -func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord { - aRecords, ok := r.ARecords[domainName] - if !ok { - return nil - } - - i := 0 - for _, aRecord := range aRecords { - if now.After(aRecord.Deadline) { - continue - } - aRecords[i] = aRecord - i++ - } - aRecords = aRecords[:i] - if i == 0 { - delete(r.ARecords, domainName) - return nil - } - - return aRecords -} - -func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string { - processedDomains := make(map[string]struct{}) - cNameList := make([]string, 0) - if fromEnd { - domainName = r.getAliasedDomain(now, domainName) - cNameList = append(cNameList, domainName) - } - r.cleanupCNameRecords(now) - for { - if _, processed := processedDomains[domainName]; processed { - // Loop detected! - return nil - } else { - processedDomains[domainName] = struct{}{} - } - - found := false - for aliasFrom, aliasTo := range r.CNameRecords { - if aliasTo.Alias == domainName { - cNameList = append(cNameList, aliasFrom) - domainName = aliasFrom - found = true - break - } - } - if !found { - break - } - } - return cNameList -} - -func (r *Records) Cleanup() { - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - r.cleanupARecords(now) - r.cleanupCNameRecords(now) -} - -func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string { - r.mutex.RLock() - defer r.mutex.RUnlock() - now := time.Now() - - return r.getActualCNames(now, domainName, fromEnd) -} - -func (r *Records) GetARecords(domainName string) []*ARecord { - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - - return r.getActualARecords(now, r.getAliasedDomain(now, domainName)) -} - -func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) { - if domainName == cName { - // Can't assing to yourself - return - } - - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - - delete(r.ARecords, domainName) - r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl)) -} - -func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) { - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - - delete(r.CNameRecords, domainName) - if _, ok := r.ARecords[domainName]; !ok { - r.ARecords[domainName] = make([]*ARecord, 0) - } - for _, aRecord := range r.ARecords[domainName] { - if bytes.Compare(aRecord.Address, addr) == 0 { - aRecord.Deadline = now.Add(ttl) - return - } - } - r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl))) -} - -func (r *Records) ListKnownDomains() []string { - r.mutex.Lock() - defer r.mutex.Unlock() - now := time.Now() - r.cleanupARecords(now) - r.cleanupCNameRecords(now) - - domains := map[string]struct{}{} - for name, _ := range r.ARecords { - domains[name] = struct{}{} - } - for name, _ := range r.CNameRecords { - domains[name] = struct{}{} - } - - domainsList := make([]string, len(domains)) - i := 0 - for name, _ := range domains { - domainsList[i] = name - i++ - } - return domainsList -} - -func NewRecords() *Records { - return &Records{ - ARecords: make(map[string][]*ARecord), - CNameRecords: make(map[string]*CNameRecord), - } -} diff --git a/records/records.go b/records/records.go new file mode 100644 index 0000000..7c6ad97 --- /dev/null +++ b/records/records.go @@ -0,0 +1,167 @@ +package records + +import ( + "bytes" + "net" + "sync" + "time" +) + +type ARecord struct { + Address net.IP + Deadline time.Time +} + +type CNameRecord struct { + Alias string + Deadline time.Time +} + +type Records struct { + mux sync.RWMutex + records map[string]interface{} +} + +func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) { + if domainName == alias { + return + } + + r.mux.Lock() + r.records[domainName] = &CNameRecord{ + Alias: alias, + Deadline: time.Now().Add(ttl), + } + r.mux.Unlock() +} + +func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) { + r.mux.Lock() + defer r.mux.Unlock() + + deadline := time.Now().Add(ttl) + + aRecords, _ := r.records[domainName].([]*ARecord) + for _, aRecord := range aRecords { + if bytes.Compare(aRecord.Address, addr) != 0 { + continue + } + aRecord.Deadline = deadline + return + } + + r.records[domainName] = append(aRecords, &ARecord{ + Address: addr, + Deadline: deadline, + }) +} + +func (r *Records) GetAliases(domainName string) []string { + r.mux.Lock() + defer r.mux.Unlock() + r.cleanupRecords() + + domains := make(map[string]struct{}) + domains[domainName] = struct{}{} + + for { + var addedNew bool + for name, aRecord := range r.records { + if _, ok := domains[name]; ok { + continue + } + cname, ok := aRecord.(*CNameRecord) + if !ok { + continue + } + if _, ok = domains[cname.Alias]; !ok { + continue + } + + domains[name] = struct{}{} + addedNew = true + } + if !addedNew { + break + } + } + + domainList := make([]string, len(domains)) + idx := 0 + for name, _ := range domains { + domainList[idx] = name + idx++ + } + + return domainList +} + +func (r *Records) GetARecords(domainName string) []*ARecord { + r.mux.Lock() + defer r.mux.Unlock() + r.cleanupRecords() + + loopDetect := make(map[string]struct{}) + loopDetect[domainName] = struct{}{} + for { + switch v := r.records[domainName].(type) { + case *CNameRecord: + if _, ok := loopDetect[v.Alias]; ok { + return nil + } + domainName = v.Alias + loopDetect[v.Alias] = struct{}{} + case []*ARecord: + return v + default: + return nil + } + } +} + +func (r *Records) ListKnownDomains() []string { + r.mux.Lock() + defer r.mux.Unlock() + r.cleanupRecords() + + domainsList := make([]string, len(r.records)) + i := 0 + for name, _ := range r.records { + domainsList[i] = name + i++ + } + return domainsList +} + +func (r *Records) cleanupRecords() { + now := time.Now() + for name, records := range r.records { + switch v := records.(type) { + case []*ARecord: + idx := 0 + for _, aRecord := range v { + if now.After(aRecord.Deadline) { + continue + } + v[idx] = aRecord + idx++ + } + if idx == 0 { + delete(r.records, name) + break + } + r.records[name] = v[:idx] + case *CNameRecord: + if !now.After(v.Deadline) { + continue + } + delete(r.records, name) + } + } +} + +func New() *Records { + return &Records{ + records: make(map[string]interface{}), + } +} diff --git a/records/records_test.go b/records/records_test.go new file mode 100644 index 0000000..90d782e --- /dev/null +++ b/records/records_test.go @@ -0,0 +1,109 @@ +package records + +import ( + "bytes" + "slices" + "testing" + "time" +) + +func TestLoop(t *testing.T) { + r := New() + r.AddCNameRecord("1", "2", time.Minute) + r.AddCNameRecord("2", "1", time.Minute) + if r.GetARecords("1") != nil { + t.Fatal("loop detected") + } + if r.GetARecords("2") != nil { + t.Fatal("loop detected") + } +} + +func TestCName(t *testing.T) { + r := New() + r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute) + r.AddCNameRecord("gateway.example.com", "example.com", time.Minute) + records := r.GetARecords("gateway.example.com") + if records == nil { + t.Fatal("no records") + } + if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 { + t.Fatal("cname mismatch") + } +} + +func TestA(t *testing.T) { + r := New() + r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute) + records := r.GetARecords("example.com") + if records == nil { + t.Fatal("no records") + } + if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 { + t.Fatal("cname mismatch") + } +} + +func TestDeprecated(t *testing.T) { + r := New() + r.AddARecord("example.com", []byte{1, 2, 3, 4}, -time.Minute) + records := r.GetARecords("example.com") + if records != nil { + t.Fatal("deprecated records") + } +} + +func TestNotExistedA(t *testing.T) { + r := New() + records := r.GetARecords("example.com") + if records != nil { + t.Fatal("not existed records") + } +} + +func TestNotExistedCNameAlias(t *testing.T) { + r := New() + r.AddCNameRecord("gateway.example.com", "example.com", time.Minute) + records := r.GetARecords("gateway.example.com") + if records != nil { + t.Fatal("not existed records") + } +} + +func TestReplacing(t *testing.T) { + r := New() + r.AddCNameRecord("gateway.example.com", "example.com", time.Minute) + r.AddARecord("gateway.example.com", []byte{1, 2, 3, 4}, time.Minute) + records := r.GetARecords("gateway.example.com") + if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 { + t.Fatal("mismatch") + } +} + +func TestAliases(t *testing.T) { + r := New() + r.AddARecord("1", []byte{1, 2, 3, 4}, time.Minute) + r.AddCNameRecord("2", "1", time.Minute) + r.AddCNameRecord("3", "2", time.Minute) + r.AddCNameRecord("4", "2", time.Minute) + r.AddCNameRecord("5", "1", time.Minute) + aliases := r.GetAliases("1") + if aliases == nil { + t.Fatal("no aliases") + } + if !slices.Contains(aliases, "1") { + t.Fatal("no 1") + } + if !slices.Contains(aliases, "2") { + t.Fatal("no 2") + } + if !slices.Contains(aliases, "3") { + t.Fatal("no 3") + } + if !slices.Contains(aliases, "4") { + t.Fatal("no 4") + } + if !slices.Contains(aliases, "5") { + t.Fatal("no 5") + } +}