From 3a178def2940f9020183b641479ecd649b0dbabc Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sat, 14 Sep 2024 16:13:59 +0300 Subject: [PATCH] refactoring records --- group.go | 26 +--- kvas2.go | 62 ++++++++- records.go | 383 +++++++++++++++++++++++------------------------------ 3 files changed, 228 insertions(+), 243 deletions(-) diff --git a/group.go b/group.go index 982f3e7..232723c 100644 --- a/group.go +++ b/group.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "net" "time" @@ -21,16 +20,12 @@ type Group struct { ifaceToIPSet *netfilterHelper.IfaceToIPSet } -func (g *Group) HandleIPv4(names []string, address net.IP, ttl time.Duration) error { - if !g.Enabled { - return nil - } - +func (g *Group) HandleIPv4(relatedDomains []string, address net.IP, ttl time.Duration) error { for _, domain := range g.Domains { if !domain.IsEnabled() { continue } - for _, name := range names { + for _, name := range relatedDomains { if domain.IsMatch(name) { ttlSeconds := uint32(ttl.Seconds()) return g.ipset.Add(address, &ttlSeconds) @@ -91,20 +86,3 @@ func (g *Group) Disable() []error { return errs } - -func (a *App) AddGroup(group *models.Group) error { - if _, exists := a.Groups[group.ID]; exists { - return ErrGroupIDConflict - } - - ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID) - - a.Groups[group.ID] = &Group{ - Group: group, - iptables: a.NetfilterHelper.IPTables, - ipset: a.NetfilterHelper.IPSet(ipsetName), - ifaceToIPSet: a.NetfilterHelper.IfaceToIPSet(fmt.Sprintf("%sROUTING_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false), - } - - return nil -} diff --git a/kvas2.go b/kvas2.go index 4897869..aa45414 100644 --- a/kvas2.go +++ b/kvas2.go @@ -212,6 +212,61 @@ Loop: return errs } +func (a *App) AddGroup(group *models.Group) error { + if _, exists := a.Groups[group.ID]; exists { + return ErrGroupIDConflict + } + + ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID) + + grp := &Group{ + Group: group, + iptables: a.NetfilterHelper4.IPTables, + ipset: a.NetfilterHelper4.IPSet(ipsetName), + ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false), + } + a.Groups[group.ID] = grp + + domains := a.Records.ListKnownDomains() + processedDomains := make(map[string]struct{}) + for _, domainName := range domains { + if _, exists := processedDomains[domainName]; exists { + continue + } + + for _, domain := range group.Domains { + if !domain.IsMatch(domainName) { + continue + } + + cnames := a.Records.GetCNameRecords(domainName, true) + for _, cname := range cnames { + processedDomains[cname] = struct{}{} + } + + if len(cnames) == 0 { + break + } + + addresses := a.Records.GetARecords(domainName) + for _, address := range addresses { + err := grp.HandleIPv4(cnames, address.Address, time.Now().Sub(address.Deadline)) + if err != nil { + log.Error(). + Str("name", domainName). + Str("address", address.Address.String()). + Int("group", group.ID). + Err(err). + Msg("failed to handle address") + } + } + break + } + } + + return nil +} + func (a *App) ListInterfaces() ([]net.Interface, error) { interfaceNames := make([]net.Interface, 0) @@ -243,9 +298,10 @@ func (a *App) processARecord(aRecord dnsProxy.Address) { ttlDuration = a.Config.MinimalTTL } - a.Records.PutARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) + a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) - names := append([]string{aRecord.Name.String()}, a.Records.GetCNameRecords(aRecord.Name.String(), true, true)...) + // TODO: Optimize + names := a.Records.GetCNameRecords(aRecord.Name.String(), true) for _, group := range a.Groups { err := group.HandleIPv4(names, aRecord.Address, ttlDuration) if err != nil { @@ -265,7 +321,7 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { ttlDuration = a.Config.MinimalTTL } - a.Records.PutCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration) + a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration) } func (a *App) handleRecord(rr dnsProxy.ResourceRecord) { diff --git a/records.go b/records.go index 31c4cc6..9cb52ee 100644 --- a/records.go +++ b/records.go @@ -20,258 +20,209 @@ func NewARecord(addr net.IP, deadline time.Time) *ARecord { } type CNameRecord struct { - CName string + Alias string Deadline time.Time } func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord { return &CNameRecord{ - CName: domainName, + Alias: domainName, Deadline: deadline, } } -type Record struct { - Name string - ARecords []*ARecord - CNameRecords []*CNameRecord -} - -func (r *Record) Cleanup() bool { - i := 0 - for _, record := range r.ARecords { - if time.Now().Before(record.Deadline) { - r.ARecords[i] = record - i++ - } - } - r.ARecords = r.ARecords[:i] - - i = 0 - for _, record := range r.CNameRecords { - if time.Now().Before(record.Deadline) { - r.CNameRecords[i] = record - i++ - } - } - r.CNameRecords = r.CNameRecords[:i] - - return len(r.ARecords) == 0 && len(r.CNameRecords) == 0 -} - -func NewRecord(domainName string) *Record { - return &Record{ - Name: domainName, - ARecords: make([]*ARecord, 0), - CNameRecords: make([]*CNameRecord, 0), - } -} - type Records struct { - mutex sync.RWMutex - Records map[string]*Record + mutex sync.RWMutex + ARecords map[string][]*ARecord + CNameRecords map[string]*CNameRecord } -func (r *Records) getCNames(domainName string, recursive bool, reversive bool) []string { - record, ok := r.Records[domainName] +func (r *Records) cleanupARecords(now time.Time) { + for name, aRecords := range r.ARecords { + i := 0 + for _, aRecord := range aRecords { + if aRecord.Deadline.After(now) { + continue + } + aRecords[i] = aRecord + i++ + } + aRecords = aRecords[:i] + if i == 0 { + delete(r.ARecords, name) + } + } +} + +func (r *Records) cleanupCNameRecords(now time.Time) { + for name, record := range r.CNameRecords { + if record.Deadline.After(now) { + delete(r.CNameRecords, name) + } + } +} + +func (r *Records) getAliasedDomain(now time.Time, domainName string) string { + processedDomains := make(map[string]struct{}) + for { + if _, processed := processedDomains[domainName]; processed { + // Loop detected! + return "" + } else { + processedDomains[domainName] = struct{}{} + } + + cname, ok := r.CNameRecords[domainName] + if !ok { + break + } + if cname.Deadline.After(now) { + delete(r.CNameRecords, domainName) + break + } + domainName = cname.Alias + } + return domainName +} + +func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord { + aRecords, ok := r.ARecords[domainName] if !ok { return nil } - if record.Cleanup() { - delete(r.Records, domainName) + + i := 0 + for _, aRecord := range aRecords { + if aRecord.Deadline.After(now) { + continue + } + aRecords[i] = aRecord + i++ + } + aRecords = aRecords[:i] + if i == 0 { + delete(r.ARecords, domainName) return nil } - excludedFromCNameList := map[string]struct{}{ - domainName: {}, - } - - cNameList := make([]string, 0) - for _, cnameRecord := range record.CNameRecords { - if _, exists := excludedFromCNameList[cnameRecord.CName]; !exists { - cNameList = append(cNameList, cnameRecord.CName) - excludedFromCNameList[cnameRecord.CName] = struct{}{} - } - } - - if recursive { - excludedFromProcess := map[string]struct{}{ - domainName: {}, - } - - processingList := cNameList - for len(processingList) > 0 { - newProcessingList := []string{} - for _, cname := range processingList { - if _, exists := excludedFromProcess[cname]; exists { - continue - } - - record, ok := r.Records[cname] - if !ok { - continue - } - if record.Cleanup() { - delete(r.Records, cname) - continue - } - - for _, cNameRecord := range record.CNameRecords { - if _, exists := excludedFromCNameList[cNameRecord.CName]; !exists { - cNameList = append(cNameList, cNameRecord.CName) - excludedFromCNameList[cNameRecord.CName] = struct{}{} - } - newProcessingList = append(newProcessingList, cNameRecord.CName) - } - } - processingList = newProcessingList - } - } - - if reversive { - excludedFromProcess := make(map[string]struct{}) - processingList := []string{domainName} - for len(processingList) > 0 { - nextProcessingList := make([]string, 0) - for _, target := range processingList { - if _, exists := excludedFromProcess[target]; exists { - continue - } - - for cname, record := range r.Records { - if record.Cleanup() { - delete(r.Records, cname) - continue - } - - for _, cnameRecord := range record.CNameRecords { - if cnameRecord.CName != target { - continue - } - - if _, exists := excludedFromCNameList[record.Name]; !exists { - cNameList = append(cNameList, record.Name) - excludedFromCNameList[record.Name] = struct{}{} - } - nextProcessingList = append(nextProcessingList, record.Name) - break - } - } - - excludedFromProcess[target] = struct{}{} - } - processingList = nextProcessingList - } - } - - return cNameList -} - -func (r *Records) GetCNameRecords(domainName string, recursive bool, reversive bool) []string { - r.mutex.RLock() - defer r.mutex.RUnlock() - - return r.getCNames(domainName, recursive, reversive) -} - -func (r *Records) GetARecords(domainName string, recursive bool, reversive bool) []net.IP { - r.mutex.RLock() - defer r.mutex.RUnlock() - - cNameList := []string{domainName} - if recursive { - cNameList = append(cNameList, r.getCNames(domainName, true, reversive)...) - } - - aRecords := make([]net.IP, 0) - for _, cName := range cNameList { - record, ok := r.Records[cName] - if !ok { - continue - } - if record.Cleanup() { - delete(r.Records, cName) - continue - } - - for _, aRecord := range record.ARecords { - aRecords = append(aRecords, aRecord.Address) - } - } - return aRecords } -func (r *Records) PutCNameRecord(domainName string, cName string, ttl time.Duration) { - r.mutex.Lock() - defer r.mutex.Unlock() - - record, ok := r.Records[domainName] - if !ok { - record = NewRecord(domainName) - r.Records[domainName] = record +func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string { + processedDomains := make(map[string]struct{}) + cNameList := make([]string, 0) + if fromEnd { + domainName = r.getAliasedDomain(now, domainName) + cNameList = append(cNameList, domainName) } - record.Cleanup() + r.cleanupCNameRecords(now) + for { + if _, processed := processedDomains[domainName]; processed { + // Loop detected! + return nil + } else { + processedDomains[domainName] = struct{}{} + } - for _, cNameRecord := range record.CNameRecords { - if cNameRecord.CName == cName { - cNameRecord.Deadline = time.Now().Add(ttl) - return + found := false + for aliasFrom, aliasTo := range r.CNameRecords { + if aliasTo.Alias == domainName { + cNameList = append(cNameList, aliasFrom) + domainName = aliasFrom + found = true + break + } + } + if !found { + break } } - - record.CNameRecords = append(record.CNameRecords, NewCNameRecord(cName, time.Now().Add(ttl))) -} - -func (r *Records) PutARecord(domainName string, addr net.IP, ttl time.Duration) { - r.mutex.Lock() - defer r.mutex.Unlock() - - record, ok := r.Records[domainName] - if !ok { - record = NewRecord(domainName) - r.Records[domainName] = record - } - record.Cleanup() - - for _, aRecord := range record.ARecords { - if bytes.Compare(aRecord.Address, addr) == 0 { - aRecord.Deadline = time.Now().Add(ttl) - return - } - } - record.ARecords = append(record.ARecords, NewARecord(addr, time.Now().Add(ttl))) -} - -func (r *Records) ListKnownDomains() []string { - r.mutex.Lock() - defer r.mutex.Unlock() - - domains := make([]string, 0) - for name, record := range r.Records { - if record.Cleanup() { - delete(r.Records, name) - continue - } - domains = append(domains, name) - } - - return domains + return cNameList } func (r *Records) Cleanup() { r.mutex.Lock() defer r.mutex.Unlock() + now := time.Now() + r.cleanupARecords(now) + r.cleanupCNameRecords(now) +} - for domainName, record := range r.Records { - if record.Cleanup() { - delete(r.Records, domainName) +func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string { + r.mutex.RLock() + defer r.mutex.RUnlock() + now := time.Now() + + return r.getActualCNames(now, domainName, fromEnd) +} + +func (r *Records) GetARecords(domainName string) []*ARecord { + r.mutex.Lock() + defer r.mutex.Unlock() + now := time.Now() + + return r.getActualARecords(now, r.getAliasedDomain(now, domainName)) +} + +func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) { + if domainName == cName { + // Can't assing to yourself + return + } + + r.mutex.Lock() + defer r.mutex.Unlock() + now := time.Now() + + delete(r.ARecords, domainName) + r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl)) +} + +func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) { + r.mutex.Lock() + defer r.mutex.Unlock() + now := time.Now() + + delete(r.CNameRecords, domainName) + if _, ok := r.ARecords[domainName]; !ok { + r.ARecords[domainName] = make([]*ARecord, 0) + } + for _, aRecord := range r.ARecords[domainName] { + if bytes.Compare(aRecord.Address, addr) == 0 { + aRecord.Deadline = now.Add(ttl) + return } } + r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl))) +} + +func (r *Records) ListKnownDomains() []string { + r.mutex.Lock() + defer r.mutex.Unlock() + now := time.Now() + r.cleanupARecords(now) + r.cleanupCNameRecords(now) + + domains := map[string]struct{}{} + for name, _ := range r.ARecords { + domains[name] = struct{}{} + } + for name, _ := range r.CNameRecords { + domains[name] = struct{}{} + } + + domainsList := make([]string, len(domains)) + i := 0 + for name, _ := range domains { + domainsList[i] = name + i++ + } + return domainsList } func NewRecords() *Records { return &Records{ - Records: make(map[string]*Record), + ARecords: make(map[string][]*ARecord), + CNameRecords: make(map[string]*CNameRecord), } }