diff --git a/records.go b/records.go index 4eee4c1..0328af9 100644 --- a/records.go +++ b/records.go @@ -1,35 +1,100 @@ package main import ( + "bytes" "net" + "slices" "sync" "time" ) -type Records struct { - mutex sync.RWMutex - aRecords map[string]map[string]time.Time - cnameRecords map[string]map[string]time.Time +type ARecord struct { + Address net.IP + Deadline time.Time } -func (r *Records) getCNames(domainName string, recursive bool) []string { - cNameList := make([]string, 0) - for cname, ttl := range r.cnameRecords[domainName] { - if time.Now().Sub(ttl).Nanoseconds() > 0 { - delete(r.cnameRecords[domainName], cname) - if len(r.cnameRecords[domainName]) == 0 { - delete(r.cnameRecords, domainName) - return nil - } - continue +func NewARecord(addr net.IP, deadline time.Time) *ARecord { + return &ARecord{ + Address: addr, + Deadline: deadline, + } +} + +type CNameRecord struct { + CName string + Deadline time.Time +} + +func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord { + return &CNameRecord{ + CName: domainName, + Deadline: deadline, + } +} + +type Record struct { + Name string + ARecords []*ARecord + CNameRecords []*CNameRecord +} + +func (r *Record) Cleanup() bool { + newARecords := make([]*ARecord, 0) + for _, record := range r.ARecords { + if time.Now().Sub(record.Deadline).Nanoseconds() <= 0 { + newARecords = append(newARecords, record) } - cNameList = append(cNameList, cname) + } + r.ARecords = newARecords + + newCNameRecords := make([]*CNameRecord, 0) + for _, record := range r.CNameRecords { + if time.Now().Sub(record.Deadline).Nanoseconds() <= 0 { + newCNameRecords = append(newCNameRecords, record) + } + } + r.CNameRecords = newCNameRecords + + return len(newARecords) == 0 && len(newCNameRecords) == 0 +} + +func NewRecord(domainName string) *Record { + return &Record{ + Name: domainName, + ARecords: make([]*ARecord, 0), + CNameRecords: make([]*CNameRecord, 0), + } +} + +type Records struct { + mutex sync.RWMutex + Records map[string]*Record +} + +func (r *Records) getCNames(domainName string, recursive bool, excludeDomains ...string) []string { + record, ok := r.Records[domainName] + if !ok { + return nil + } + if record.Cleanup() { + delete(r.Records, domainName) + return nil + } + + cNameList := make([]string, len(record.CNameRecords)) + for idx, cnameRecord := range record.CNameRecords { + cNameList[idx] = cnameRecord.CName } if recursive { origCNameLen := len(cNameList) for i := 0; i < origCNameLen; i++ { - parentList := r.getCNames(cNameList[i], true) + if slices.Contains(excludeDomains, cNameList[i]) { + continue + } + + excludeDomains = append(excludeDomains, cNameList...) + parentList := r.getCNames(cNameList[i], true, excludeDomains...) if parentList != nil { cNameList = append(cNameList, parentList...) } @@ -57,16 +122,17 @@ func (r *Records) GetARecords(domainName string, recursive bool) []net.IP { aRecords := make([]net.IP, 0) for _, cName := range cNameList { - for addr, ttl := range r.aRecords[cName] { - if time.Now().Sub(ttl).Nanoseconds() > 0 { - delete(r.aRecords[cName], addr) - if len(r.aRecords[cName]) == 0 { - delete(r.aRecords, cName) - break - } - continue - } - aRecords = append(aRecords, []byte(addr)) + record, ok := r.Records[cName] + if !ok { + continue + } + if record.Cleanup() { + delete(r.Records, cName) + continue + } + + for _, aRecord := range record.ARecords { + aRecords = append(aRecords, aRecord.Address) } } @@ -77,86 +143,56 @@ func (r *Records) PutCNameRecord(domainName string, cName string, ttl time.Durat r.mutex.Lock() defer r.mutex.Unlock() - if r.cnameRecords[domainName] == nil { - r.cnameRecords[domainName] = make(map[string]time.Time) + record, ok := r.Records[domainName] + if !ok { + record = NewRecord(domainName) + r.Records[domainName] = record + } + record.Cleanup() + + for _, cNameRecord := range record.CNameRecords { + if cNameRecord.CName == cName { + cNameRecord.Deadline = time.Now().Add(ttl) + return + } } - r.cnameRecords[domainName][cName] = time.Now().Add(ttl) + record.CNameRecords = append(record.CNameRecords, NewCNameRecord(cName, time.Now().Add(ttl))) } func (r *Records) PutARecord(domainName string, addr net.IP, ttl time.Duration) { r.mutex.Lock() defer r.mutex.Unlock() - if r.aRecords[domainName] == nil { - r.aRecords[domainName] = make(map[string]time.Time) + record, ok := r.Records[domainName] + if !ok { + record = NewRecord(domainName) + r.Records[domainName] = record } + record.Cleanup() - r.aRecords[domainName][string(addr)] = time.Now().Add(ttl) -} - -func (r *Records) ListKnownCNameRecords() []string { - r.mutex.RLock() - defer r.mutex.RUnlock() - - domains := make([]string, len(r.cnameRecords)) - counter := 0 - for domain, _ := range r.cnameRecords { - domains[counter] = domain - counter++ + for _, aRecord := range record.ARecords { + if bytes.Compare(aRecord.Address, addr) == 0 { + aRecord.Deadline = time.Now().Add(ttl) + return + } } - - return domains -} - -func (r *Records) ListKnownARecords() []string { - r.mutex.RLock() - defer r.mutex.RUnlock() - - domains := make([]string, len(r.aRecords)) - counter := 0 - for domain, _ := range r.aRecords { - domains[counter] = domain - counter++ - } - - return domains + record.ARecords = append(record.ARecords, NewARecord(addr, time.Now().Add(ttl))) } func (r *Records) Cleanup() { r.mutex.Lock() defer r.mutex.Unlock() - for domainName, _ := range r.aRecords { - for aRecord, ttl := range r.aRecords[domainName] { - if time.Now().Sub(ttl).Nanoseconds() <= 0 { - continue - } - delete(r.aRecords[domainName], aRecord) - if len(r.aRecords[domainName]) == 0 { - delete(r.aRecords, domainName) - break - } - } - } - - for domainName, _ := range r.cnameRecords { - for cname, ttl := range r.cnameRecords[domainName] { - if time.Now().Sub(ttl).Nanoseconds() <= 0 { - continue - } - delete(r.cnameRecords[domainName], cname) - if len(r.cnameRecords[domainName]) == 0 { - delete(r.cnameRecords, domainName) - break - } + for domainName, record := range r.Records { + if record.Cleanup() { + delete(r.Records, domainName) } } } func NewRecords() *Records { return &Records{ - aRecords: make(map[string]map[string]time.Time), - cnameRecords: make(map[string]map[string]time.Time), + Records: make(map[string]*Record), } }