From 14e0054c1bdfbfa7723b214ccc24326e2631c5d0 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sun, 25 Aug 2024 00:18:15 +0300 Subject: [PATCH] refactoring records and ttl cleanup --- main.go | 12 +++-- rule-composer/records.go | 113 ++++++++++++++++----------------------- 2 files changed, 55 insertions(+), 70 deletions(-) diff --git a/main.go b/main.go index 65b02b8..42b562c 100644 --- a/main.go +++ b/main.go @@ -22,15 +22,19 @@ func main() { proxy.MsgHandler = func(msg *dnsProxy.Message) { for _, q := range msg.QD { fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String()) + fmt.Printf("%x: DBG (Before) Known addresses for: %s\n", msg.ID, q.QName.String()) + for idx, addr := range records.GetARecords(q.QName.String(), true) { + fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) + } } for _, a := range msg.AN { switch v := a.(type) { case dnsProxy.Address: fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) - records.PutIPv4Address(v.Name.String(), v.Address, int64(v.TTL)) + records.PutARecord(v.Name.String(), v.Address, int64(v.TTL)) case dnsProxy.CName: fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) - records.PutCName(v.Name.String(), v.CName.String(), int64(v.TTL)) + records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL)) default: fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) } @@ -43,8 +47,8 @@ func main() { } for _, q := range msg.QD { - fmt.Printf("%x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) - for idx, addr := range records.GetIPv4Addresses(q.QName.String()) { + fmt.Printf("%x: DBG (After) Known addresses for: %s\n", msg.ID, q.QName.String()) + for idx, addr := range records.GetARecords(q.QName.String(), true) { fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) } } diff --git a/rule-composer/records.go b/rule-composer/records.go index 86ba1c2..152c6ac 100644 --- a/rule-composer/records.go +++ b/rule-composer/records.go @@ -1,114 +1,95 @@ package ruleComposer import ( - "bytes" "net" "sync" "time" ) type Records struct { - mutex sync.RWMutex - ipv4Addresses map[string]map[string]time.Time - cNames map[string]map[string]time.Time + mutex sync.RWMutex + aRecords map[string]map[string]time.Time + cnameRecords map[string]map[string]time.Time } -func (r *Records) getCNames(domainName string) []string { - _, ok := r.cNames[domainName] - if !ok { - return nil - } - - cNameList := make([]string, 0, len(r.cNames[domainName])) - for cname, ttl := range r.cNames[domainName] { - if time.Now().Sub(ttl).Nanoseconds() < 0 { - cNameList = append(cNameList, cname) +func (r *Records) getCNames(domainName string, recursive bool) []string { + cNameList := make([]string, 0) + for cname, ttl := range r.cnameRecords[domainName] { + if time.Now().Sub(ttl).Nanoseconds() > 0 { + delete(r.cnameRecords[domainName], cname) + continue } + cNameList = append(cNameList, cname) } - origCNameLen := len(cNameList) - for i := 0; i < origCNameLen; i++ { - parentList := r.getCNames(cNameList[i]) - if parentList != nil { - cNameList = append(cNameList, parentList...) + if recursive { + origCNameLen := len(cNameList) + for i := 0; i < origCNameLen; i++ { + parentList := r.getCNames(cNameList[i], true) + if parentList != nil { + cNameList = append(cNameList, parentList...) + } } } return cNameList } -func (r *Records) GetIPv4Addresses(domainName string) []net.IP { +func (r *Records) GetCNameRecords(domainName string, recursive bool) []string { r.mutex.RLock() defer r.mutex.RUnlock() - cNameList := append([]string{domainName}, r.getCNames(domainName)...) - ipAddresses := make([]net.IP, 0) + return r.getCNames(domainName, recursive) +} + +func (r *Records) GetARecords(domainName string, recursive bool) []net.IP { + r.mutex.RLock() + defer r.mutex.RUnlock() + + cNameList := []string{domainName} + if recursive { + cNameList = append(cNameList, r.getCNames(domainName, true)...) + } + + aRecords := make([]net.IP, 0) for _, cName := range cNameList { - addresses, ok := r.ipv4Addresses[cName] - if !ok { - continue - } - - addressesNetIP := make([]net.IP, 0, len(addresses)) - for addr, ttl := range addresses { - if time.Now().Sub(ttl).Nanoseconds() < 0 { - addressesNetIP = append(addressesNetIP, []byte(addr)) + for addr, ttl := range r.aRecords[cName] { + if time.Now().Sub(ttl).Nanoseconds() > 0 { + delete(r.aRecords[cName], addr) + continue } + aRecords = append(aRecords, []byte(addr)) } - - ipAddresses = append(ipAddresses, addressesNetIP...) } - return ipAddresses + return aRecords } -func (r *Records) PutCName(domainName string, cName string, ttl int64) { +func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) { r.mutex.Lock() defer r.mutex.Unlock() - if r.cNames[domainName] == nil { - r.cNames[domainName] = make(map[string]time.Time) + if r.cnameRecords[domainName] == nil { + r.cnameRecords[domainName] = make(map[string]time.Time) } - skipPut := false - for name, _ := range r.cNames[domainName] { - if name == cName { - r.cNames[domainName][name] = time.Now().Add(time.Second * time.Duration(ttl)) - skipPut = true - break - } - } - - if !skipPut { - r.cNames[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl)) - } + r.cnameRecords[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl)) } -func (r *Records) PutIPv4Address(domainName string, addr net.IP, ttl int64) { +func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) { r.mutex.Lock() defer r.mutex.Unlock() - if r.ipv4Addresses[domainName] == nil { - r.ipv4Addresses[domainName] = make(map[string]time.Time) + if r.aRecords[domainName] == nil { + r.aRecords[domainName] = make(map[string]time.Time) } - skipPut := false - for address, _ := range r.ipv4Addresses[domainName] { - if bytes.Compare([]byte(address), addr) == 0 { - r.ipv4Addresses[domainName][address] = time.Now().Add(time.Second * time.Duration(ttl)) - skipPut = true - break - } - } - - if !skipPut { - r.ipv4Addresses[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl)) - } + r.aRecords[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl)) } func NewRecords() *Records { return &Records{ - ipv4Addresses: make(map[string]map[string]time.Time), - cNames: make(map[string]map[string]time.Time), + aRecords: make(map[string]map[string]time.Time), + cnameRecords: make(map[string]map[string]time.Time), } }