refactor PortRemap

This commit is contained in:
Vladimir Avtsenov 2025-02-11 15:53:58 +03:00
parent b537007c9a
commit ff6ab7b859
2 changed files with 54 additions and 43 deletions

View File

@ -204,17 +204,13 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":") args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" { if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event") log.Debug().Str("table", args[2]).Msg("netfilter.d event")
if a.dnsOverrider4.Enabled { err = a.dnsOverrider4.NetfilerDHook(args[2])
err := a.dnsOverrider4.PutIPTable(args[2]) if err != nil {
if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
if a.dnsOverrider6.Enabled { err = a.dnsOverrider6.NetfilerDHook(args[2])
err = a.dnsOverrider6.PutIPTable(args[2]) if err != nil {
if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
} }
for _, group := range a.Groups { for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2]) err := group.ipsetToLink.NetfilerDHook(args[2])

View File

@ -16,34 +16,32 @@ type PortRemap struct {
From uint16 From uint16
To uint16 To uint16
Enabled bool enabled bool
} }
func (r *PortRemap) PutIPTable(table string) error { func (r *PortRemap) insertIPTablesRules(table string) error {
if table == "all" || table == "nat" { if table == "" || table == "nat" {
err := r.IPTables.ClearChain("nat", r.ChainName) err := r.IPTables.NewChain("nat", r.ChainName)
if err != nil { if err != nil {
return fmt.Errorf("failed to clear chain: %w", err) // If not "AlreadyExists"
if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) {
return fmt.Errorf("failed to create chain: %w", err)
}
} }
for _, addr := range r.Addresses { for _, addr := range r.Addresses {
var addrIP net.IP if !((r.IPTables.Proto() == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (r.IPTables.Proto() == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len)) {
iptablesProtocol := r.IPTables.Proto()
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
addrIP = addr.IP
}
if addrIP == nil {
continue continue
} }
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) for _, iptablesArgs := range [][]string{
if err != nil { {"-p", "tcp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
return fmt.Errorf("failed to create rule: %w", err) {"-p", "udp", "-d", addr.IP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)},
} } {
err = r.IPTables.AppendUnique("nat", r.ChainName, iptablesArgs...)
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) if err != nil {
if err != nil { return fmt.Errorf("failed to append rule: %w", err)
return fmt.Errorf("failed to create rule: %w", err) }
} }
} }
@ -56,17 +54,7 @@ func (r *PortRemap) PutIPTable(table string) error {
return nil return nil
} }
func (r *PortRemap) ForceEnable() error { func (r *PortRemap) deleteIPTablesRules() []error {
err := r.PutIPTable("all")
if err != nil {
return err
}
r.Enabled = true
return nil
}
func (r *PortRemap) Disable() []error {
var errs []error var errs []error
err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName) err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName)
@ -79,16 +67,30 @@ func (r *PortRemap) Disable() []error {
errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) errs = append(errs, fmt.Errorf("failed to delete chain: %w", err))
} }
r.Enabled = false
return errs return errs
} }
func (r *PortRemap) enable() error {
err := r.insertIPTablesRules("")
if err != nil {
return err
}
r.enabled = true
return nil
}
func (r *PortRemap) Enable() error { func (r *PortRemap) Enable() error {
if r.Enabled { if r.enabled {
return nil return nil
} }
err := r.ForceEnable() err := r.IPTables.ClearChain("nat", r.ChainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = r.enable()
if err != nil { if err != nil {
r.Disable() r.Disable()
return err return err
@ -97,6 +99,19 @@ func (r *PortRemap) Enable() error {
return nil return nil
} }
func (r *PortRemap) Disable() []error {
errs := r.deleteIPTablesRules()
r.enabled = false
return errs
}
func (r *PortRemap) NetfilerDHook(table string) error {
if !r.enabled {
return nil
}
return r.insertIPTablesRules(table)
}
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap { func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
return &PortRemap{ return &PortRemap{
IPTables: nh.IPTables, IPTables: nh.IPTables,