add group fwmark and table using

This commit is contained in:
Vladimir Avtsenov 2024-08-27 03:07:58 +03:00
parent 41bfa1f39b
commit b74ee760cb
3 changed files with 129 additions and 19 deletions

9
group.go Normal file
View File

@ -0,0 +1,9 @@
package main
import "kvas2-go/models"
type Group struct {
*models.Group
FWMark uint32
Table uint16
}

View File

@ -2,12 +2,20 @@ package main
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"kvas2-go/models" "strconv"
"kvas2-go/pkg/dns-proxy"
"kvas2-go/pkg/iptables-helper"
"sync" "sync"
"time" "time"
"kvas2-go/models"
"kvas2-go/pkg/dns-proxy"
"kvas2-go/pkg/ip-helper"
"kvas2-go/pkg/iptables-helper"
)
var (
ErrGroupIDConflict = errors.New("group id conflict")
) )
type Config struct { type Config struct {
@ -23,7 +31,7 @@ type App struct {
DNSProxy *dnsProxy.DNSProxy DNSProxy *dnsProxy.DNSProxy
DNSOverrider *iptablesHelper.DNSOverrider DNSOverrider *iptablesHelper.DNSOverrider
Records *Records Records *Records
Groups []*models.Group Groups map[int]Group
} }
func (a *App) Listen(ctx context.Context) []error { func (a *App) Listen(ctx context.Context) []error {
@ -76,6 +84,38 @@ func (a *App) Listen(ctx context.Context) []error {
return errs return errs
} }
func (a *App) AppendGroup(group *models.Group) error {
if _, exists := a.Groups[group.ID]; exists {
return ErrGroupIDConflict
}
fwmark, err := ipHelper.GetUnusedFwMark()
if err != nil {
return fmt.Errorf("error while getting fwmark: %w", err)
}
table, err := ipHelper.GetUnusedTable()
if err != nil {
return fmt.Errorf("error while getting table: %w", err)
}
out, err := ipHelper.ExecIp("rule", "add", "fwmark", strconv.Itoa(int(fwmark)), "table", strconv.Itoa(int(table)))
if err != nil {
return err
}
if len(out) != 0 {
return errors.New(string(out))
}
a.Groups[group.ID] = Group{
Group: group,
FWMark: fwmark,
Table: table,
}
return nil
}
func (a *App) processARecord(aRecord dnsProxy.Address) { func (a *App) processARecord(aRecord dnsProxy.Address) {
ttlDuration := time.Duration(aRecord.TTL) * time.Second ttlDuration := time.Duration(aRecord.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL { if ttlDuration < a.Config.MinimalTTL {
@ -169,7 +209,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err) return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
} }
app.Groups = make([]*models.Group, 0) app.Groups = make(map[int]Group)
return app, nil return app, nil
} }

View File

@ -3,15 +3,22 @@ package ipHelper
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"errors"
"fmt" "fmt"
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
"slices"
"strconv" "strconv"
"strings" "strings"
) )
func execIp(args ...string) ([]byte, error) { var (
ErrMaxTableSize = errors.New("max table size")
ErrMaxFwMarkSize = errors.New("max fwmark size")
)
func ExecIp(args ...string) ([]byte, error) {
cmd := exec.Command("ip", args...) cmd := exec.Command("ip", args...)
var out bytes.Buffer var out bytes.Buffer
cmd.Stdout = &out cmd.Stdout = &out
@ -22,10 +29,10 @@ func execIp(args ...string) ([]byte, error) {
return out.Bytes(), nil return out.Bytes(), nil
} }
func GetUsedFwMarks() ([]int, error) { func GetUsedFwMarks() ([]uint32, error) {
markMap := make(map[int]struct{}) markMap := make(map[uint32]struct{})
out, err := execIp("rule", "show") out, err := ExecIp("rule", "show")
if err != nil { if err != nil {
return nil, fmt.Errorf("error while getting rules: %w", err) return nil, fmt.Errorf("error while getting rules: %w", err)
} }
@ -40,11 +47,11 @@ func GetUsedFwMarks() ([]int, error) {
hexStr := match[1] hexStr := match[1]
hexValue, err := strconv.ParseInt(hexStr, 16, 64) hexValue, err := strconv.ParseInt(hexStr, 16, 64)
if err == nil { if err == nil {
markMap[int(hexValue)] = struct{}{} markMap[uint32(hexValue)] = struct{}{}
} }
} }
marks := make([]int, len(markMap)) marks := make([]uint32, len(markMap))
counter := 0 counter := 0
for mark, _ := range markMap { for mark, _ := range markMap {
marks[counter] = mark marks[counter] = mark
@ -54,8 +61,24 @@ func GetUsedFwMarks() ([]int, error) {
return marks, nil return marks, nil
} }
func GetTableAliases() (map[string]int, error) { func GetUnusedFwMark() (uint32, error) {
tables := map[string]int{ usedFwMarks, err := GetUsedFwMarks()
if err != nil {
return 0, fmt.Errorf("error while getting used fwmarks: %w", err)
}
fwmark := uint32(1)
for slices.Contains(usedFwMarks, fwmark) {
fwmark++
if fwmark == 0xFFFFFFFF {
return 0, ErrMaxFwMarkSize
}
}
return fwmark, nil
}
func GetTableAliases() (map[string]uint16, error) {
tables := map[string]uint16{
"unspec": 0, "unspec": 0,
"default": 253, "default": 253,
"main": 254, "main": 254,
@ -82,7 +105,7 @@ func GetTableAliases() (map[string]int, error) {
if len(parts) >= 2 { if len(parts) >= 2 {
tableID, err := strconv.Atoi(parts[0]) tableID, err := strconv.Atoi(parts[0])
if err == nil { if err == nil {
tables[parts[1]] = tableID tables[parts[1]] = uint16(tableID)
} }
} }
} }
@ -94,8 +117,8 @@ func GetTableAliases() (map[string]int, error) {
return tables, nil return tables, nil
} }
func GetUsedTables() ([]int, error) { func GetUsedTables() ([]uint16, error) {
tableMap := map[int]struct{}{ tableMap := map[uint16]struct{}{
0: {}, 0: {},
253: {}, 253: {},
254: {}, 254: {},
@ -107,7 +130,7 @@ func GetUsedTables() ([]int, error) {
return nil, fmt.Errorf("error while getting table aliases: %w", err) return nil, fmt.Errorf("error while getting table aliases: %w", err)
} }
out, err := execIp("route", "show", "table", "all") out, err := ExecIp("route", "show", "table", "all")
if err != nil { if err != nil {
return nil, fmt.Errorf("error while getting routes: %w", err) return nil, fmt.Errorf("error while getting routes: %w", err)
} }
@ -119,7 +142,8 @@ func GetUsedTables() ([]int, error) {
if part == "table" && i+1 < len(parts) { if part == "table" && i+1 < len(parts) {
tableNum, ok := tableAliases[parts[i+1]] tableNum, ok := tableAliases[parts[i+1]]
if !ok { if !ok {
tableNum, _ = strconv.Atoi(parts[i+1]) tableNumInt, _ := strconv.Atoi(parts[i+1])
tableNum = uint16(tableNumInt)
} }
tableMap[tableNum] = struct{}{} tableMap[tableNum] = struct{}{}
} }
@ -127,7 +151,28 @@ func GetUsedTables() ([]int, error) {
} }
} }
tables := make([]int, len(tableMap)) out, err = ExecIp("rule", "show")
if err != nil {
return nil, fmt.Errorf("error while getting rules: %w", err)
}
for _, line := range strings.Split(string(out), "\n") {
if strings.Contains(line, "lookup") {
parts := strings.Fields(line)
for i, part := range parts {
if part == "lookup" && i+1 < len(parts) {
tableNum, ok := tableAliases[parts[i+1]]
if !ok {
tableNumInt, _ := strconv.Atoi(parts[i+1])
tableNum = uint16(tableNumInt)
}
tableMap[tableNum] = struct{}{}
}
}
}
}
tables := make([]uint16, len(tableMap))
counter := 0 counter := 0
for table, _ := range tableMap { for table, _ := range tableMap {
tables[counter] = table tables[counter] = table
@ -136,3 +181,19 @@ func GetUsedTables() ([]int, error) {
return tables, nil return tables, nil
} }
func GetUnusedTable() (uint16, error) {
usedTables, err := GetUsedTables()
if err != nil {
return 0, fmt.Errorf("error while getting used tables: %w", err)
}
tableID := uint16(1)
for slices.Contains(usedTables, tableID) {
tableID++
if tableID > 0x3FF {
return 0, ErrMaxTableSize
}
}
return tableID, nil
}