From e780b58df1fd2c564eee349792da1294b068b8d2 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Tue, 27 Aug 2024 01:44:17 +0300 Subject: [PATCH] reversive records listing --- kvas2.go | 52 ++++++++++++++++++++++------- records.go | 97 ++++++++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 29 deletions(-) diff --git a/kvas2.go b/kvas2.go index 08479a3..62286d1 100644 --- a/kvas2.go +++ b/kvas2.go @@ -76,11 +76,48 @@ func (a *App) Listen(ctx context.Context) []error { return errs } +func (a *App) processARecord(aRecord dnsProxy.Address) { + ttlDuration := time.Duration(aRecord.TTL) * time.Second + if ttlDuration < a.Config.MinimalTTL { + ttlDuration = a.Config.MinimalTTL + } + + a.Records.PutARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) + + cNames := append([]string{aRecord.Name.String()}, a.Records.GetCNameRecords(aRecord.Name.String(), true, true)...) + fmt.Printf("Relates CNames:\n") + for idx, cName := range cNames { + fmt.Printf("|- #%d: %s\n", idx, cName) + } + + for _, group := range a.Groups { + for _, domain := range group.Domains { + if !domain.IsEnabled() { + continue + } + for _, cName := range cNames { + if domain.IsMatch(cName) { + fmt.Printf("|- Matched %s (%s) for %s in %s group!\n", cName, aRecord.Name, domain.Domain, group.Name) + } + } + } + } +} + +func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { + ttlDuration := time.Duration(cNameRecord.TTL) * time.Second + if ttlDuration < a.Config.MinimalTTL { + ttlDuration = a.Config.MinimalTTL + } + + a.Records.PutCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration) +} + func (a *App) handleRecord(msg *dnsProxy.Message) { printKnownRecords := func() { for _, q := range msg.QD { fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) - for idx, addr := range a.Records.GetARecords(q.QName.String(), true) { + for idx, addr := range a.Records.GetARecords(q.QName.String(), true, false) { fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String()) } } @@ -89,18 +126,10 @@ func (a *App) handleRecord(msg *dnsProxy.Message) { switch v := rr.(type) { case dnsProxy.Address: fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) - ttlDuration := time.Duration(v.TTL) * time.Second - if ttlDuration < a.Config.MinimalTTL { - ttlDuration = a.Config.MinimalTTL - } - a.Records.PutARecord(v.Name.String(), v.Address, ttlDuration) + a.processARecord(v) case dnsProxy.CName: fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) - ttlDuration := time.Duration(v.TTL) * time.Second - if ttlDuration < a.Config.MinimalTTL { - ttlDuration = a.Config.MinimalTTL - } - a.Records.PutCNameRecord(v.Name.String(), v.CName.String(), ttlDuration) + a.processCNameRecord(v) default: fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) } @@ -120,6 +149,7 @@ func (a *App) handleRecord(msg *dnsProxy.Message) { parseResponseRecord(a) } printKnownRecords() + fmt.Println() } func New(config Config) (*App, error) { diff --git a/records.go b/records.go index 0328af9..8297b0c 100644 --- a/records.go +++ b/records.go @@ -3,7 +3,6 @@ package main import ( "bytes" "net" - "slices" "sync" "time" ) @@ -71,7 +70,7 @@ type Records struct { Records map[string]*Record } -func (r *Records) getCNames(domainName string, recursive bool, excludeDomains ...string) []string { +func (r *Records) getCNames(domainName string, recursive bool, reversive bool) []string { record, ok := r.Records[domainName] if !ok { return nil @@ -81,43 +80,105 @@ func (r *Records) getCNames(domainName string, recursive bool, excludeDomains .. return nil } - cNameList := make([]string, len(record.CNameRecords)) - for idx, cnameRecord := range record.CNameRecords { - cNameList[idx] = cnameRecord.CName + excludedFromCNameList := map[string]struct{}{ + domainName: {}, + } + + cNameList := make([]string, 0) + for _, cnameRecord := range record.CNameRecords { + if _, exists := excludedFromCNameList[cnameRecord.CName]; !exists { + cNameList = append(cNameList, cnameRecord.CName) + excludedFromCNameList[cnameRecord.CName] = struct{}{} + } } if recursive { - origCNameLen := len(cNameList) - for i := 0; i < origCNameLen; i++ { - if slices.Contains(excludeDomains, cNameList[i]) { - continue - } + excludedFromProcess := map[string]struct{}{ + domainName: {}, + } - excludeDomains = append(excludeDomains, cNameList...) - parentList := r.getCNames(cNameList[i], true, excludeDomains...) - if parentList != nil { - cNameList = append(cNameList, parentList...) + processingList := cNameList + for len(processingList) > 0 { + newProcessingList := []string{} + for _, cname := range processingList { + if _, exists := excludedFromProcess[cname]; exists { + continue + } + + record, ok := r.Records[cname] + if !ok { + continue + } + if record.Cleanup() { + delete(r.Records, cname) + continue + } + + for _, cNameRecord := range record.CNameRecords { + if _, exists := excludedFromCNameList[cNameRecord.CName]; !exists { + cNameList = append(cNameList, cNameRecord.CName) + excludedFromCNameList[cNameRecord.CName] = struct{}{} + } + newProcessingList = append(newProcessingList, cNameRecord.CName) + } } + processingList = newProcessingList + } + } + + if reversive { + excludedFromProcess := make(map[string]struct{}) + processingList := []string{domainName} + for len(processingList) > 0 { + nextProcessingList := make([]string, 0) + for _, target := range processingList { + if _, exists := excludedFromProcess[target]; exists { + continue + } + + for cname, record := range r.Records { + if record.Cleanup() { + delete(r.Records, cname) + continue + } + + for _, cnameRecord := range record.CNameRecords { + if cnameRecord.CName != target { + continue + } + + if _, exists := excludedFromCNameList[record.Name]; !exists { + cNameList = append(cNameList, record.Name) + excludedFromCNameList[record.Name] = struct{}{} + } + nextProcessingList = append(nextProcessingList, record.Name) + break + } + } + + excludedFromProcess[target] = struct{}{} + } + processingList = nextProcessingList } } return cNameList } -func (r *Records) GetCNameRecords(domainName string, recursive bool) []string { +func (r *Records) GetCNameRecords(domainName string, recursive bool, reversive bool) []string { r.mutex.RLock() defer r.mutex.RUnlock() - return r.getCNames(domainName, recursive) + return r.getCNames(domainName, recursive, reversive) } -func (r *Records) GetARecords(domainName string, recursive bool) []net.IP { +func (r *Records) GetARecords(domainName string, recursive bool, reversive bool) []net.IP { r.mutex.RLock() defer r.mutex.RUnlock() cNameList := []string{domainName} if recursive { - cNameList = append(cNameList, r.getCNames(domainName, true)...) + cNameList = append(cNameList, r.getCNames(domainName, true, reversive)...) } aRecords := make([]net.IP, 0)