From b74ee760cb9379bdddbb10ed4323b37e9131bd1f Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Tue, 27 Aug 2024 03:07:58 +0300 Subject: [PATCH] add group fwmark and table using --- group.go | 9 ++++ kvas2.go | 50 ++++++++++++++++++--- pkg/ip-helper/ip-helper.go | 89 ++++++++++++++++++++++++++++++++------ 3 files changed, 129 insertions(+), 19 deletions(-) create mode 100644 group.go diff --git a/group.go b/group.go new file mode 100644 index 0000000..190abca --- /dev/null +++ b/group.go @@ -0,0 +1,9 @@ +package main + +import "kvas2-go/models" + +type Group struct { + *models.Group + FWMark uint32 + Table uint16 +} diff --git a/kvas2.go b/kvas2.go index 62286d1..f855409 100644 --- a/kvas2.go +++ b/kvas2.go @@ -2,12 +2,20 @@ package main import ( "context" + "errors" "fmt" - "kvas2-go/models" - "kvas2-go/pkg/dns-proxy" - "kvas2-go/pkg/iptables-helper" + "strconv" "sync" "time" + + "kvas2-go/models" + "kvas2-go/pkg/dns-proxy" + "kvas2-go/pkg/ip-helper" + "kvas2-go/pkg/iptables-helper" +) + +var ( + ErrGroupIDConflict = errors.New("group id conflict") ) type Config struct { @@ -23,7 +31,7 @@ type App struct { DNSProxy *dnsProxy.DNSProxy DNSOverrider *iptablesHelper.DNSOverrider Records *Records - Groups []*models.Group + Groups map[int]Group } func (a *App) Listen(ctx context.Context) []error { @@ -76,6 +84,38 @@ func (a *App) Listen(ctx context.Context) []error { return errs } +func (a *App) AppendGroup(group *models.Group) error { + if _, exists := a.Groups[group.ID]; exists { + return ErrGroupIDConflict + } + + fwmark, err := ipHelper.GetUnusedFwMark() + if err != nil { + return fmt.Errorf("error while getting fwmark: %w", err) + } + + table, err := ipHelper.GetUnusedTable() + if err != nil { + return fmt.Errorf("error while getting table: %w", err) + } + + out, err := ipHelper.ExecIp("rule", "add", "fwmark", strconv.Itoa(int(fwmark)), "table", strconv.Itoa(int(table))) + if err != nil { + return err + } + if len(out) != 0 { + return errors.New(string(out)) + } + + a.Groups[group.ID] = Group{ + Group: group, + FWMark: fwmark, + Table: table, + } + + return nil +} + func (a *App) processARecord(aRecord dnsProxy.Address) { ttlDuration := time.Duration(aRecord.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { @@ -169,7 +209,7 @@ func New(config Config) (*App, error) { return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err) } - app.Groups = make([]*models.Group, 0) + app.Groups = make(map[int]Group) return app, nil } diff --git a/pkg/ip-helper/ip-helper.go b/pkg/ip-helper/ip-helper.go index 1be76a9..a470730 100644 --- a/pkg/ip-helper/ip-helper.go +++ b/pkg/ip-helper/ip-helper.go @@ -3,15 +3,22 @@ package ipHelper import ( "bufio" "bytes" + "errors" "fmt" "os" "os/exec" "regexp" + "slices" "strconv" "strings" ) -func execIp(args ...string) ([]byte, error) { +var ( + ErrMaxTableSize = errors.New("max table size") + ErrMaxFwMarkSize = errors.New("max fwmark size") +) + +func ExecIp(args ...string) ([]byte, error) { cmd := exec.Command("ip", args...) var out bytes.Buffer cmd.Stdout = &out @@ -22,10 +29,10 @@ func execIp(args ...string) ([]byte, error) { return out.Bytes(), nil } -func GetUsedFwMarks() ([]int, error) { - markMap := make(map[int]struct{}) +func GetUsedFwMarks() ([]uint32, error) { + markMap := make(map[uint32]struct{}) - out, err := execIp("rule", "show") + out, err := ExecIp("rule", "show") if err != nil { return nil, fmt.Errorf("error while getting rules: %w", err) } @@ -40,11 +47,11 @@ func GetUsedFwMarks() ([]int, error) { hexStr := match[1] hexValue, err := strconv.ParseInt(hexStr, 16, 64) if err == nil { - markMap[int(hexValue)] = struct{}{} + markMap[uint32(hexValue)] = struct{}{} } } - marks := make([]int, len(markMap)) + marks := make([]uint32, len(markMap)) counter := 0 for mark, _ := range markMap { marks[counter] = mark @@ -54,8 +61,24 @@ func GetUsedFwMarks() ([]int, error) { return marks, nil } -func GetTableAliases() (map[string]int, error) { - tables := map[string]int{ +func GetUnusedFwMark() (uint32, error) { + usedFwMarks, err := GetUsedFwMarks() + if err != nil { + return 0, fmt.Errorf("error while getting used fwmarks: %w", err) + } + + fwmark := uint32(1) + for slices.Contains(usedFwMarks, fwmark) { + fwmark++ + if fwmark == 0xFFFFFFFF { + return 0, ErrMaxFwMarkSize + } + } + return fwmark, nil +} + +func GetTableAliases() (map[string]uint16, error) { + tables := map[string]uint16{ "unspec": 0, "default": 253, "main": 254, @@ -82,7 +105,7 @@ func GetTableAliases() (map[string]int, error) { if len(parts) >= 2 { tableID, err := strconv.Atoi(parts[0]) if err == nil { - tables[parts[1]] = tableID + tables[parts[1]] = uint16(tableID) } } } @@ -94,8 +117,8 @@ func GetTableAliases() (map[string]int, error) { return tables, nil } -func GetUsedTables() ([]int, error) { - tableMap := map[int]struct{}{ +func GetUsedTables() ([]uint16, error) { + tableMap := map[uint16]struct{}{ 0: {}, 253: {}, 254: {}, @@ -107,7 +130,7 @@ func GetUsedTables() ([]int, error) { return nil, fmt.Errorf("error while getting table aliases: %w", err) } - out, err := execIp("route", "show", "table", "all") + out, err := ExecIp("route", "show", "table", "all") if err != nil { return nil, fmt.Errorf("error while getting routes: %w", err) } @@ -119,7 +142,8 @@ func GetUsedTables() ([]int, error) { if part == "table" && i+1 < len(parts) { tableNum, ok := tableAliases[parts[i+1]] if !ok { - tableNum, _ = strconv.Atoi(parts[i+1]) + tableNumInt, _ := strconv.Atoi(parts[i+1]) + tableNum = uint16(tableNumInt) } tableMap[tableNum] = struct{}{} } @@ -127,7 +151,28 @@ func GetUsedTables() ([]int, error) { } } - tables := make([]int, len(tableMap)) + out, err = ExecIp("rule", "show") + if err != nil { + return nil, fmt.Errorf("error while getting rules: %w", err) + } + + for _, line := range strings.Split(string(out), "\n") { + if strings.Contains(line, "lookup") { + parts := strings.Fields(line) + for i, part := range parts { + if part == "lookup" && i+1 < len(parts) { + tableNum, ok := tableAliases[parts[i+1]] + if !ok { + tableNumInt, _ := strconv.Atoi(parts[i+1]) + tableNum = uint16(tableNumInt) + } + tableMap[tableNum] = struct{}{} + } + } + } + } + + tables := make([]uint16, len(tableMap)) counter := 0 for table, _ := range tableMap { tables[counter] = table @@ -136,3 +181,19 @@ func GetUsedTables() ([]int, error) { return tables, nil } + +func GetUnusedTable() (uint16, error) { + usedTables, err := GetUsedTables() + if err != nil { + return 0, fmt.Errorf("error while getting used tables: %w", err) + } + + tableID := uint16(1) + for slices.Contains(usedTables, tableID) { + tableID++ + if tableID > 0x3FF { + return 0, ErrMaxTableSize + } + } + return tableID, nil +}