refactoring ipset

This commit is contained in:
Vladimir Avtsenov 2024-09-14 18:20:44 +03:00
parent 3a178def29
commit f5c77f719c
3 changed files with 117 additions and 64 deletions

View File

@ -20,20 +20,17 @@ type Group struct {
ifaceToIPSet *netfilterHelper.IfaceToIPSet
}
func (g *Group) HandleIPv4(relatedDomains []string, address net.IP, ttl time.Duration) error {
for _, domain := range g.Domains {
if !domain.IsEnabled() {
continue
}
for _, name := range relatedDomains {
if domain.IsMatch(name) {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.Add(address, &ttlSeconds)
}
}
}
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.Add(address, &ttlSeconds)
}
return nil
func (g *Group) DelIPv4(address net.IP) error {
return g.ipset.Del(address)
}
func (g *Group) ListIPv4() (map[string]*uint32, error) {
return g.ipset.List()
}
func (g *Group) Enable() error {
@ -50,12 +47,7 @@ func (g *Group) Enable() error {
g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
}
err := g.ipset.Create()
if err != nil {
return err
}
err = g.ifaceToIPSet.Enable()
err := g.ifaceToIPSet.Enable()
if err != nil {
return err
}
@ -77,11 +69,6 @@ func (g *Group) Disable() []error {
errs = append(errs, errs2...)
}
err := g.ipset.Destroy()
if err != nil {
errs = append(errs, err)
}
g.Enabled = false
return errs

101
kvas2.go
View File

@ -218,49 +218,85 @@ func (a *App) AddGroup(group *models.Group) error {
}
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID)
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
if err != nil {
return fmt.Errorf("failed to initialize ipset: %w", err)
}
grp := &Group{
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: a.NetfilterHelper4.IPSet(ipsetName),
ipset: ipset,
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false),
}
a.Groups[group.ID] = grp
return a.SyncGroup(grp)
}
domains := a.Records.ListKnownDomains()
func (a *App) SyncGroup(group *Group) error {
processedDomains := make(map[string]struct{})
for _, domainName := range domains {
if _, exists := processedDomains[domainName]; exists {
newIpsetAddressesMap := make(map[string]time.Duration)
now := time.Now()
oldIpsetAddresses, err := group.ListIPv4()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
knownDomains := a.Records.ListKnownDomains()
for _, domain := range group.Domains {
if !domain.IsEnabled() {
continue
}
for _, domain := range group.Domains {
for _, domainName := range knownDomains {
if !domain.IsMatch(domainName) {
continue
}
cnames := a.Records.GetCNameRecords(domainName, true)
if len(cnames) == 0 {
continue
}
for _, cname := range cnames {
processedDomains[cname] = struct{}{}
}
if len(cnames) == 0 {
break
}
addresses := a.Records.GetARecords(domainName)
for _, address := range addresses {
err := grp.HandleIPv4(cnames, address.Address, time.Now().Sub(address.Deadline))
if err != nil {
log.Error().
Str("name", domainName).
Str("address", address.Address.String()).
Int("group", group.ID).
Err(err).
Msg("failed to handle address")
ttl := now.Sub(address.Deadline)
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
newIpsetAddressesMap[string(address.Address)] = ttl
}
}
break
}
}
for addr, ttl := range newIpsetAddressesMap {
if _, exists := oldIpsetAddresses[addr]; exists {
continue
}
ip := net.IP(addr)
err = group.AddIPv4(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to add address")
}
}
for addr, _ := range oldIpsetAddresses {
if _, exists := newIpsetAddressesMap[addr]; exists {
continue
}
ip := net.IP(addr)
err = group.DelIPv4(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to delete address")
}
}
@ -300,22 +336,32 @@ func (a *App) processARecord(aRecord dnsProxy.Address) {
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
// TODO: Optimize
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
for _, group := range a.Groups {
err := group.HandleIPv4(names, aRecord.Address, ttlDuration)
if err != nil {
log.Error().
Str("name", aRecord.Name.String()).
Str("address", aRecord.Address.String()).
Int("group", group.ID).
Err(err).
Msg("failed to handle address")
for _, domain := range group.Domains {
for _, name := range names {
if !domain.IsMatch(name) {
continue
}
err := group.AddIPv4(aRecord.Address, ttlDuration)
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
Err(err).
Msg("failed to add address")
}
}
}
}
}
func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
log.Trace().
Str("name", cNameRecord.Name.String()).
Str("cname", cNameRecord.CName.String()).
Int("ttl", int(cNameRecord.TTL)).
Msg("processing cname record")
ttlDuration := time.Duration(cNameRecord.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
@ -327,6 +373,7 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {
switch v := rr.(type) {
case dnsProxy.Address:
// TODO: Optimize equals domain A records
a.processARecord(v)
case dnsProxy.CName:
a.processCNameRecord(v)

View File

@ -23,23 +23,28 @@ func (r *IPSet) Add(addr net.IP, timeout *uint32) error {
return nil
}
func (r *IPSet) Create() error {
err := r.Destroy()
if err != nil {
return err
}
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(r.SetName, "hash:ip", netlink.IpsetCreateOptions{
Timeout: &defaultTimeout,
func (r *IPSet) Del(addr net.IP) error {
err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{
IP: addr,
})
if err != nil {
return fmt.Errorf("failed to create ipset: %w", err)
return fmt.Errorf("failed to delete address: %w", err)
}
return nil
}
func (r *IPSet) List() (map[string]*uint32, error) {
list, err := netlink.IpsetList(r.SetName)
if err != nil {
return nil, err
}
addresses := make(map[string]*uint32)
for _, entry := range list.Entries {
addresses[string(entry.IP)] = entry.Timeout
}
return addresses, nil
}
func (r *IPSet) Destroy() error {
err := netlink.IpsetDestroy(r.SetName)
if err != nil && !os.IsNotExist(err) {
@ -48,8 +53,22 @@ func (r *IPSet) Destroy() error {
return nil
}
func (nh *NetfilterHelper) IPSet(name string) *IPSet {
return &IPSet{
func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
ipset := &IPSet{
SetName: name,
}
err := ipset.Destroy()
if err != nil {
return nil, err
}
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(ipset.SetName, "hash:ip", netlink.IpsetCreateOptions{
Timeout: &defaultTimeout,
})
if err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err)
}
return ipset, nil
}