From 33a232d07c11b676f87e3e63f3c9332ee48803a6 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Fri, 6 Sep 2024 14:24:55 +0300 Subject: [PATCH] refactoring --- group.go | 294 +++---------------- kvas2.go | 94 ++---- netfilter-helper/interface-to-ipset.go | 257 ++++++++++++++++ netfilter-helper/ipset.go | 55 ++++ netfilter-helper/netfiler-helper.go | 21 ++ netfilter-helper/port-remap.go | 76 +++++ recoverable-iptables/recoverable-iptables.go | 124 ++++++++ 7 files changed, 608 insertions(+), 313 deletions(-) create mode 100644 netfilter-helper/interface-to-ipset.go create mode 100644 netfilter-helper/ipset.go create mode 100644 netfilter-helper/netfiler-helper.go create mode 100644 netfilter-helper/port-remap.go create mode 100644 recoverable-iptables/recoverable-iptables.go diff --git a/group.go b/group.go index 2acf4d4..44f936d 100644 --- a/group.go +++ b/group.go @@ -2,53 +2,35 @@ package main import ( "fmt" + netfilterHelper "kvas2-go/netfilter-helper" "net" - "os" - "strconv" "time" "kvas2-go/models" - - "github.com/rs/zerolog/log" - "github.com/vishvananda/netlink" - "github.com/vishvananda/netlink/nl" ) -type GroupOptions struct { - Enabled bool - ipRule *netlink.Rule - ipRoute *netlink.Route -} - type Group struct { *models.Group - ipsetName string - options GroupOptions + + Enabled bool + + ipset *netfilterHelper.IPSet + ifaceToIPSet *netfilterHelper.IfaceToIPSet } func (g *Group) HandleIPv4(names []string, address net.IP, ttl time.Duration) error { - if !g.options.Enabled { + if !g.Enabled { return nil } - ttlSeconds := uint32(ttl.Seconds()) - -DomainSearch: for _, domain := range g.Domains { if !domain.IsEnabled() { continue } for _, name := range names { if domain.IsMatch(name) { - err := netlink.IpsetAdd(g.ipsetName, &netlink.IPSetEntry{ - IP: address, - Timeout: &ttlSeconds, - Replace: true, - }) - if err != nil { - return fmt.Errorf("failed to assign address: %w", err) - } - break DomainSearch + ttlSeconds := uint32(ttl.Seconds()) + return g.ipset.Add(address, &ttlSeconds) } } } @@ -56,247 +38,65 @@ DomainSearch: return nil } -func (g *Group) Enable(a *App) error { - if g.options.Enabled { +func (g *Group) Enable() error { + if g.Enabled { return nil } + defer func() { + if !g.Enabled { + _ = g.Disable() + } + }() - var err error - - markMap := make(map[uint32]struct{}) - tableMap := map[int]struct{}{ - 0: {}, - 253: {}, - 254: {}, - 255: {}, - } - var table int - var mark uint32 - - rules, err := netlink.RuleList(nl.FAMILY_ALL) + err := g.ipset.Create() if err != nil { - return fmt.Errorf("error while getting rules: %w", err) - } - for _, rule := range rules { - markMap[rule.Mark] = struct{}{} - tableMap[rule.Table] = struct{}{} + return err } - routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) + err = g.ifaceToIPSet.Enable() if err != nil { - return fmt.Errorf("error while getting routes: %w", err) - } - for _, route := range routes { - tableMap[route.Table] = struct{}{} + return err } - for { - if _, exists := tableMap[table]; exists { - table++ - continue - } - break - } - - for { - if _, exists := markMap[mark]; exists { - mark++ - continue - } - break - } - - rule := netlink.NewRule() - rule.Mark = mark - rule.Table = table - if err != nil { - return fmt.Errorf("error while getting free table: %w", err) - } - err = netlink.RuleAdd(rule) - if err != nil { - return fmt.Errorf("error while adding rule: %w", err) - } - g.options.ipRule = rule - - iface, err := netlink.LinkByName(g.Interface) - if err != nil { - log.Warn().Str("interface", g.Interface).Msg("error while getting interface") - } - - if iface != nil { - route := &netlink.Route{ - LinkIndex: iface.Attrs().Index, - Table: rule.Table, - Dst: &net.IPNet{ - IP: []byte{0, 0, 0, 0}, - Mask: []byte{0, 0, 0, 0}, - }, - } - err = netlink.RouteAdd(route) - if err != nil { - return fmt.Errorf("error while adding route: %w", err) - } - g.options.ipRoute = route - } - - defaultTimeout := uint32(300) - err = netlink.IpsetDestroy(g.ipsetName) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to destroy ipset: %w", err) - } - err = netlink.IpsetCreate(g.ipsetName, "hash:ip", netlink.IpsetCreateOptions{ - Timeout: &defaultTimeout, - }) - if err != nil { - return fmt.Errorf("failed to create ipset: %w", err) - } - - if !a.Config.UseSoftwareRouting { - chainName := fmt.Sprintf("%sROUTING_%d", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.ClearChain("mangle", chainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - for _, iptablesArgs := range [][]string{ - // Source: https://github.com/qzeleza/kvas/blob/3fdbbd1ace7b57b11bf88d8db3882d94a1d6e01c/opt/etc/ndm/ndm#L194-L206 - {"-m", "set", "!", "--match-set", g.ipsetName, "dst", "-j", "RETURN"}, - {"-j", "CONNMARK", "--restore-mark"}, - {"-m", "mark", "--mark", strconv.Itoa(int(mark)), "-j", "RETURN"}, - // This command not working - // {"--syn", "-j", "MARK", "--set-mark", strconv.Itoa(int(mark))}, - {"-m", "conntrack", "--ctstate", "NEW", "-j", "MARK", "--set-mark", strconv.Itoa(int(mark))}, - {"-j", "CONNMARK", "--save-mark"}, - } { - err = a.IPTables.AppendUnique("mangle", chainName, iptablesArgs...) - if err != nil { - return fmt.Errorf("failed to append rule: %w", err) - } - } - - err = a.IPTables.AppendUnique("mangle", "PREROUTING", "-m", "set", "--match-set", g.ipsetName, "dst", "-j", chainName) - if err != nil { - return fmt.Errorf("failed to append rule to PREROUTING: %w", err) - } - - err = a.IPTables.AppendUnique("mangle", "OUTPUT", "-m", "set", "--match-set", g.ipsetName, "dst", "-j", chainName) - if err != nil { - return fmt.Errorf("failed to append rule to OUTPUT: %w", err) - } - } else { - preroutingChainName := fmt.Sprintf("%sROUTING_%d_PREROUTING", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.ClearChain("mangle", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - err = a.IPTables.AppendUnique("mangle", preroutingChainName, "-m", "set", "--match-set", g.ipsetName, "dst", "-j", "MARK", "--set-mark", strconv.Itoa(int(mark))) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = a.IPTables.AppendUnique("mangle", "PREROUTING", "-j", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to linking chain: %w", err) - } - } - - postroutingChainName := fmt.Sprintf("%sROUTING_%d_POSTROUTING", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.ClearChain("nat", postroutingChainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - err = a.IPTables.AppendUnique("nat", postroutingChainName, "-o", g.Interface, "-j", "MASQUERADE") - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = a.IPTables.AppendUnique("nat", "POSTROUTING", "-j", postroutingChainName) - if err != nil { - return fmt.Errorf("failed to linking chain: %w", err) - } - - g.options.Enabled = true + g.Enabled = true return nil } -func (g *Group) Disable(a *App) error { - if !g.options.Enabled { +func (g *Group) Disable() []error { + var errs []error + + if !g.Enabled { return nil } - var err error - - if !a.Config.UseSoftwareRouting { - chainName := fmt.Sprintf("%sROUTING_%d", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", g.ipsetName, "dst", "-j", chainName) - if err != nil { - return fmt.Errorf("failed to delete rule to PREROUTING: %w", err) - } - - err = a.IPTables.DeleteIfExists("mangle", "OUTPUT", "-m", "set", "--match-set", g.ipsetName, "dst", "-j", chainName) - if err != nil { - return fmt.Errorf("failed to delete rule to OUTPUT: %w", err) - } - - err = a.IPTables.ClearAndDeleteChain("mangle", chainName) - if err != nil { - return fmt.Errorf("failed to delete chain: %w", err) - } - } else { - preroutingChainName := fmt.Sprintf("%sROUTING_%d_PREROUTING", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.DeleteIfExists("mangle", "PREROUTING", "-j", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to unlinking chain: %w", err) - } - - err = a.IPTables.ClearAndDeleteChain("mangle", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to delete chain: %w", err) - } - } - - postroutingChainName := fmt.Sprintf("%sROUTING_%d_POSTROUTING", a.Config.ChainPostfix, g.ID) - - err = a.IPTables.DeleteIfExists("nat", "POSTROUTING", "-j", postroutingChainName) + err := g.ipset.Destroy() if err != nil { - return fmt.Errorf("failed to unlinking chain: %w", err) + errs = append(errs, err) } - err = a.IPTables.ClearAndDeleteChain("nat", postroutingChainName) - if err != nil { - return fmt.Errorf("failed to delete chain: %w", err) + errs2 := g.ifaceToIPSet.Disable() + if errs2 != nil { + errs = append(errs, errs2...) } - if g.options.ipRule != nil { - err = netlink.RuleDel(g.options.ipRule) - if err != nil { - return fmt.Errorf("error while deleting rule: %w", err) - } - g.options.ipRule = nil - } - - if g.options.ipRoute != nil { - err = netlink.RouteDel(g.options.ipRoute) - if err != nil { - return fmt.Errorf("error while deleting route: %w", err) - } - g.options.ipRule = nil - } - - err = netlink.IpsetDestroy(g.ipsetName) - if err != nil && !os.IsNotExist(err) { - return fmt.Errorf("failed to destroy ipset: %w", err) - } - - g.options.Enabled = false + g.Enabled = false + + return nil +} + +func (a *App) AddGroup(group *models.Group) error { + if _, exists := a.Groups[group.ID]; exists { + return ErrGroupIDConflict + } + + ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID) + + a.Groups[group.ID] = &Group{ + Group: group, + ipset: a.NetfilterHelper.IPSet(ipsetName), + ifaceToIPSet: a.NetfilterHelper.IfaceToIPSet(fmt.Sprintf("%sROUTING_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false), + } return nil } diff --git a/kvas2.go b/kvas2.go index b900c94..80f3bcd 100644 --- a/kvas2.go +++ b/kvas2.go @@ -5,14 +5,13 @@ import ( "errors" "fmt" "net" - "strconv" "sync" "time" "kvas2-go/dns-proxy" "kvas2-go/models" + "kvas2-go/netfilter-helper" - "github.com/coreos/go-iptables/iptables" "github.com/rs/zerolog/log" ) @@ -33,12 +32,13 @@ type Config struct { type App struct { Config Config - DNSProxy *dnsProxy.DNSProxy - IPTables *iptables.IPTables - Records *Records - Groups map[int]*Group + DNSProxy *dnsProxy.DNSProxy + NetfilterHelper *netfilterHelper.NetfilterHelper + Records *Records + Groups map[int]*Group - isRunning bool + isRunning bool + dnsOverrider *netfilterHelper.PortRemap } func (a *App) Listen(ctx context.Context) []error { @@ -60,6 +60,13 @@ func (a *App) Listen(ctx context.Context) []error { errs = append(errs, err) once.Do(func() { close(isError) }) } + handleErrors := func(errs2 []error) { + errsMu.Lock() + defer errsMu.Unlock() + + errs = append(errs, errs2...) + once.Do(func() { close(isError) }) + } defer func() { if r := recover(); r != nil { @@ -74,28 +81,11 @@ func (a *App) Listen(ctx context.Context) []error { newCtx, cancel := context.WithCancel(ctx) defer cancel() - chainName := fmt.Sprintf("%sDNSOVERRIDER", a.Config.ChainPostfix) - - err := a.IPTables.ClearChain("nat", chainName) - if err != nil { - handleError(fmt.Errorf("failed to clear chain: %w", err)) - return errs - } - - err = a.IPTables.AppendUnique("nat", chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(a.Config.ListenPort))) - if err != nil { - handleError(fmt.Errorf("failed to create rule: %w", err)) - return errs - } - - err = a.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", chainName) - if err != nil { - handleError(fmt.Errorf("failed to linking chain: %w", err)) - return errs - } + a.dnsOverrider = a.NetfilterHelper.PortRemap(fmt.Sprintf("%sDNSOVERRIDER", a.Config.ChainPostfix), 53, a.Config.ListenPort) + err := a.dnsOverrider.Enable() for idx, _ := range a.Groups { - err = a.Groups[idx].Enable(a) + err = a.Groups[idx].Enable() if err != nil { handleError(fmt.Errorf("failed to enable group: %w", err)) return errs @@ -113,49 +103,21 @@ func (a *App) Listen(ctx context.Context) []error { case <-isError: } + errs2 := a.dnsOverrider.Disable() + if errs2 != nil { + handleErrors(errs2) + } + for idx, _ := range a.Groups { - err = a.Groups[idx].Disable(a) - if err != nil { - handleError(fmt.Errorf("failed to disable group: %w", err)) - return errs + errs2 = a.Groups[idx].Disable() + if errs2 != nil { + handleErrors(errs2) } } - err = a.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", chainName) - if err != nil { - handleError(fmt.Errorf("failed to unlinking chain: %w", err)) - return errs - } - - err = a.IPTables.ClearAndDeleteChain("nat", chainName) - if err != nil { - handleError(fmt.Errorf("failed to delete chain: %w", err)) - return errs - } - return errs } -func (a *App) AppendGroup(group *models.Group) error { - if _, exists := a.Groups[group.ID]; exists { - return ErrGroupIDConflict - } - - a.Groups[group.ID] = &Group{ - Group: group, - ipsetName: fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID), - } - - if a.isRunning { - err := a.Groups[group.ID].Enable(a) - if err != nil { - return fmt.Errorf("failed to enable appended group: %w", err) - } - } - - return nil -} - func (a *App) ListInterfaces() ([]net.Interface, error) { interfaceNames := make([]net.Interface, 0) @@ -246,11 +208,11 @@ func New(config Config) (*App, error) { app.Records = NewRecords() - ipt, err := iptables.New() + nh, err := netfilterHelper.New() if err != nil { - return nil, fmt.Errorf("iptables init fail: %w", err) + return nil, fmt.Errorf("netfilter helper init fail: %w", err) } - app.IPTables = ipt + app.NetfilterHelper = nh app.Groups = make(map[int]*Group) diff --git a/netfilter-helper/interface-to-ipset.go b/netfilter-helper/interface-to-ipset.go new file mode 100644 index 0000000..2c3c6d1 --- /dev/null +++ b/netfilter-helper/interface-to-ipset.go @@ -0,0 +1,257 @@ +package netfilterHelper + +import ( + "fmt" + "github.com/coreos/go-iptables/iptables" + "github.com/rs/zerolog/log" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" + "net" + "strconv" +) + +type IfaceToIPSet struct { + IPTables *iptables.IPTables + ChainName string + IfaceName string + IPSetName string + SoftwareMode bool + + Enabled bool + + mark uint32 + table int + ipRule *netlink.Rule + ipRoute *netlink.Route +} + +func (r *IfaceToIPSet) ForceEnable() error { + // Release used mark and table + r.Disable() + r.mark = 0 + r.table = 0 + + // Find interface + iface, err := netlink.LinkByName(r.IfaceName) + if err != nil { + log.Warn().Str("interface", r.IfaceName).Msg("error while getting interface") + } + + // Find unused mark and table + markMap := make(map[uint32]struct{}) + tableMap := map[int]struct{}{0: {}, 253: {}, 254: {}, 255: {}} + + rules, err := netlink.RuleList(nl.FAMILY_ALL) + if err != nil { + return fmt.Errorf("error while getting rules: %w", err) + } + for _, rule := range rules { + markMap[rule.Mark] = struct{}{} + tableMap[rule.Table] = struct{}{} + } + + routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) + if err != nil { + return fmt.Errorf("error while getting routes: %w", err) + } + for _, route := range routes { + tableMap[route.Table] = struct{}{} + } + + for { + if _, exists := tableMap[r.table]; exists { + r.table++ + continue + } + break + } + + for { + if _, exists := markMap[r.mark]; exists { + r.mark++ + continue + } + break + } + + // IPTables rules + if !r.SoftwareMode { + err = r.IPTables.ClearChain("mangle", r.ChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + for _, iptablesArgs := range [][]string{ + // Source: https://github.com/qzeleza/kvas/blob/3fdbbd1ace7b57b11bf88d8db3882d94a1d6e01c/opt/etc/ndm/ndm#L194-L206 + {"-m", "set", "!", "--match-set", r.IPSetName, "dst", "-j", "RETURN"}, + {"-j", "CONNMARK", "--restore-mark"}, + {"-m", "mark", "--mark", strconv.Itoa(int(r.mark)), "-j", "RETURN"}, + // This command not working + // {"--syn", "-j", "MARK", "--set-mark", strconv.Itoa(int(mark))}, + {"-m", "conntrack", "--ctstate", "NEW", "-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))}, + {"-j", "CONNMARK", "--save-mark"}, + } { + err = r.IPTables.AppendUnique("mangle", r.ChainName, iptablesArgs...) + if err != nil { + return fmt.Errorf("failed to append rule: %w", err) + } + } + + err = r.IPTables.AppendUnique("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + return fmt.Errorf("failed to append rule to PREROUTING: %w", err) + } + + err = r.IPTables.AppendUnique("mangle", "OUTPUT", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + return fmt.Errorf("failed to append rule to OUTPUT: %w", err) + } + } else { + preroutingChainName := fmt.Sprintf("%s_PREROUTING", r.ChainName) + + err = r.IPTables.ClearChain("mangle", preroutingChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = r.IPTables.AppendUnique("mangle", preroutingChainName, "-m", "set", "--match-set", r.IPSetName, "dst", "-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))) + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } + + err = r.IPTables.AppendUnique("mangle", "PREROUTING", "-j", preroutingChainName) + if err != nil { + return fmt.Errorf("failed to append rule to PREROUTING: %w", err) + } + } + + postroutingChainName := fmt.Sprintf("%s_POSTROUTING", r.ChainName) + + err = r.IPTables.ClearChain("nat", postroutingChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = r.IPTables.AppendUnique("nat", postroutingChainName, "-o", r.IfaceName, "-j", "MASQUERADE") + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } + + err = r.IPTables.AppendUnique("nat", "POSTROUTING", "-j", postroutingChainName) + if err != nil { + return fmt.Errorf("failed to append rule to POSTROUTING: %w", err) + } + + // Mapping mark with table + rule := netlink.NewRule() + rule.Mark = r.mark + rule.Table = r.table + err = netlink.RuleAdd(rule) + if err != nil { + return fmt.Errorf("error while mapping mark with table: %w", err) + } + r.ipRule = rule + + // Mapping iface with table + if iface != nil { + route := &netlink.Route{ + LinkIndex: iface.Attrs().Index, + Table: rule.Table, + Dst: &net.IPNet{IP: []byte{0, 0, 0, 0}, Mask: []byte{0, 0, 0, 0}}, + } + err = netlink.RouteAdd(route) + if err != nil { + return fmt.Errorf("error while mapping iface with table: %w", err) + } + r.ipRoute = route + } + + return nil +} + +func (r *IfaceToIPSet) Disable() []error { + var errs []error + var err error + + if !r.SoftwareMode { + err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete rule from PREROUTING: %w", err)) + } + + err = r.IPTables.DeleteIfExists("mangle", "OUTPUT", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete rule from OUTPUT: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("mangle", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + } else { + preroutingChainName := fmt.Sprintf("%s_PREROUTING", r.ChainName) + + err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-j", preroutingChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete rule from PREROUTING: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("mangle", preroutingChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + } + + postroutingChainName := fmt.Sprintf("%s_POSTROUTING", r.ChainName) + + err = r.IPTables.DeleteIfExists("nat", "POSTROUTING", "-j", postroutingChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("nat", postroutingChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + + if r.ipRule != nil { + err = netlink.RuleDel(r.ipRule) + if err != nil { + errs = append(errs, fmt.Errorf("error while deleting rule: %w", err)) + } + r.ipRule = nil + } + + if r.ipRoute != nil { + err = netlink.RouteDel(r.ipRoute) + if err != nil { + errs = append(errs, fmt.Errorf("error while deleting route: %w", err)) + } + r.ipRule = nil + } + + return errs +} + +func (r *IfaceToIPSet) Enable() error { + if r.Enabled { + return nil + } + + err := r.ForceEnable() + if err != nil { + r.Disable() + return err + } + + return nil +} + +func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IfaceToIPSet { + return &IfaceToIPSet{ + IPTables: nh.IPTables, + ChainName: name, + IfaceName: ifaceName, + IPSetName: ipsetName, + } +} diff --git a/netfilter-helper/ipset.go b/netfilter-helper/ipset.go new file mode 100644 index 0000000..81979c5 --- /dev/null +++ b/netfilter-helper/ipset.go @@ -0,0 +1,55 @@ +package netfilterHelper + +import ( + "fmt" + "github.com/vishvananda/netlink" + "net" + "os" +) + +type IPSet struct { + SetName string +} + +func (r *IPSet) Add(addr net.IP, timeout *uint32) error { + err := netlink.IpsetAdd(r.SetName, &netlink.IPSetEntry{ + IP: addr, + Timeout: timeout, + Replace: true, + }) + if err != nil { + return fmt.Errorf("failed to add address: %w", err) + } + return nil +} + +func (r *IPSet) Create() error { + err := r.Destroy() + if err != nil { + return err + } + + defaultTimeout := uint32(300) + err = netlink.IpsetCreate(r.SetName, "hash:ip", netlink.IpsetCreateOptions{ + Timeout: &defaultTimeout, + }) + if err != nil { + return fmt.Errorf("failed to create ipset: %w", err) + } + + return nil +} + +func (r *IPSet) Destroy() error { + err := netlink.IpsetDestroy(r.SetName) + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("failed to destroy ipset: %w", err) + } + return nil +} + +func (nh *NetfilterHelper) IPSet(name string) *IPSet { + return &IPSet{ + SetName: name, + } +} diff --git a/netfilter-helper/netfiler-helper.go b/netfilter-helper/netfiler-helper.go new file mode 100644 index 0000000..c3cb9a0 --- /dev/null +++ b/netfilter-helper/netfiler-helper.go @@ -0,0 +1,21 @@ +package netfilterHelper + +import ( + "fmt" + "github.com/coreos/go-iptables/iptables" +) + +type NetfilterHelper struct { + IPTables *iptables.IPTables +} + +func New() (*NetfilterHelper, error) { + ipt, err := iptables.New() + if err != nil { + return nil, fmt.Errorf("iptables init fail: %w", err) + } + + return &NetfilterHelper{ + IPTables: ipt, + }, nil +} diff --git a/netfilter-helper/port-remap.go b/netfilter-helper/port-remap.go new file mode 100644 index 0000000..5d8729d --- /dev/null +++ b/netfilter-helper/port-remap.go @@ -0,0 +1,76 @@ +package netfilterHelper + +import ( + "fmt" + "github.com/coreos/go-iptables/iptables" + "strconv" +) + +type PortRemap struct { + IPTables *iptables.IPTables + ChainName string + From uint16 + To uint16 + + Enabled bool +} + +func (r *PortRemap) ForceEnable() error { + err := r.IPTables.ClearChain("nat", r.ChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "--dport", strconv.Itoa(int(r.From)), "-j", "REDIRECT", "--to-port", strconv.Itoa(int(r.To))) + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } + + err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName) + if err != nil { + return fmt.Errorf("failed to linking chain: %w", err) + } + + r.Enabled = true + return nil +} + +func (r *PortRemap) Disable() []error { + var errs []error + + err := r.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("nat", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + + r.Enabled = false + return errs +} + +func (r *PortRemap) Enable() error { + if r.Enabled { + return nil + } + + err := r.ForceEnable() + if err != nil { + r.Disable() + return err + } + + return nil +} + +func (nh *NetfilterHelper) PortRemap(name string, from, to uint16) *PortRemap { + return &PortRemap{ + IPTables: nh.IPTables, + ChainName: name, + From: from, + To: to, + } +} diff --git a/recoverable-iptables/recoverable-iptables.go b/recoverable-iptables/recoverable-iptables.go new file mode 100644 index 0000000..4d6f19f --- /dev/null +++ b/recoverable-iptables/recoverable-iptables.go @@ -0,0 +1,124 @@ +package recoverableIPTables + +import ( + "github.com/coreos/go-iptables/iptables" + "reflect" +) + +type IPTablesRule struct { + Position int + RuleSpec []string +} + +type IPTables struct { + ipt *iptables.IPTables + cache map[string]map[string][]IPTablesRule +} + +/* + * Chain + */ + +func (r *IPTables) clearChain(table, chain string) { + if r.cache[table] == nil { + r.cache[table] = make(map[string][]IPTablesRule) + } + r.cache[table][chain] = nil +} + +func (r *IPTables) delChain(table, chain string) { + if r.cache[table] == nil { + return + } + delete(r.cache[table], chain) +} + +/* + * Rule + */ + +func (r *IPTables) delRule(table, chain string, rulespec ...string) { + if r.cache[table] == nil { + r.cache[table] = make(map[string][]IPTablesRule) + } + for idx, rule := range r.cache[table][chain] { + if !reflect.DeepEqual(rulespec, rule.RuleSpec) { + continue + } + copy(r.cache[table][chain][idx:], r.cache[table][chain][idx+1:]) + r.cache[table][chain] = r.cache[table][chain][:len(r.cache[table][chain])-1] + break + } +} + +func (r *IPTables) addRuleUnique(table, chain string, position int, rulespec ...string) { + if r.cache[table] == nil { + r.cache[table] = make(map[string][]IPTablesRule) + } + for _, rule := range r.cache[table][chain] { + if reflect.DeepEqual(rulespec, rule.RuleSpec) { + return + } + } + r.cache[table][chain] = append(r.cache[table][chain], IPTablesRule{ + Position: position, + RuleSpec: rulespec, + }) +} + +/* + * Mappings + */ + +func (r *IPTables) ClearChain(table, chain string) error { + err := r.ipt.ClearChain(table, chain) + if err != nil { + return err + } + r.clearChain(table, chain) + return nil +} + +func (r *IPTables) ClearAndDeleteChain(table, chain string) error { + err := r.ipt.ClearAndDeleteChain(table, chain) + if err != nil { + return err + } + r.delChain(table, chain) + return nil +} + +func (r *IPTables) AppendUnique(table, chain string, rulespec ...string) error { + err := r.ipt.AppendUnique(table, chain, rulespec...) + if err != nil { + return err + } + r.addRuleUnique(table, chain, 0, rulespec...) + return nil +} + +func (r *IPTables) InsertUnique(table, chain string, pos int, rulespec ...string) error { + err := r.ipt.InsertUnique(table, chain, pos, rulespec...) + if err != nil { + return err + } + r.addRuleUnique(table, chain, pos, rulespec...) + return nil +} + +func (r *IPTables) DeleteIfExists(table, chain string, rulespec ...string) error { + err := r.ipt.DeleteIfExists(table, chain, rulespec...) + if err != nil { + return err + } + r.delRule(table, chain, rulespec...) + return nil +} + +func New() (*IPTables, error) { + ipt, err := iptables.New() + return &IPTables{ + ipt: ipt, + cache: make(map[string]map[string][]IPTablesRule), + }, err +}