Compare commits

...

5 Commits

Author SHA1 Message Date
ff6ab7b859 refactor PortRemap 2025-02-11 15:53:58 +03:00
b537007c9a smol refactor 2025-02-11 15:53:09 +03:00
0f4820f499 fix name 2025-02-11 15:37:38 +03:00
5fd28ae005 refactor IPSet 2025-02-11 15:29:26 +03:00
066eeb0ab7 refactor CleanIPTables 2025-02-11 15:22:08 +03:00
6 changed files with 79 additions and 78 deletions

View File

@ -21,17 +21,17 @@ type Group struct {
ipsetToLink *netfilterHelper.IPSetToLink ipsetToLink *netfilterHelper.IPSetToLink
} }
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error { func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds()) ttlSeconds := uint32(ttl.Seconds())
return g.ipset.AddIP(address, &ttlSeconds) return g.ipset.AddIP(address, &ttlSeconds)
} }
func (g *Group) DelIPv4(address net.IP) error { func (g *Group) DelIP(address net.IP) error {
return g.ipset.Del(address) return g.ipset.DelIP(address)
} }
func (g *Group) ListIPv4() (map[string]*uint32, error) { func (g *Group) ListIP() (map[string]*uint32, error) {
return g.ipset.List() return g.ipset.ListIPs()
} }
func (g *Group) Enable() error { func (g *Group) Enable() error {

View File

@ -204,17 +204,13 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":") args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" { if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event") log.Debug().Str("table", args[2]).Msg("netfilter.d event")
if a.dnsOverrider4.Enabled { err = a.dnsOverrider4.NetfilerDHook(args[2])
err := a.dnsOverrider4.PutIPTable(args[2]) if err != nil {
if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
if a.dnsOverrider6.Enabled { err = a.dnsOverrider6.NetfilerDHook(args[2])
err = a.dnsOverrider6.PutIPTable(args[2]) if err != nil {
if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
for _, group := range a.Groups { for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2]) err := group.ipsetToLink.NetfilerDHook(args[2])
@ -295,7 +291,7 @@ func (a *App) AddGroup(group *models.Group) error {
Group: group, Group: group,
iptables: a.NetfilterHelper4.IPTables, iptables: a.NetfilterHelper4.IPTables,
ipset: ipset, ipset: ipset,
ipsetToLink: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false), ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
} }
a.Groups[grp.ID] = grp a.Groups[grp.ID] = grp
return a.SyncGroup(grp) return a.SyncGroup(grp)
@ -326,7 +322,7 @@ func (a *App) SyncGroup(group *Group) error {
} }
} }
currentAddresses, err := group.ListIPv4() currentAddresses, err := group.ListIP()
if err != nil { if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err) return fmt.Errorf("failed to get old ipset list: %w", err)
} }
@ -337,7 +333,7 @@ func (a *App) SyncGroup(group *Group) error {
continue continue
} }
ip := net.IP(addr) ip := net.IP(addr)
err = group.AddIPv4(ip, ttl) err = group.AddIP(ip, ttl)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", ip.String()). Str("address", ip.String()).
@ -356,7 +352,7 @@ func (a *App) SyncGroup(group *Group) error {
continue continue
} }
ip := net.IP(addr) ip := net.IP(addr)
err = group.DelIPv4(ip) err = group.DelIP(ip)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", ip.String()). Str("address", ip.String()).
@ -418,7 +414,7 @@ func (a *App) processARecord(aRecord dns.A) {
continue continue
} }
// TODO: Check already existed // TODO: Check already existed
err := group.AddIPv4(aRecord.A, ttlDuration) err := group.AddIP(aRecord.A, ttlDuration)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.A.String()). Str("address", aRecord.A.String()).
@ -467,7 +463,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue continue
} }
for _, aRecord := range aRecords { for _, aRecord := range aRecords {
err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline)) err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.Address.String()). Str("address", aRecord.Address.String()).
@ -561,7 +557,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err) return nil, fmt.Errorf("netfilter helper init fail: %w", err)
} }
app.NetfilterHelper4 = nh4 app.NetfilterHelper4 = nh4
err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix) err = app.NetfilterHelper4.CleanIPTables(app.Config.ChainPrefix)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err) return nil, fmt.Errorf("failed to clear iptables: %w", err)
} }
@ -571,7 +567,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("netfilter helper init fail: %w", err) return nil, fmt.Errorf("netfilter helper init fail: %w", err)
} }
app.NetfilterHelper6 = nh6 app.NetfilterHelper6 = nh6
err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix) err = app.NetfilterHelper6.CleanIPTables(app.Config.ChainPrefix)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to clear iptables: %w", err) return nil, fmt.Errorf("failed to clear iptables: %w", err)
} }

View File

@ -76,9 +76,8 @@ func (r *IPSetToLink) insertIPTablesRules(table string) error {
func (r *IPSetToLink) deleteIPTablesRules() []error { func (r *IPSetToLink) deleteIPTablesRules() []error {
var errs []error var errs []error
var err error
err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) err := r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName)
if err != nil { if err != nil {
errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err))
} }
@ -280,7 +279,7 @@ func (r *IPSetToLink) LinkUpdateHook() error {
return r.insertIPRoute() return r.insertIPRoute()
} }
func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink { func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink {
return &IPSetToLink{ return &IPSetToLink{
IPTables: nh.IPTables, IPTables: nh.IPTables,
ChainName: name, ChainName: name,

View File

@ -2,9 +2,10 @@ package netfilterHelper
import ( import (
"fmt" "fmt"
"github.com/vishvananda/netlink"
"net" "net"
"os" "os"
"github.com/vishvananda/netlink"
) )
type IPSet struct { type IPSet struct {
@ -23,7 +24,7 @@ func (r *IPSet) AddIP(addr net.IP, timeout *uint32) error {
return nil return nil
} }
func (r *IPSet) Del(addr net.IP) error { func (r *IPSet) DelIP(addr net.IP) error {
err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{ err := netlink.IpsetDel(r.SetName, &netlink.IPSetEntry{
IP: addr, IP: addr,
}) })
@ -33,7 +34,7 @@ func (r *IPSet) Del(addr net.IP) error {
return nil return nil
} }
func (r *IPSet) List() (map[string]*uint32, error) { func (r *IPSet) ListIPs() (map[string]*uint32, error) {
list, err := netlink.IpsetList(r.SetName) list, err := netlink.IpsetList(r.SetName)
if err != nil { if err != nil {
return nil, err return nil, err
@ -62,9 +63,8 @@ func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
return nil, err return nil, err
} }
defaultTimeout := uint32(300)
err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{ err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{
Timeout: &defaultTimeout, Timeout: func(i uint32) *uint32 { return &i }(300),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create ipset: %w", err) return nil, fmt.Errorf("failed to create ipset: %w", err)

View File

@ -5,11 +5,9 @@ import (
"strings" "strings"
) )
func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error { func (nh *NetfilterHelper) CleanIPTables(chainPrefix string) error {
jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix) jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix)
tableList := []string{"nat", "mangle", "filter"} for _, table := range []string{"nat", "mangle", "filter"} {
for _, table := range tableList {
chainListToDelete := make([]string, 0) chainListToDelete := make([]string, 0)
chains, err := nh.IPTables.ListChains(table) chains, err := nh.IPTables.ListChains(table)
@ -29,15 +27,8 @@ func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
} }
for _, rule := range rules { for _, rule := range rules {
ruleSlice := strings.Split(rule, " ") if strings.Contains(rule, jumpToChainPrefix) {
if len(ruleSlice) < 2 || ruleSlice[0] != "-A" || ruleSlice[1] != chain { err = nh.IPTables.Delete(table, chain, rule)
// TODO: Warn
continue
}
ruleSlice = ruleSlice[2:]
if strings.Contains(strings.Join(ruleSlice, " "), jumpToChainPrefix) {
err := nh.IPTables.Delete(table, chain, ruleSlice...)
if err != nil { if err != nil {
return fmt.Errorf("rule deletion error: %w", err) return fmt.Errorf("rule deletion error: %w", err)
} }
@ -46,7 +37,7 @@ func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
} }
for _, chain := range chainListToDelete { for _, chain := range chainListToDelete {
err := nh.IPTables.ClearAndDeleteChain(table, chain) err = nh.IPTables.ClearAndDeleteChain(table, chain)
if err != nil { if err != nil {
return fmt.Errorf("deleting chain error: %w", err) return fmt.Errorf("deleting chain error: %w", err)
} }

View File

@ -16,34 +16,32 @@ type PortRemap struct {
From uint16 From uint16
To uint16 To uint16
Enabled bool enabled bool
} }
func (r *PortRemap) PutIPTable(table string) error { func (r *PortRemap) insertIPTablesRules(table string) error {
if table == "all" || table == "nat" { if table == "" || table == "nat" {
err := r.IPTables.ClearChain("nat", r.ChainName) err := r.IPTables.NewChain("nat", r.ChainName)
if err != nil { if err != nil {
return fmt.Errorf("failed to clear chain: %w", err) // If not "AlreadyExists"
if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) {
return fmt.Errorf("failed to create chain: %w", err)
}
} }
for _, addr := range r.Addresses { for _, addr := range r.Addresses {
var addrIP net.IP if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) {
iptablesProtocol := r.IPTables.Proto()
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
addrIP = addr.IP
}
if addrIP == nil {
continue continue
} }
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) for _, iptablesArgs := range [][]string{
if err != nil { {"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
return fmt.Errorf("failed to create rule: %w", err) {"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
} } {
err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...)
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) if err != nil {
if err != nil { return fmt.Errorf("failed to append rule: %w", err)
return fmt.Errorf("failed to create rule: %w", err) }
} }
} }
@ -56,17 +54,7 @@ func (r *PortRemap) PutIPTable(table string) error {
return nil return nil
} }
func (r *PortRemap) ForceEnable() error { func (r *PortRemap) deleteIPTablesRules() []error {
err := r.PutIPTable("all")
if err != nil {
return err
}
r.Enabled = true
return nil
}
func (r *PortRemap) Disable() []error {
var errs []error var errs []error
err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName) err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName)
@ -79,16 +67,30 @@ func (r *PortRemap) Disable() []error {
errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) errs = append(errs, fmt.Errorf("failed to delete chain: %w", err))
} }
r.Enabled = false
return errs return errs
} }
func (r *PortRemap) enable() error {
err := r.insertIPTablesRules("")
if err != nil {
return err
}
r.enabled = true
return nil
}
func (r *PortRemap) Enable() error { func (r *PortRemap) Enable() error {
if r.Enabled { if r.enabled {
return nil return nil
} }
err := r.ForceEnable() err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = r.enable()
if err != nil { if err != nil {
r.Disable() r.Disable()
return err return err
@ -97,6 +99,19 @@ func (r *PortRemap) Enable() error {
return nil return nil
} }
func (r *PortRemap) Disable() []error {
errs := r.deleteIPTablesRules()
r.enabled = false
return errs
}
func (r *PortRemap) NetfilerDHook(table string) error {
if !r.enabled {
return nil
}
return r.insertIPTablesRules(table)
}
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap { func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
return &PortRemap{ return &PortRemap{
IPTables: nh.IPTables, IPTables: nh.IPTables,