diff --git a/pkg/dns-proxy/dns-proxy.go b/dns-proxy/dns-proxy.go similarity index 100% rename from pkg/dns-proxy/dns-proxy.go rename to dns-proxy/dns-proxy.go diff --git a/pkg/dns-proxy/parser.go b/dns-proxy/parser.go similarity index 100% rename from pkg/dns-proxy/parser.go rename to dns-proxy/parser.go diff --git a/pkg/dns-proxy/types.go b/dns-proxy/types.go similarity index 100% rename from pkg/dns-proxy/types.go rename to dns-proxy/types.go diff --git a/pkg/dns-proxy/types_test.go b/dns-proxy/types_test.go similarity index 100% rename from pkg/dns-proxy/types_test.go rename to dns-proxy/types_test.go diff --git a/group.go b/group.go index 3b233ae..71a05fd 100644 --- a/group.go +++ b/group.go @@ -7,10 +7,10 @@ import ( "time" "kvas2-go/models" - "kvas2-go/pkg/ip-helper" "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" + "github.com/vishvananda/netlink/nl" ) type GroupOptions struct { @@ -62,12 +62,52 @@ func (g *Group) Enable() error { var err error - rule := netlink.NewRule() - rule.Mark, err = ipHelper.GetUnusedFwMark(1) - if err != nil { - return fmt.Errorf("error while getting free fwmark: %w", err) + markMap := make(map[uint32]struct{}) + tableMap := map[int]struct{}{ + 0: {}, + 253: {}, + 254: {}, + 255: {}, } - rule.Table, err = ipHelper.GetUnusedTable(1) + var table int + var mark uint32 + + rules, err := netlink.RuleList(nl.FAMILY_ALL) + if err != nil { + return fmt.Errorf("error while getting rules: %w", err) + } + for _, rule := range rules { + markMap[rule.Mark] = struct{}{} + tableMap[rule.Table] = struct{}{} + } + + routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) + if err != nil { + return fmt.Errorf("error while getting routes: %w", err) + } + for _, route := range routes { + tableMap[route.Table] = struct{}{} + } + + for { + if _, exists := tableMap[table]; exists { + table++ + continue + } + break + } + + for { + if _, exists := markMap[mark]; exists { + mark++ + continue + } + break + } + + rule := netlink.NewRule() + rule.Mark = mark + rule.Table = table if err != nil { return fmt.Errorf("error while getting free table: %w", err) } diff --git a/kvas2.go b/kvas2.go index ddec60e..da2c250 100644 --- a/kvas2.go +++ b/kvas2.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" "net" + "strconv" "sync" "time" + "kvas2-go/dns-proxy" "kvas2-go/models" - "kvas2-go/pkg/dns-proxy" - "kvas2-go/pkg/iptables-helper" + "github.com/coreos/go-iptables/iptables" "github.com/rs/zerolog/log" ) @@ -31,10 +32,10 @@ type Config struct { type App struct { Config Config - DNSProxy *dnsProxy.DNSProxy - DNSOverrider *iptablesHelper.DNSOverrider - Records *Records - Groups map[int]*Group + DNSProxy *dnsProxy.DNSProxy + IPTables *iptables.IPTables + Records *Records + Groups map[int]*Group isRunning bool } @@ -72,13 +73,28 @@ func (a *App) Listen(ctx context.Context) []error { newCtx, cancel := context.WithCancel(ctx) defer cancel() - if err := a.DNSOverrider.Enable(); err != nil { - handleError(fmt.Errorf("failed to override DNS: %w", err)) + chainName := fmt.Sprintf("%sDNSOVERRIDER", a.Config.ChainPostfix) + + err := a.IPTables.ClearChain("nat", chainName) + if err != nil { + handleError(fmt.Errorf("failed to clear chain: %w", err)) + return errs + } + + err = a.IPTables.AppendUnique("nat", chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(a.Config.ListenPort))) + if err != nil { + handleError(fmt.Errorf("failed to create rule: %w", err)) + return errs + } + + err = a.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", chainName) + if err != nil { + handleError(fmt.Errorf("failed to linking chain: %w", err)) return errs } for idx, _ := range a.Groups { - err := a.Groups[idx].Enable() + err = a.Groups[idx].Enable() if err != nil { handleError(fmt.Errorf("failed to enable group: %w", err)) return errs @@ -97,15 +113,23 @@ func (a *App) Listen(ctx context.Context) []error { } for idx, _ := range a.Groups { - err := a.Groups[idx].Disable() + err = a.Groups[idx].Disable() if err != nil { handleError(fmt.Errorf("failed to disable group: %w", err)) return errs } } - if err := a.DNSOverrider.Disable(); err != nil { - handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err)) + err = a.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", chainName) + if err != nil { + handleError(fmt.Errorf("failed to unlinking chain: %w", err)) + return errs + } + + err = a.IPTables.ClearAndDeleteChain("nat", chainName) + if err != nil { + handleError(fmt.Errorf("failed to delete chain: %w", err)) + return errs } return errs @@ -221,10 +245,11 @@ func New(config Config) (*App, error) { app.Records = NewRecords() - app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%sDNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort) + ipt, err := iptables.New() if err != nil { - return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err) + return nil, fmt.Errorf("iptables init fail: %w", err) } + app.IPTables = ipt app.Groups = make(map[int]*Group) diff --git a/pkg/ip-helper/ip-helper.go b/pkg/ip-helper/ip-helper.go deleted file mode 100644 index 0740474..0000000 --- a/pkg/ip-helper/ip-helper.go +++ /dev/null @@ -1,118 +0,0 @@ -package ipHelper - -import ( - "bytes" - "errors" - "fmt" - "os/exec" - "slices" - - "github.com/vishvananda/netlink" - "github.com/vishvananda/netlink/nl" -) - -var ( - ErrMaxTableSize = errors.New("max table size") - ErrMaxFwMarkSize = errors.New("max fwmark size") -) - -func ExecIp(args ...string) ([]byte, error) { - cmd := exec.Command("ip", args...) - var out bytes.Buffer - cmd.Stdout = &out - err := cmd.Run() - if err != nil { - return nil, err - } - return out.Bytes(), nil -} - -func GetUsedFwMarks() ([]uint32, error) { - markMap := make(map[uint32]struct{}) - - rules, err := netlink.RuleList(nl.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("error while getting rules: %w", err) - } - - for _, rule := range rules { - markMap[rule.Mark] = struct{}{} - } - - marks := make([]uint32, len(markMap)) - counter := 0 - for mark, _ := range markMap { - marks[counter] = mark - counter++ - } - - return marks, nil -} - -func GetUnusedFwMark(startFrom uint32) (uint32, error) { - usedFwMarks, err := GetUsedFwMarks() - if err != nil { - return 0, fmt.Errorf("error while getting used fwmarks: %w", err) - } - - fwmark := startFrom - for slices.Contains(usedFwMarks, fwmark) { - fwmark++ - if fwmark == 0xFFFFFFFF { - return 0, ErrMaxFwMarkSize - } - } - return fwmark, nil -} - -func GetUsedTables() ([]int, error) { - tableMap := map[int]struct{}{ - 0: {}, - 253: {}, - 254: {}, - 255: {}, - } - - routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE) - if err != nil { - return nil, fmt.Errorf("error while getting routes: %w", err) - } - - for _, route := range routes { - tableMap[route.Table] = struct{}{} - } - - rules, err := netlink.RuleList(nl.FAMILY_ALL) - if err != nil { - return nil, fmt.Errorf("error while getting rules: %w", err) - } - - for _, rule := range rules { - tableMap[rule.Table] = struct{}{} - } - - tables := make([]int, len(tableMap)) - counter := 0 - for table, _ := range tableMap { - tables[counter] = table - counter++ - } - - return tables, nil -} - -func GetUnusedTable(startFrom int) (int, error) { - usedTables, err := GetUsedTables() - if err != nil { - return 0, fmt.Errorf("error while getting used tables: %w", err) - } - - tableID := startFrom - for slices.Contains(usedTables, tableID) { - tableID++ - if tableID > 0x3FF { - return 0, ErrMaxTableSize - } - } - return tableID, nil -} diff --git a/pkg/iptables-helper/iptables-helper.go b/pkg/iptables-helper/iptables-helper.go deleted file mode 100644 index 96f63ef..0000000 --- a/pkg/iptables-helper/iptables-helper.go +++ /dev/null @@ -1,60 +0,0 @@ -package iptablesHelper - -import ( - "fmt" - "strconv" - - "github.com/coreos/go-iptables/iptables" -) - -type DNSOverrider struct { - ipt *iptables.IPTables - chainName string - destPort uint16 -} - -func (o DNSOverrider) Enable() error { - err := o.ipt.ClearChain("nat", o.chainName) - if err != nil { - return fmt.Errorf("failed to clear chain: %w", err) - } - - err = o.ipt.AppendUnique("nat", o.chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(o.destPort))) - if err != nil { - return fmt.Errorf("failed to create rule: %w", err) - } - - err = o.ipt.InsertUnique("nat", "PREROUTING", 1, "-j", o.chainName) - if err != nil { - return fmt.Errorf("failed to linking chain: %w", err) - } - - return nil -} - -func (o DNSOverrider) Disable() error { - err := o.ipt.DeleteIfExists("nat", "PREROUTING", "-j", o.chainName) - if err != nil { - return fmt.Errorf("failed to unlinking chain: %w", err) - } - - err = o.ipt.ClearAndDeleteChain("nat", o.chainName) - if err != nil { - return fmt.Errorf("failed to delete chain: %w", err) - } - - return nil -} - -func NewDNSOverrider(chainName string, destPort uint16) (*DNSOverrider, error) { - ipt, err := iptables.New() - if err != nil { - return nil, fmt.Errorf("iptables init fail: %w", err) - } - - return &DNSOverrider{ - ipt: ipt, - chainName: chainName, - destPort: destPort, - }, nil -}