From ff6ab7b85976f5d3b9c37fa7926d44d7a76039b1 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Tue, 11 Feb 2025 15:53:58 +0300 Subject: [PATCH] refactor PortRemap --- kvas2.go | 16 +++---- netfilter-helper/port-remap.go | 81 ++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/kvas2.go b/kvas2.go index feda0c6..2cc0853 100644 --- a/kvas2.go +++ b/kvas2.go @@ -204,17 +204,13 @@ func (a *App) start(ctx context.Context) (err error) { args := strings.Split(string(buf[:n]), ":") if len(args) == 3 && args[0] == "netfilter.d" { log.Debug().Str("table", args[2]).Msg("netfilter.d event") - if a.dnsOverrider4.Enabled { - err := a.dnsOverrider4.PutIPTable(args[2]) - if err != nil { - log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") - } + err = a.dnsOverrider4.NetfilerDHook(args[2]) + if err != nil { + log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } - if a.dnsOverrider6.Enabled { - err = a.dnsOverrider6.PutIPTable(args[2]) - if err != nil { - log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") - } + err = a.dnsOverrider6.NetfilerDHook(args[2]) + if err != nil { + log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } for _, group := range a.Groups { err := group.ipsetToLink.NetfilerDHook(args[2]) diff --git a/netfilter-helper/port-remap.go b/netfilter-helper/port-remap.go index b0a2787..5c20e18 100644 --- a/netfilter-helper/port-remap.go +++ b/netfilter-helper/port-remap.go @@ -16,34 +16,32 @@ type PortRemap struct { From uint16 To uint16 - Enabled bool + enabled bool } -func (r *PortRemap) PutIPTable(table string) error { - if table == "all" || table == "nat" { - err := r.IPTables.ClearChain("nat", r.ChainName) +func (r *PortRemap) insertIPTablesRules(table string) error { + if table == "" || table == "nat" { + err := r.IPTables.NewChain("nat", r.ChainName) if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) + // If not "AlreadyExists" + if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) { + return fmt.Errorf("failed to create chain: %w", err) + } } for _, addr := range r.Addresses { - var addrIP net.IP - iptablesProtocol := r.IPTables.Proto() - if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) { - addrIP = addr.IP - } - if addrIP == nil { + if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) { continue } - err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) + for _, iptablesArgs := range [][]string{ + {"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)}, + {"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)}, + } { + err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...) + if err != nil { + return fmt.Errorf("failed to append rule: %w", err) + } } } @@ -56,17 +54,7 @@ func (r *PortRemap) PutIPTable(table string) error { return nil } -func (r *PortRemap) ForceEnable() error { - err := r.PutIPTable("all") - if err != nil { - return err - } - - r.Enabled = true - return nil -} - -func (r *PortRemap) Disable() []error { +func (r *PortRemap) deleteIPTablesRules() []error { var errs []error err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName) @@ -79,16 +67,30 @@ func (r *PortRemap) Disable() []error { errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) } - r.Enabled = false return errs } +func (r *PortRemap) enable() error { + err := r.insertIPTablesRules("") + if err != nil { + return err + } + + r.enabled = true + return nil +} + func (r *PortRemap) Enable() error { - if r.Enabled { + if r.enabled { return nil } - err := r.ForceEnable() + err := r.IPTables.ClearChain("nat", r.ChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = r.enable() if err != nil { r.Disable() return err @@ -97,6 +99,19 @@ func (r *PortRemap) Enable() error { return nil } +func (r *PortRemap) Disable() []error { + errs := r.deleteIPTablesRules() + r.enabled = false + return errs +} + +func (r *PortRemap) NetfilerDHook(table string) error { + if !r.enabled { + return nil + } + return r.insertIPTablesRules(table) +} + func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap { return &PortRemap{ IPTables: nh.IPTables,