From f5c77f719c7bbb474917eee271160d78a70ef348 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sat, 14 Sep 2024 18:20:44 +0300 Subject: [PATCH] refactoring ipset --- group.go | 35 +++++-------- kvas2.go | 101 ++++++++++++++++++++++++++++---------- netfilter-helper/ipset.go | 45 ++++++++++++----- 3 files changed, 117 insertions(+), 64 deletions(-) diff --git a/group.go b/group.go index 232723c..324c0c2 100644 --- a/group.go +++ b/group.go @@ -20,20 +20,17 @@ type Group struct { ifaceToIPSet *netfilterHelper.IfaceToIPSet } -func (g *Group) HandleIPv4(relatedDomains []string, address net.IP, ttl time.Duration) error { - for _, domain := range g.Domains { - if !domain.IsEnabled() { - continue - } - for _, name := range relatedDomains { - if domain.IsMatch(name) { - ttlSeconds := uint32(ttl.Seconds()) - return g.ipset.Add(address, &ttlSeconds) - } - } - } +func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error { + ttlSeconds := uint32(ttl.Seconds()) + return g.ipset.Add(address, &ttlSeconds) +} - return nil +func (g *Group) DelIPv4(address net.IP) error { + return g.ipset.Del(address) +} + +func (g *Group) ListIPv4() (map[string]*uint32, error) { + return g.ipset.List() } func (g *Group) Enable() error { @@ -50,12 +47,7 @@ func (g *Group) Enable() error { g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT") } - err := g.ipset.Create() - if err != nil { - return err - } - - err = g.ifaceToIPSet.Enable() + err := g.ifaceToIPSet.Enable() if err != nil { return err } @@ -77,11 +69,6 @@ func (g *Group) Disable() []error { errs = append(errs, errs2...) } - err := g.ipset.Destroy() - if err != nil { - errs = append(errs, err) - } - g.Enabled = false return errs diff --git a/kvas2.go b/kvas2.go index aa45414..e5774ac 100644 --- a/kvas2.go +++ b/kvas2.go @@ -218,49 +218,85 @@ func (a *App) AddGroup(group *models.Group) error { } ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID) + ipset, err := a.NetfilterHelper4.IPSet(ipsetName) + if err != nil { + return fmt.Errorf("failed to initialize ipset: %w", err) + } grp := &Group{ Group: group, iptables: a.NetfilterHelper4.IPTables, - ipset: a.NetfilterHelper4.IPSet(ipsetName), + ipset: ipset, ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false), } a.Groups[group.ID] = grp + return a.SyncGroup(grp) +} - domains := a.Records.ListKnownDomains() +func (a *App) SyncGroup(group *Group) error { processedDomains := make(map[string]struct{}) - for _, domainName := range domains { - if _, exists := processedDomains[domainName]; exists { + newIpsetAddressesMap := make(map[string]time.Duration) + now := time.Now() + + oldIpsetAddresses, err := group.ListIPv4() + if err != nil { + return fmt.Errorf("failed to get old ipset list: %w", err) + } + + knownDomains := a.Records.ListKnownDomains() + for _, domain := range group.Domains { + if !domain.IsEnabled() { continue } - for _, domain := range group.Domains { + for _, domainName := range knownDomains { if !domain.IsMatch(domainName) { continue } cnames := a.Records.GetCNameRecords(domainName, true) + if len(cnames) == 0 { + continue + } for _, cname := range cnames { processedDomains[cname] = struct{}{} } - if len(cnames) == 0 { - break - } - addresses := a.Records.GetARecords(domainName) for _, address := range addresses { - err := grp.HandleIPv4(cnames, address.Address, time.Now().Sub(address.Deadline)) - if err != nil { - log.Error(). - Str("name", domainName). - Str("address", address.Address.String()). - Int("group", group.ID). - Err(err). - Msg("failed to handle address") + ttl := now.Sub(address.Deadline) + if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL { + newIpsetAddressesMap[string(address.Address)] = ttl } } - break + } + } + + for addr, ttl := range newIpsetAddressesMap { + if _, exists := oldIpsetAddresses[addr]; exists { + continue + } + ip := net.IP(addr) + err = group.AddIPv4(ip, ttl) + if err != nil { + log.Error(). + Str("address", ip.String()). + Err(err). + Msg("failed to add address") + } + } + + for addr, _ := range oldIpsetAddresses { + if _, exists := newIpsetAddressesMap[addr]; exists { + continue + } + ip := net.IP(addr) + err = group.DelIPv4(ip) + if err != nil { + log.Error(). + Str("address", ip.String()). + Err(err). + Msg("failed to delete address") } } @@ -300,22 +336,32 @@ func (a *App) processARecord(aRecord dnsProxy.Address) { a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) - // TODO: Optimize names := a.Records.GetCNameRecords(aRecord.Name.String(), true) for _, group := range a.Groups { - err := group.HandleIPv4(names, aRecord.Address, ttlDuration) - if err != nil { - log.Error(). - Str("name", aRecord.Name.String()). - Str("address", aRecord.Address.String()). - Int("group", group.ID). - Err(err). - Msg("failed to handle address") + for _, domain := range group.Domains { + for _, name := range names { + if !domain.IsMatch(name) { + continue + } + err := group.AddIPv4(aRecord.Address, ttlDuration) + if err != nil { + log.Error(). + Str("address", aRecord.Address.String()). + Err(err). + Msg("failed to add address") + } + } } } } func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { + log.Trace(). + Str("name", cNameRecord.Name.String()). + Str("cname", cNameRecord.CName.String()). + Int("ttl", int(cNameRecord.TTL)). + Msg("processing cname record") + ttlDuration := time.Duration(cNameRecord.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL @@ -327,6 +373,7 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { func (a *App) handleRecord(rr dnsProxy.ResourceRecord) { switch v := rr.(type) { case dnsProxy.Address: + // TODO: Optimize equals domain A records a.processARecord(v) case dnsProxy.CName: a.processCNameRecord(v) diff --git a/netfilter-helper/ipset.go b/netfilter-helper/ipset.go index 81979c5..01e798d 100644 --- a/netfilter-helper/ipset.go +++ b/netfilter-helper/ipset.go @@ -23,23 +23,28 @@ func (r *IPSet) Add(addr net.IP, timeout *uint32) error { return nil } -func (r *IPSet) Create() error { - err := r.Destroy() - if err != nil { - return err - } - - defaultTimeout := uint32(300) - err = netlink.IpsetCreate(r.SetName, "hash:ip", netlink.IpsetCreateOptions{ - Timeout: &defaultTimeout, +func (r *IPSet) Del(addr net.IP) error { + err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{ + IP: addr, }) if err != nil { - return fmt.Errorf("failed to create ipset: %w", err) + return fmt.Errorf("failed to delete address: %w", err) } - return nil } +func (r *IPSet) List() (map[string]*uint32, error) { + list, err := netlink.IpsetList(r.SetName) + if err != nil { + return nil, err + } + addresses := make(map[string]*uint32) + for _, entry := range list.Entries { + addresses[string(entry.IP)] = entry.Timeout + } + return addresses, nil +} + func (r *IPSet) Destroy() error { err := netlink.IpsetDestroy(r.SetName) if err != nil && !os.IsNotExist(err) { @@ -48,8 +53,22 @@ func (r *IPSet) Destroy() error { return nil } -func (nh *NetfilterHelper) IPSet(name string) *IPSet { - return &IPSet{ +func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) { + ipset := &IPSet{ SetName: name, } + err := ipset.Destroy() + if err != nil { + return nil, err + } + + defaultTimeout := uint32(300) + err = netlink.IpsetCreate(ipset.SetName, "hash:ip", netlink.IpsetCreateOptions{ + Timeout: &defaultTimeout, + }) + if err != nil { + return nil, fmt.Errorf("failed to create ipset: %w", err) + } + + return ipset, nil }