diff --git a/group.go b/group.go index 190abca..35e8752 100644 --- a/group.go +++ b/group.go @@ -2,8 +2,13 @@ package main import "kvas2-go/models" +type GroupOptions struct { + Enabled bool + FWMark uint32 + Table uint16 +} + type Group struct { *models.Group - FWMark uint32 - Table uint16 + options GroupOptions } diff --git a/kvas2.go b/kvas2.go index 2ddfc63..63a0a04 100644 --- a/kvas2.go +++ b/kvas2.go @@ -101,12 +101,16 @@ func (a *App) Listen(ctx context.Context) []error { } func (a *App) usingGroup(idx int) error { - fwmark, err := ipHelper.GetUnusedFwMark() + if a.Groups[idx].options.Enabled { + return nil + } + + fwmark, err := ipHelper.GetUnusedFwMark(1) if err != nil { return fmt.Errorf("error while getting fwmark: %w", err) } - table, err := ipHelper.GetUnusedTable() + table, err := ipHelper.GetUnusedTable(1) if err != nil { return fmt.Errorf("error while getting table: %w", err) } @@ -119,14 +123,21 @@ func (a *App) usingGroup(idx int) error { return errors.New(string(out)) } - a.Groups[idx].FWMark = fwmark - a.Groups[idx].Table = table + a.Groups[idx].options.Enabled = true + a.Groups[idx].options.FWMark = fwmark + a.Groups[idx].options.Table = table return nil } func (a *App) releaseGroup(idx int) error { - out, err := ipHelper.ExecIp("rule", "del", "fwmark", strconv.Itoa(int(a.Groups[idx].FWMark)), "table", strconv.Itoa(int(a.Groups[idx].Table))) + if !a.Groups[idx].options.Enabled { + return nil + } + + fwmark := strconv.Itoa(int(a.Groups[idx].options.FWMark)) + table := strconv.Itoa(int(a.Groups[idx].options.Table)) + out, err := ipHelper.ExecIp("rule", "del", "fwmark", fwmark, "table", table) if err != nil { return err } diff --git a/pkg/ip-helper/ip-helper.go b/pkg/ip-helper/ip-helper.go index a470730..352a22d 100644 --- a/pkg/ip-helper/ip-helper.go +++ b/pkg/ip-helper/ip-helper.go @@ -61,13 +61,13 @@ func GetUsedFwMarks() ([]uint32, error) { return marks, nil } -func GetUnusedFwMark() (uint32, error) { +func GetUnusedFwMark(startFrom uint32) (uint32, error) { usedFwMarks, err := GetUsedFwMarks() if err != nil { return 0, fmt.Errorf("error while getting used fwmarks: %w", err) } - fwmark := uint32(1) + fwmark := startFrom for slices.Contains(usedFwMarks, fwmark) { fwmark++ if fwmark == 0xFFFFFFFF { @@ -182,13 +182,13 @@ func GetUsedTables() ([]uint16, error) { return tables, nil } -func GetUnusedTable() (uint16, error) { +func GetUnusedTable(startFrom uint16) (uint16, error) { usedTables, err := GetUsedTables() if err != nil { return 0, fmt.Errorf("error while getting used tables: %w", err) } - tableID := uint16(1) + tableID := startFrom for slices.Contains(usedTables, tableID) { tableID++ if tableID > 0x3FF {