refactoring records and ttl cleanup

This commit is contained in:
Vladimir Avtsenov 2024-08-25 00:18:15 +03:00
parent 7abea45c5c
commit 14e0054c1b
2 changed files with 55 additions and 70 deletions

12
main.go
View File

@ -22,15 +22,19 @@ func main() {
proxy.MsgHandler = func(msg *dnsProxy.Message) { proxy.MsgHandler = func(msg *dnsProxy.Message) {
for _, q := range msg.QD { for _, q := range msg.QD {
fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String()) fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String())
fmt.Printf("%x: DBG (Before) Known addresses for: %s\n", msg.ID, q.QName.String())
for idx, addr := range records.GetARecords(q.QName.String(), true) {
fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String())
}
} }
for _, a := range msg.AN { for _, a := range msg.AN {
switch v := a.(type) { switch v := a.(type) {
case dnsProxy.Address: case dnsProxy.Address:
fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
records.PutIPv4Address(v.Name.String(), v.Address, int64(v.TTL)) records.PutARecord(v.Name.String(), v.Address, int64(v.TTL))
case dnsProxy.CName: case dnsProxy.CName:
fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
records.PutCName(v.Name.String(), v.CName.String(), int64(v.TTL)) records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL))
default: default:
fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
} }
@ -43,8 +47,8 @@ func main() {
} }
for _, q := range msg.QD { for _, q := range msg.QD {
fmt.Printf("%x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) fmt.Printf("%x: DBG (After) Known addresses for: %s\n", msg.ID, q.QName.String())
for idx, addr := range records.GetIPv4Addresses(q.QName.String()) { for idx, addr := range records.GetARecords(q.QName.String(), true) {
fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String())
} }
} }

View File

@ -1,7 +1,6 @@
package ruleComposer package ruleComposer
import ( import (
"bytes"
"net" "net"
"sync" "sync"
"time" "time"
@ -9,106 +8,88 @@ import (
type Records struct { type Records struct {
mutex sync.RWMutex mutex sync.RWMutex
ipv4Addresses map[string]map[string]time.Time aRecords map[string]map[string]time.Time
cNames map[string]map[string]time.Time cnameRecords map[string]map[string]time.Time
} }
func (r *Records) getCNames(domainName string) []string { func (r *Records) getCNames(domainName string, recursive bool) []string {
_, ok := r.cNames[domainName] cNameList := make([]string, 0)
if !ok { for cname, ttl := range r.cnameRecords[domainName] {
return nil if time.Now().Sub(ttl).Nanoseconds() > 0 {
delete(r.cnameRecords[domainName], cname)
continue
} }
cNameList := make([]string, 0, len(r.cNames[domainName]))
for cname, ttl := range r.cNames[domainName] {
if time.Now().Sub(ttl).Nanoseconds() < 0 {
cNameList = append(cNameList, cname) cNameList = append(cNameList, cname)
} }
}
if recursive {
origCNameLen := len(cNameList) origCNameLen := len(cNameList)
for i := 0; i < origCNameLen; i++ { for i := 0; i < origCNameLen; i++ {
parentList := r.getCNames(cNameList[i]) parentList := r.getCNames(cNameList[i], true)
if parentList != nil { if parentList != nil {
cNameList = append(cNameList, parentList...) cNameList = append(cNameList, parentList...)
} }
} }
}
return cNameList return cNameList
} }
func (r *Records) GetIPv4Addresses(domainName string) []net.IP { func (r *Records) GetCNameRecords(domainName string, recursive bool) []string {
r.mutex.RLock() r.mutex.RLock()
defer r.mutex.RUnlock() defer r.mutex.RUnlock()
cNameList := append([]string{domainName}, r.getCNames(domainName)...) return r.getCNames(domainName, recursive)
ipAddresses := make([]net.IP, 0) }
func (r *Records) GetARecords(domainName string, recursive bool) []net.IP {
r.mutex.RLock()
defer r.mutex.RUnlock()
cNameList := []string{domainName}
if recursive {
cNameList = append(cNameList, r.getCNames(domainName, true)...)
}
aRecords := make([]net.IP, 0)
for _, cName := range cNameList { for _, cName := range cNameList {
addresses, ok := r.ipv4Addresses[cName] for addr, ttl := range r.aRecords[cName] {
if !ok { if time.Now().Sub(ttl).Nanoseconds() > 0 {
delete(r.aRecords[cName], addr)
continue continue
} }
aRecords = append(aRecords, []byte(addr))
addressesNetIP := make([]net.IP, 0, len(addresses))
for addr, ttl := range addresses {
if time.Now().Sub(ttl).Nanoseconds() < 0 {
addressesNetIP = append(addressesNetIP, []byte(addr))
} }
} }
ipAddresses = append(ipAddresses, addressesNetIP...) return aRecords
}
return ipAddresses
} }
func (r *Records) PutCName(domainName string, cName string, ttl int64) { func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.cNames[domainName] == nil { if r.cnameRecords[domainName] == nil {
r.cNames[domainName] = make(map[string]time.Time) r.cnameRecords[domainName] = make(map[string]time.Time)
} }
skipPut := false r.cnameRecords[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl))
for name, _ := range r.cNames[domainName] {
if name == cName {
r.cNames[domainName][name] = time.Now().Add(time.Second * time.Duration(ttl))
skipPut = true
break
}
}
if !skipPut {
r.cNames[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl))
}
} }
func (r *Records) PutIPv4Address(domainName string, addr net.IP, ttl int64) { func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
if r.ipv4Addresses[domainName] == nil { if r.aRecords[domainName] == nil {
r.ipv4Addresses[domainName] = make(map[string]time.Time) r.aRecords[domainName] = make(map[string]time.Time)
} }
skipPut := false r.aRecords[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl))
for address, _ := range r.ipv4Addresses[domainName] {
if bytes.Compare([]byte(address), addr) == 0 {
r.ipv4Addresses[domainName][address] = time.Now().Add(time.Second * time.Duration(ttl))
skipPut = true
break
}
}
if !skipPut {
r.ipv4Addresses[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl))
}
} }
func NewRecords() *Records { func NewRecords() *Records {
return &Records{ return &Records{
ipv4Addresses: make(map[string]map[string]time.Time), aRecords: make(map[string]map[string]time.Time),
cNames: make(map[string]map[string]time.Time), cnameRecords: make(map[string]map[string]time.Time),
} }
} }