refactoring records

This commit is contained in:
Vladimir Avtsenov 2024-09-14 16:13:59 +03:00
parent 9d667e3982
commit 3a178def29
3 changed files with 228 additions and 243 deletions

View File

@ -1,7 +1,6 @@
package main
import (
"fmt"
"net"
"time"
@ -21,16 +20,12 @@ type Group struct {
ifaceToIPSet *netfilterHelper.IfaceToIPSet
}
func (g *Group) HandleIPv4(names []string, address net.IP, ttl time.Duration) error {
if !g.Enabled {
return nil
}
func (g *Group) HandleIPv4(relatedDomains []string, address net.IP, ttl time.Duration) error {
for _, domain := range g.Domains {
if !domain.IsEnabled() {
continue
}
for _, name := range names {
for _, name := range relatedDomains {
if domain.IsMatch(name) {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.Add(address, &ttlSeconds)
@ -91,20 +86,3 @@ func (g *Group) Disable() []error {
return errs
}
func (a *App) AddGroup(group *models.Group) error {
if _, exists := a.Groups[group.ID]; exists {
return ErrGroupIDConflict
}
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID)
a.Groups[group.ID] = &Group{
Group: group,
iptables: a.NetfilterHelper.IPTables,
ipset: a.NetfilterHelper.IPSet(ipsetName),
ifaceToIPSet: a.NetfilterHelper.IfaceToIPSet(fmt.Sprintf("%sROUTING_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false),
}
return nil
}

View File

@ -212,6 +212,61 @@ Loop:
return errs
}
func (a *App) AddGroup(group *models.Group) error {
if _, exists := a.Groups[group.ID]; exists {
return ErrGroupIDConflict
}
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID)
grp := &Group{
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: a.NetfilterHelper4.IPSet(ipsetName),
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false),
}
a.Groups[group.ID] = grp
domains := a.Records.ListKnownDomains()
processedDomains := make(map[string]struct{})
for _, domainName := range domains {
if _, exists := processedDomains[domainName]; exists {
continue
}
for _, domain := range group.Domains {
if !domain.IsMatch(domainName) {
continue
}
cnames := a.Records.GetCNameRecords(domainName, true)
for _, cname := range cnames {
processedDomains[cname] = struct{}{}
}
if len(cnames) == 0 {
break
}
addresses := a.Records.GetARecords(domainName)
for _, address := range addresses {
err := grp.HandleIPv4(cnames, address.Address, time.Now().Sub(address.Deadline))
if err != nil {
log.Error().
Str("name", domainName).
Str("address", address.Address.String()).
Int("group", group.ID).
Err(err).
Msg("failed to handle address")
}
}
break
}
}
return nil
}
func (a *App) ListInterfaces() ([]net.Interface, error) {
interfaceNames := make([]net.Interface, 0)
@ -243,9 +298,10 @@ func (a *App) processARecord(aRecord dnsProxy.Address) {
ttlDuration = a.Config.MinimalTTL
}
a.Records.PutARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
names := append([]string{aRecord.Name.String()}, a.Records.GetCNameRecords(aRecord.Name.String(), true, true)...)
// TODO: Optimize
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
for _, group := range a.Groups {
err := group.HandleIPv4(names, aRecord.Address, ttlDuration)
if err != nil {
@ -265,7 +321,7 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
ttlDuration = a.Config.MinimalTTL
}
a.Records.PutCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
}
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {

View File

@ -20,258 +20,209 @@ func NewARecord(addr net.IP, deadline time.Time) *ARecord {
}
type CNameRecord struct {
CName string
Alias string
Deadline time.Time
}
func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord {
return &CNameRecord{
CName: domainName,
Alias: domainName,
Deadline: deadline,
}
}
type Record struct {
Name string
ARecords []*ARecord
CNameRecords []*CNameRecord
}
func (r *Record) Cleanup() bool {
i := 0
for _, record := range r.ARecords {
if time.Now().Before(record.Deadline) {
r.ARecords[i] = record
i++
}
}
r.ARecords = r.ARecords[:i]
i = 0
for _, record := range r.CNameRecords {
if time.Now().Before(record.Deadline) {
r.CNameRecords[i] = record
i++
}
}
r.CNameRecords = r.CNameRecords[:i]
return len(r.ARecords) == 0 && len(r.CNameRecords) == 0
}
func NewRecord(domainName string) *Record {
return &Record{
Name: domainName,
ARecords: make([]*ARecord, 0),
CNameRecords: make([]*CNameRecord, 0),
}
}
type Records struct {
mutex sync.RWMutex
Records map[string]*Record
mutex sync.RWMutex
ARecords map[string][]*ARecord
CNameRecords map[string]*CNameRecord
}
func (r *Records) getCNames(domainName string, recursive bool, reversive bool) []string {
record, ok := r.Records[domainName]
func (r *Records) cleanupARecords(now time.Time) {
for name, aRecords := range r.ARecords {
i := 0
for _, aRecord := range aRecords {
if aRecord.Deadline.After(now) {
continue
}
aRecords[i] = aRecord
i++
}
aRecords = aRecords[:i]
if i == 0 {
delete(r.ARecords, name)
}
}
}
func (r *Records) cleanupCNameRecords(now time.Time) {
for name, record := range r.CNameRecords {
if record.Deadline.After(now) {
delete(r.CNameRecords, name)
}
}
}
func (r *Records) getAliasedDomain(now time.Time, domainName string) string {
processedDomains := make(map[string]struct{})
for {
if _, processed := processedDomains[domainName]; processed {
// Loop detected!
return ""
} else {
processedDomains[domainName] = struct{}{}
}
cname, ok := r.CNameRecords[domainName]
if !ok {
break
}
if cname.Deadline.After(now) {
delete(r.CNameRecords, domainName)
break
}
domainName = cname.Alias
}
return domainName
}
func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord {
aRecords, ok := r.ARecords[domainName]
if !ok {
return nil
}
if record.Cleanup() {
delete(r.Records, domainName)
i := 0
for _, aRecord := range aRecords {
if aRecord.Deadline.After(now) {
continue
}
aRecords[i] = aRecord
i++
}
aRecords = aRecords[:i]
if i == 0 {
delete(r.ARecords, domainName)
return nil
}
excludedFromCNameList := map[string]struct{}{
domainName: {},
}
cNameList := make([]string, 0)
for _, cnameRecord := range record.CNameRecords {
if _, exists := excludedFromCNameList[cnameRecord.CName]; !exists {
cNameList = append(cNameList, cnameRecord.CName)
excludedFromCNameList[cnameRecord.CName] = struct{}{}
}
}
if recursive {
excludedFromProcess := map[string]struct{}{
domainName: {},
}
processingList := cNameList
for len(processingList) > 0 {
newProcessingList := []string{}
for _, cname := range processingList {
if _, exists := excludedFromProcess[cname]; exists {
continue
}
record, ok := r.Records[cname]
if !ok {
continue
}
if record.Cleanup() {
delete(r.Records, cname)
continue
}
for _, cNameRecord := range record.CNameRecords {
if _, exists := excludedFromCNameList[cNameRecord.CName]; !exists {
cNameList = append(cNameList, cNameRecord.CName)
excludedFromCNameList[cNameRecord.CName] = struct{}{}
}
newProcessingList = append(newProcessingList, cNameRecord.CName)
}
}
processingList = newProcessingList
}
}
if reversive {
excludedFromProcess := make(map[string]struct{})
processingList := []string{domainName}
for len(processingList) > 0 {
nextProcessingList := make([]string, 0)
for _, target := range processingList {
if _, exists := excludedFromProcess[target]; exists {
continue
}
for cname, record := range r.Records {
if record.Cleanup() {
delete(r.Records, cname)
continue
}
for _, cnameRecord := range record.CNameRecords {
if cnameRecord.CName != target {
continue
}
if _, exists := excludedFromCNameList[record.Name]; !exists {
cNameList = append(cNameList, record.Name)
excludedFromCNameList[record.Name] = struct{}{}
}
nextProcessingList = append(nextProcessingList, record.Name)
break
}
}
excludedFromProcess[target] = struct{}{}
}
processingList = nextProcessingList
}
}
return cNameList
}
func (r *Records) GetCNameRecords(domainName string, recursive bool, reversive bool) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()
return r.getCNames(domainName, recursive, reversive)
}
func (r *Records) GetARecords(domainName string, recursive bool, reversive bool) []net.IP {
r.mutex.RLock()
defer r.mutex.RUnlock()
cNameList := []string{domainName}
if recursive {
cNameList = append(cNameList, r.getCNames(domainName, true, reversive)...)
}
aRecords := make([]net.IP, 0)
for _, cName := range cNameList {
record, ok := r.Records[cName]
if !ok {
continue
}
if record.Cleanup() {
delete(r.Records, cName)
continue
}
for _, aRecord := range record.ARecords {
aRecords = append(aRecords, aRecord.Address)
}
}
return aRecords
}
func (r *Records) PutCNameRecord(domainName string, cName string, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
record, ok := r.Records[domainName]
if !ok {
record = NewRecord(domainName)
r.Records[domainName] = record
func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string {
processedDomains := make(map[string]struct{})
cNameList := make([]string, 0)
if fromEnd {
domainName = r.getAliasedDomain(now, domainName)
cNameList = append(cNameList, domainName)
}
record.Cleanup()
r.cleanupCNameRecords(now)
for {
if _, processed := processedDomains[domainName]; processed {
// Loop detected!
return nil
} else {
processedDomains[domainName] = struct{}{}
}
for _, cNameRecord := range record.CNameRecords {
if cNameRecord.CName == cName {
cNameRecord.Deadline = time.Now().Add(ttl)
return
found := false
for aliasFrom, aliasTo := range r.CNameRecords {
if aliasTo.Alias == domainName {
cNameList = append(cNameList, aliasFrom)
domainName = aliasFrom
found = true
break
}
}
if !found {
break
}
}
record.CNameRecords = append(record.CNameRecords, NewCNameRecord(cName, time.Now().Add(ttl)))
}
func (r *Records) PutARecord(domainName string, addr net.IP, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
record, ok := r.Records[domainName]
if !ok {
record = NewRecord(domainName)
r.Records[domainName] = record
}
record.Cleanup()
for _, aRecord := range record.ARecords {
if bytes.Compare(aRecord.Address, addr) == 0 {
aRecord.Deadline = time.Now().Add(ttl)
return
}
}
record.ARecords = append(record.ARecords, NewARecord(addr, time.Now().Add(ttl)))
}
func (r *Records) ListKnownDomains() []string {
r.mutex.Lock()
defer r.mutex.Unlock()
domains := make([]string, 0)
for name, record := range r.Records {
if record.Cleanup() {
delete(r.Records, name)
continue
}
domains = append(domains, name)
}
return domains
return cNameList
}
func (r *Records) Cleanup() {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
r.cleanupARecords(now)
r.cleanupCNameRecords(now)
}
for domainName, record := range r.Records {
if record.Cleanup() {
delete(r.Records, domainName)
func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()
now := time.Now()
return r.getActualCNames(now, domainName, fromEnd)
}
func (r *Records) GetARecords(domainName string) []*ARecord {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
return r.getActualARecords(now, r.getAliasedDomain(now, domainName))
}
func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) {
if domainName == cName {
// Can't assing to yourself
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
delete(r.ARecords, domainName)
r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl))
}
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
delete(r.CNameRecords, domainName)
if _, ok := r.ARecords[domainName]; !ok {
r.ARecords[domainName] = make([]*ARecord, 0)
}
for _, aRecord := range r.ARecords[domainName] {
if bytes.Compare(aRecord.Address, addr) == 0 {
aRecord.Deadline = now.Add(ttl)
return
}
}
r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl)))
}
func (r *Records) ListKnownDomains() []string {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
r.cleanupARecords(now)
r.cleanupCNameRecords(now)
domains := map[string]struct{}{}
for name, _ := range r.ARecords {
domains[name] = struct{}{}
}
for name, _ := range r.CNameRecords {
domains[name] = struct{}{}
}
domainsList := make([]string, len(domains))
i := 0
for name, _ := range domains {
domainsList[i] = name
i++
}
return domainsList
}
func NewRecords() *Records {
return &Records{
Records: make(map[string]*Record),
ARecords: make(map[string][]*ARecord),
CNameRecords: make(map[string]*CNameRecord),
}
}