diff --git a/group.go b/group.go index 7e77ee5..1648bbd 100644 --- a/group.go +++ b/group.go @@ -16,9 +16,9 @@ type Group struct { Enabled bool - iptables *iptables.IPTables - ipset *netfilterHelper.IPSet - ifaceToIPSetNAT *netfilterHelper.IfaceToIPSet + iptables *iptables.IPTables + ipset *netfilterHelper.IPSet + ipsetToLink *netfilterHelper.IPSetToLink } func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error { @@ -51,7 +51,7 @@ func (g *Group) Enable() error { } } - err := g.ifaceToIPSetNAT.Enable() + err := g.ipsetToLink.Enable() if err != nil { return err } @@ -68,7 +68,7 @@ func (g *Group) Disable() []error { return nil } - err := g.ifaceToIPSetNAT.Disable() + err := g.ipsetToLink.Disable() if err != nil { errs = append(errs, err...) } diff --git a/kvas2.go b/kvas2.go index 73e6ab2..780147c 100644 --- a/kvas2.go +++ b/kvas2.go @@ -34,7 +34,6 @@ type Config struct { LinkName string TargetDNSServerAddress string ListenDNSPort uint16 - UseSoftwareRouting bool } type App struct { @@ -69,7 +68,7 @@ func (a *App) handleLink(event netlink.LinkUpdate) { continue } - err := group.ifaceToIPSetNAT.IfaceHandle() + err := group.ipsetToLink.LinkUpdateHook() if err != nil { log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up") } @@ -218,11 +217,9 @@ func (a *App) start(ctx context.Context) (err error) { } } for _, group := range a.Groups { - if group.ifaceToIPSetNAT.Enabled { - err := group.ifaceToIPSetNAT.PutIPTable(args[2]) - if err != nil { - log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") - } + err := group.ipsetToLink.NetfilerDHook(args[2]) + if err != nil { + log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } } } @@ -295,12 +292,11 @@ func (a *App) AddGroup(group *models.Group) error { } grp := &Group{ - Group: group, - iptables: a.NetfilterHelper4.IPTables, - ipset: ipset, - ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false), + Group: group, + iptables: a.NetfilterHelper4.IPTables, + ipset: ipset, + ipsetToLink: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false), } - grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting a.Groups[grp.ID] = grp return a.SyncGroup(grp) } diff --git a/netfilter-helper/interface-to-ipset.go b/netfilter-helper/interface-to-ipset.go deleted file mode 100644 index d14badb..0000000 --- a/netfilter-helper/interface-to-ipset.go +++ /dev/null @@ -1,295 +0,0 @@ -package netfilterHelper - -import ( - "fmt" - "github.com/coreos/go-iptables/iptables" - "github.com/rs/zerolog/log" - "github.com/vishvananda/netlink" - "github.com/vishvananda/netlink/nl" - "net" - "strconv" -) - -type IfaceToIPSet struct { - IPTables *iptables.IPTables - ChainName string - IfaceName string - IPSetName string - SoftwareMode bool - - Enabled bool - - mark uint32 - table int - ipRule *netlink.Rule - ipRoute *netlink.Route -} - -func (r *IfaceToIPSet) PutIPTable(table string) error { - var err error - - if !r.SoftwareMode { - if table == "all" || table == "mangle" { - err = r.IPTables.ClearChain("mangle", r.ChainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - for _, iptablesArgs := range [][]string{ - // Source: https://github.com/qzeleza/kvas/blob/3fdbbd1ace7b57b11bf88d8db3882d94a1d6e01c/opt/etc/ndm/ndm#L194-L206 - {"-m", "set", "!", "--match-set", r.IPSetName, "dst", "-j", "RETURN"}, - {"-j", "CONNMARK", "--restore-mark"}, - {"-m", "mark", "--mark", strconv.Itoa(int(r.mark)), "-j", "RETURN"}, - // This command not working - // {"--syn", "-j", "MARK", "--set-mark", strconv.Itoa(int(mark))}, - {"-m", "conntrack", "--ctstate", "NEW", "-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))}, - {"-j", "CONNMARK", "--save-mark"}, - } { - err = r.IPTables.AppendUnique("mangle", r.ChainName, iptablesArgs...) - if err != nil { - return fmt.Errorf("failed to append rule: %w", err) - } - } - - err = r.IPTables.AppendUnique("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) - if err != nil { - return fmt.Errorf("failed to append rule to PREROUTING: %w", err) - } - - err = r.IPTables.AppendUnique("mangle", "OUTPUT", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) - if err != nil { - return fmt.Errorf("failed to append rule to OUTPUT: %w", err) - } - } - } else { - if table == "all" || table == "mangle" { - preroutingChainName := fmt.Sprintf("%s_PRR", r.ChainName) - - err = r.IPTables.ClearChain("mangle", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - err = r.IPTables.AppendUnique("mangle", preroutingChainName, "-m", "set", "--match-set", r.IPSetName, "dst", "-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = r.IPTables.AppendUnique("mangle", preroutingChainName, "-m", "set", "--match-set", r.IPSetName, "dst", "-j", "CONNMARK", "--save-mark") - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = r.IPTables.AppendUnique("mangle", "PREROUTING", "-j", preroutingChainName) - if err != nil { - return fmt.Errorf("failed to append rule to PREROUTING: %w", err) - } - } - } - - if table == "all" || table == "nat" { - postroutingChainName := fmt.Sprintf("%s_POR", r.ChainName) - - err = r.IPTables.ClearChain("nat", postroutingChainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - err = r.IPTables.AppendUnique("nat", postroutingChainName, "-m", "mark", "--mark", strconv.Itoa(int(r.mark)), "-j", "MASQUERADE") - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = r.IPTables.AppendUnique("nat", "POSTROUTING", "-j", postroutingChainName) - if err != nil { - return fmt.Errorf("failed to append rule to POSTROUTING: %w", err) - } - } - - return nil -} - -func (r *IfaceToIPSet) IfaceHandle() error { - // Find interface - iface, err := netlink.LinkByName(r.IfaceName) - if err != nil { - log.Warn().Str("interface", r.IfaceName).Err(err).Msg("error while getting interface") - } - - // Mapping iface with table - if iface != nil { - route := &netlink.Route{ - LinkIndex: iface.Attrs().Index, - Table: r.table, - Dst: &net.IPNet{IP: []byte{0, 0, 0, 0}, Mask: []byte{0, 0, 0, 0}}, - } - // Delete rule if exists - err = netlink.RouteDel(route) - if err != nil { - log.Warn().Str("interface", r.IfaceName).Err(err).Msg("error while deleting route") - } - err = netlink.RouteAdd(route) - if err != nil { - return fmt.Errorf("error while mapping iface with table: %w", err) - } - r.ipRoute = route - } - - return nil -} - -func (r *IfaceToIPSet) ForceEnable() error { - // Release used mark and table - r.Disable() - r.mark = 0 - r.table = 0 - - // Find unused mark and table - markMap := make(map[uint32]struct{}) - tableMap := map[int]struct{}{0: {}, 253: {}, 254: {}, 255: {}} - - rules, err := netlink.RuleList(nl.FAMILY_ALL) - if err != nil { - return fmt.Errorf("error while getting rules: %w", err) - } - for _, rule := range rules { - markMap[rule.Mark] = struct{}{} - tableMap[rule.Table] = struct{}{} - } - - routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) - if err != nil { - return fmt.Errorf("error while getting routes: %w", err) - } - for _, route := range routes { - tableMap[route.Table] = struct{}{} - } - - for { - if _, exists := tableMap[r.table]; exists { - r.table++ - continue - } - break - } - - for { - if _, exists := markMap[r.mark]; exists { - r.mark++ - continue - } - break - } - - // IPTables rules - err = r.PutIPTable("all") - if err != nil { - return err - } - - // Mapping mark with table - rule := netlink.NewRule() - rule.Mark = r.mark - rule.Table = r.table - err = netlink.RuleAdd(rule) - if err != nil { - return fmt.Errorf("error while mapping mark with table: %w", err) - } - r.ipRule = rule - - err = r.IfaceHandle() - if err != nil { - return err - } - - r.Enabled = true - return nil -} - -func (r *IfaceToIPSet) Disable() []error { - var errs []error - var err error - - if !r.SoftwareMode { - err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete rule from PREROUTING: %w", err)) - } - - err = r.IPTables.DeleteIfExists("mangle", "OUTPUT", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete rule from OUTPUT: %w", err)) - } - - err = r.IPTables.ClearAndDeleteChain("mangle", r.ChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) - } - } else { - preroutingChainName := fmt.Sprintf("%s_PRR", r.ChainName) - - err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-j", preroutingChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete rule from PREROUTING: %w", err)) - } - - err = r.IPTables.ClearAndDeleteChain("mangle", preroutingChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) - } - } - - postroutingChainName := fmt.Sprintf("%s_POR", r.ChainName) - - err = r.IPTables.DeleteIfExists("nat", "POSTROUTING", "-j", postroutingChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) - } - - err = r.IPTables.ClearAndDeleteChain("nat", postroutingChainName) - if err != nil { - errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) - } - - if r.ipRule != nil { - err = netlink.RuleDel(r.ipRule) - if err != nil { - errs = append(errs, fmt.Errorf("error while deleting rule: %w", err)) - } - r.ipRule = nil - } - - if r.ipRoute != nil { - err = netlink.RouteDel(r.ipRoute) - if err != nil { - errs = append(errs, fmt.Errorf("error while deleting route: %w", err)) - } - r.ipRule = nil - } - - r.Enabled = false - return errs -} - -func (r *IfaceToIPSet) Enable() error { - if r.Enabled { - return nil - } - - err := r.ForceEnable() - if err != nil { - r.Disable() - return err - } - - return nil -} - -func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IfaceToIPSet { - return &IfaceToIPSet{ - IPTables: nh.IPTables, - ChainName: name, - IfaceName: ifaceName, - IPSetName: ipsetName, - } -} diff --git a/netfilter-helper/ipset-to-link.go b/netfilter-helper/ipset-to-link.go new file mode 100644 index 0000000..a877711 --- /dev/null +++ b/netfilter-helper/ipset-to-link.go @@ -0,0 +1,290 @@ +package netfilterHelper + +import ( + "fmt" + "net" + "strconv" + + "github.com/coreos/go-iptables/iptables" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" +) + +type IPSetToLink struct { + IPTables *iptables.IPTables + ChainName string + IfaceName string + IPSetName string + + enabled bool + mark uint32 + table int + ipRule *netlink.Rule + ipRoute *netlink.Route +} + +func (r *IPSetToLink) insertIPTablesRules(table string) error { + var err error + + if table == "" || table == "mangle" { + err = r.IPTables.NewChain("mangle", r.ChainName) + if err != nil { + // If not "AlreadyExists" + if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) { + return fmt.Errorf("failed to create chain: %w", err) + } + } + + for _, iptablesArgs := range [][]string{ + {"-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))}, + {"-j", "CONNMARK", "--save-mark"}, + } { + err = r.IPTables.AppendUnique("mangle", r.ChainName, iptablesArgs...) + if err != nil { + return fmt.Errorf("failed to append rule: %w", err) + } + } + + err = r.IPTables.InsertUnique("mangle", "PREROUTING", 1, "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + return fmt.Errorf("failed to append rule to PREROUTING: %w", err) + } + } + + if table == "" || table == "nat" { + err = r.IPTables.NewChain("nat", r.ChainName) + if err != nil { + // If not "AlreadyExists" + if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) { + return fmt.Errorf("failed to create chain: %w", err) + } + } + + err = r.IPTables.AppendUnique("nat", r.ChainName, "-j", "MASQUERADE") + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } + + err = r.IPTables.AppendUnique("nat", "POSTROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + return fmt.Errorf("failed to append rule to POSTROUTING: %w", err) + } + } + + return nil +} + +func (r *IPSetToLink) deleteIPTablesRules() []error { + var errs []error + var err error + + err = r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("mangle", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + + err = r.IPTables.DeleteIfExists("nat", "POSTROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) + } + + err = r.IPTables.ClearAndDeleteChain("nat", r.ChainName) + if err != nil { + errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) + } + + return errs +} + +func (r *IPSetToLink) insertIPRule() error { + rule := netlink.NewRule() + rule.Mark = r.mark + rule.Table = r.table + _ = netlink.RuleDel(rule) + err := netlink.RuleAdd(rule) + if err != nil { + return fmt.Errorf("error while mapping mark with table: %w", err) + } + r.ipRule = rule + + return nil +} + +func (r *IPSetToLink) deleteIPRule() []error { + if r.ipRule == nil { + return nil + } + + err := netlink.RuleDel(r.ipRule) + if err != nil { + return []error{fmt.Errorf("error while deleting rule: %w", err)} + } + r.ipRule = nil + return nil +} + +func (r *IPSetToLink) insertIPRoute() error { + // Find interface + iface, err := netlink.LinkByName(r.IfaceName) + if err != nil { + return fmt.Errorf("error while getting interface: %w", err) + } + + // Mapping iface with table + route := &netlink.Route{ + LinkIndex: iface.Attrs().Index, + Table: r.table, + Dst: &net.IPNet{IP: []byte{0, 0, 0, 0}, Mask: []byte{0, 0, 0, 0}}, + } + // Delete rule if exists + _ = netlink.RouteDel(route) + err = netlink.RouteAdd(route) + if err != nil { + return fmt.Errorf("error while mapping iface with table: %w", err) + } + r.ipRoute = route + + return nil +} + +func (r *IPSetToLink) deleteIPRoute() []error { + if r.ipRoute == nil { + return nil + } + + err := netlink.RouteDel(r.ipRoute) + if err != nil { + return []error{fmt.Errorf("error while deleting route: %w", err)} + } + r.ipRoute = nil + return nil +} + +func (r *IPSetToLink) getUnusedMarkAndTable() (mark uint32, table int, err error) { + // Find unused mark and table + markMap := make(map[uint32]struct{}) + tableMap := map[int]struct{}{0: {}, 253: {}, 254: {}, 255: {}} + + rules, err := netlink.RuleList(nl.FAMILY_ALL) + if err != nil { + return 0, 0, fmt.Errorf("error while getting rules: %w", err) + } + for _, rule := range rules { + markMap[rule.Mark] = struct{}{} + tableMap[rule.Table] = struct{}{} + } + + routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) + if err != nil { + return 0, 0, fmt.Errorf("error while getting routes: %w", err) + } + for _, route := range routes { + tableMap[route.Table] = struct{}{} + } + + for table = 0; table < 0x7ffffffe; table++ { + if _, exists := tableMap[table]; !exists { + break + } + } + + for mark = 0; mark < 0xfffffffe; mark++ { + if _, exists := markMap[mark]; !exists { + break + } + } + + return mark, table, nil +} + +func (r *IPSetToLink) enable() error { + // Release used mark and table + r.Disable() + + var err error + r.mark, r.table, err = r.getUnusedMarkAndTable() + if err != nil { + return err + } + + err = r.IPTables.ClearChain("mangle", r.ChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = r.IPTables.ClearChain("nat", r.ChainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + // IPTables rules + err = r.insertIPTablesRules("") + if err != nil { + return err + } + + err = r.insertIPRule() + if err != nil { + return err + } + + err = r.insertIPRoute() + if err != nil { + return err + } + + r.enabled = true + return nil +} + +func (r *IPSetToLink) Enable() error { + if r.enabled { + return nil + } + + err := r.enable() + if err != nil { + r.Disable() + return err + } + + return nil +} + +func (r *IPSetToLink) Disable() []error { + var errs []error + errs = append(errs, r.deleteIPRoute()...) + errs = append(errs, r.deleteIPRule()...) + errs = append(errs, r.deleteIPTablesRules()...) + + r.enabled = false + return errs +} + +func (r *IPSetToLink) NetfilerDHook(table string) error { + if !r.enabled { + return nil + } + return r.insertIPTablesRules(table) +} + +func (r *IPSetToLink) LinkUpdateHook() error { + if !r.enabled { + return nil + } + return r.insertIPRoute() +} + +func (nh *NetfilterHelper) IfaceToIPSet(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink { + return &IPSetToLink{ + IPTables: nh.IPTables, + ChainName: name, + IfaceName: ifaceName, + IPSetName: ipsetName, + } +}