package netfilterHelper import ( "fmt" "net" "strconv" "github.com/coreos/go-iptables/iptables" "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" "github.com/vishvananda/netlink/nl" ) type IPSetToLink struct { IPTables *iptables.IPTables ChainName string IfaceName string IPSetName string enabled bool mark uint32 table int ipRule *netlink.Rule ipRoute *netlink.Route } func (r *IPSetToLink) insertIPTablesRules(table string) error { var err error if table == "" || table == "mangle" { err = r.IPTables.NewChain("mangle", r.ChainName) if err != nil { // If not "AlreadyExists" if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) { return fmt.Errorf("failed to create chain: %w", err) } } for _, iptablesArgs := range [][]string{ {"-j", "MARK", "--set-mark", strconv.Itoa(int(r.mark))}, {"-j", "CONNMARK", "--save-mark"}, } { err = r.IPTables.AppendUnique("mangle", r.ChainName, iptablesArgs...) if err != nil { return fmt.Errorf("failed to append rule: %w", err) } } err = r.IPTables.InsertUnique("mangle", "PREROUTING", 1, "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) if err != nil { return fmt.Errorf("failed to append rule to PREROUTING: %w", err) } } if table == "" || table == "nat" { err = r.IPTables.NewChain("nat", r.ChainName) if err != nil { // If not "AlreadyExists" if eerr, eok := err.(*iptables.Error); !(eok && eerr.ExitStatus() == 1) { return fmt.Errorf("failed to create chain: %w", err) } } err = r.IPTables.AppendUnique("nat", r.ChainName, "-j", "MASQUERADE") if err != nil { return fmt.Errorf("failed to create rule: %w", err) } err = r.IPTables.AppendUnique("nat", "POSTROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) if err != nil { return fmt.Errorf("failed to append rule to POSTROUTING: %w", err) } } return nil } func (r *IPSetToLink) deleteIPTablesRules() []error { var errs []error err := r.IPTables.DeleteIfExists("mangle", "PREROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) if err != nil { errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) } err = r.IPTables.ClearAndDeleteChain("mangle", r.ChainName) if err != nil { errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) } err = r.IPTables.DeleteIfExists("nat", "POSTROUTING", "-m", "set", "--match-set", r.IPSetName, "dst", "-j", r.ChainName) if err != nil { errs = append(errs, fmt.Errorf("failed to unlinking chain: %w", err)) } err = r.IPTables.ClearAndDeleteChain("nat", r.ChainName) if err != nil { errs = append(errs, fmt.Errorf("failed to delete chain: %w", err)) } return errs } func (r *IPSetToLink) insertIPRule() error { rule := netlink.NewRule() rule.Mark = r.mark rule.Table = r.table _ = netlink.RuleDel(rule) err := netlink.RuleAdd(rule) if err != nil { return fmt.Errorf("error while mapping mark with table: %w", err) } r.ipRule = rule log.Trace().Int("table", r.table).Int("mark", int(r.mark)).Msg("using ip table and mark") return nil } func (r *IPSetToLink) deleteIPRule() []error { if r.ipRule == nil { return nil } err := netlink.RuleDel(r.ipRule) if err != nil { return []error{fmt.Errorf("error while deleting rule: %w", err)} } r.ipRule = nil return nil } func (r *IPSetToLink) insertIPRoute() error { // Find interface iface, err := netlink.LinkByName(r.IfaceName) if err != nil { // TODO: Нормально отлавливать ошибку if err.Error() == "Link not found" { log.Debug().Str("iface", r.IfaceName).Msg("interface not found (waiting for it to exist)") return nil } return fmt.Errorf("error while getting interface: %w", err) } // Mapping iface with table route := &netlink.Route{ LinkIndex: iface.Attrs().Index, Table: r.table, Dst: &net.IPNet{IP: []byte{0, 0, 0, 0}, Mask: []byte{0, 0, 0, 0}}, } // Delete rule if exists err = netlink.RouteAdd(route) if err != nil { // TODO: Нормально отлавливать ошибку if err.Error() == "file exists" { return nil } return fmt.Errorf("error while mapping iface with table: %w", err) } r.ipRoute = route return nil } func (r *IPSetToLink) deleteIPRoute() []error { if r.ipRoute == nil { return nil } err := netlink.RouteDel(r.ipRoute) if err != nil { return []error{fmt.Errorf("error while deleting route: %w", err)} } r.ipRoute = nil return nil } func (r *IPSetToLink) getUnusedMarkAndTable() (mark uint32, table int, err error) { // Find unused mark and table markMap := make(map[uint32]struct{}) tableMap := map[int]struct{}{0: {}, 253: {}, 254: {}, 255: {}} rules, err := netlink.RuleList(nl.FAMILY_ALL) if err != nil { return 0, 0, fmt.Errorf("error while getting rules: %w", err) } for _, rule := range rules { markMap[rule.Mark] = struct{}{} tableMap[rule.Table] = struct{}{} } routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) if err != nil { return 0, 0, fmt.Errorf("error while getting routes: %w", err) } for _, route := range routes { tableMap[route.Table] = struct{}{} } for table = 0; table < 0x7ffffffe; table++ { if _, exists := tableMap[table]; !exists { break } } for mark = 0; mark < 0xfffffffe; mark++ { if _, exists := markMap[mark]; !exists { break } } return mark, table, nil } func (r *IPSetToLink) enable() error { // Release used mark and table r.Disable() var err error r.mark, r.table, err = r.getUnusedMarkAndTable() if err != nil { return err } err = r.IPTables.ClearChain("mangle", r.ChainName) if err != nil { return fmt.Errorf("failed to clear chain: %w", err) } err = r.IPTables.ClearChain("nat", r.ChainName) if err != nil { return fmt.Errorf("failed to clear chain: %w", err) } // IPTables rules err = r.insertIPTablesRules("") if err != nil { return err } err = r.insertIPRule() if err != nil { return err } err = r.insertIPRoute() if err != nil { return err } r.enabled = true return nil } func (r *IPSetToLink) Enable() error { if r.enabled { return nil } err := r.enable() if err != nil { r.Disable() return err } return nil } func (r *IPSetToLink) Disable() []error { var errs []error errs = append(errs, r.deleteIPRoute()...) errs = append(errs, r.deleteIPRule()...) errs = append(errs, r.deleteIPTablesRules()...) r.enabled = false return errs } func (r *IPSetToLink) NetfilterDHook(table string) error { if !r.enabled { return nil } return r.insertIPTablesRules(table) } func (r *IPSetToLink) LinkUpdateHook(event netlink.LinkUpdate) error { if !r.enabled || event.Change != 1 || event.Link.Attrs().Name != r.IfaceName { return nil } return r.insertIPRoute() } func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string) *IPSetToLink { return &IPSetToLink{ IPTables: nh.IPTables, ChainName: name, IfaceName: ifaceName, IPSetName: ipsetName, } }