add group fwmark and table using
This commit is contained in:
parent
41bfa1f39b
commit
b74ee760cb
9
group.go
Normal file
9
group.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "kvas2-go/models"
|
||||||
|
|
||||||
|
type Group struct {
|
||||||
|
*models.Group
|
||||||
|
FWMark uint32
|
||||||
|
Table uint16
|
||||||
|
}
|
50
kvas2.go
50
kvas2.go
@ -2,12 +2,20 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"kvas2-go/models"
|
"strconv"
|
||||||
"kvas2-go/pkg/dns-proxy"
|
|
||||||
"kvas2-go/pkg/iptables-helper"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"kvas2-go/models"
|
||||||
|
"kvas2-go/pkg/dns-proxy"
|
||||||
|
"kvas2-go/pkg/ip-helper"
|
||||||
|
"kvas2-go/pkg/iptables-helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrGroupIDConflict = errors.New("group id conflict")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@ -23,7 +31,7 @@ type App struct {
|
|||||||
DNSProxy *dnsProxy.DNSProxy
|
DNSProxy *dnsProxy.DNSProxy
|
||||||
DNSOverrider *iptablesHelper.DNSOverrider
|
DNSOverrider *iptablesHelper.DNSOverrider
|
||||||
Records *Records
|
Records *Records
|
||||||
Groups []*models.Group
|
Groups map[int]Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) Listen(ctx context.Context) []error {
|
func (a *App) Listen(ctx context.Context) []error {
|
||||||
@ -76,6 +84,38 @@ func (a *App) Listen(ctx context.Context) []error {
|
|||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *App) AppendGroup(group *models.Group) error {
|
||||||
|
if _, exists := a.Groups[group.ID]; exists {
|
||||||
|
return ErrGroupIDConflict
|
||||||
|
}
|
||||||
|
|
||||||
|
fwmark, err := ipHelper.GetUnusedFwMark()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while getting fwmark: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := ipHelper.GetUnusedTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while getting table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := ipHelper.ExecIp("rule", "add", "fwmark", strconv.Itoa(int(fwmark)), "table", strconv.Itoa(int(table)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(out) != 0 {
|
||||||
|
return errors.New(string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Groups[group.ID] = Group{
|
||||||
|
Group: group,
|
||||||
|
FWMark: fwmark,
|
||||||
|
Table: table,
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (a *App) processARecord(aRecord dnsProxy.Address) {
|
func (a *App) processARecord(aRecord dnsProxy.Address) {
|
||||||
ttlDuration := time.Duration(aRecord.TTL) * time.Second
|
ttlDuration := time.Duration(aRecord.TTL) * time.Second
|
||||||
if ttlDuration < a.Config.MinimalTTL {
|
if ttlDuration < a.Config.MinimalTTL {
|
||||||
@ -169,7 +209,7 @@ func New(config Config) (*App, error) {
|
|||||||
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
|
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
app.Groups = make([]*models.Group, 0)
|
app.Groups = make(map[int]Group)
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
@ -3,15 +3,22 @@ package ipHelper
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func execIp(args ...string) ([]byte, error) {
|
var (
|
||||||
|
ErrMaxTableSize = errors.New("max table size")
|
||||||
|
ErrMaxFwMarkSize = errors.New("max fwmark size")
|
||||||
|
)
|
||||||
|
|
||||||
|
func ExecIp(args ...string) ([]byte, error) {
|
||||||
cmd := exec.Command("ip", args...)
|
cmd := exec.Command("ip", args...)
|
||||||
var out bytes.Buffer
|
var out bytes.Buffer
|
||||||
cmd.Stdout = &out
|
cmd.Stdout = &out
|
||||||
@ -22,10 +29,10 @@ func execIp(args ...string) ([]byte, error) {
|
|||||||
return out.Bytes(), nil
|
return out.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUsedFwMarks() ([]int, error) {
|
func GetUsedFwMarks() ([]uint32, error) {
|
||||||
markMap := make(map[int]struct{})
|
markMap := make(map[uint32]struct{})
|
||||||
|
|
||||||
out, err := execIp("rule", "show")
|
out, err := ExecIp("rule", "show")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error while getting rules: %w", err)
|
return nil, fmt.Errorf("error while getting rules: %w", err)
|
||||||
}
|
}
|
||||||
@ -40,11 +47,11 @@ func GetUsedFwMarks() ([]int, error) {
|
|||||||
hexStr := match[1]
|
hexStr := match[1]
|
||||||
hexValue, err := strconv.ParseInt(hexStr, 16, 64)
|
hexValue, err := strconv.ParseInt(hexStr, 16, 64)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
markMap[int(hexValue)] = struct{}{}
|
markMap[uint32(hexValue)] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
marks := make([]int, len(markMap))
|
marks := make([]uint32, len(markMap))
|
||||||
counter := 0
|
counter := 0
|
||||||
for mark, _ := range markMap {
|
for mark, _ := range markMap {
|
||||||
marks[counter] = mark
|
marks[counter] = mark
|
||||||
@ -54,8 +61,24 @@ func GetUsedFwMarks() ([]int, error) {
|
|||||||
return marks, nil
|
return marks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTableAliases() (map[string]int, error) {
|
func GetUnusedFwMark() (uint32, error) {
|
||||||
tables := map[string]int{
|
usedFwMarks, err := GetUsedFwMarks()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error while getting used fwmarks: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fwmark := uint32(1)
|
||||||
|
for slices.Contains(usedFwMarks, fwmark) {
|
||||||
|
fwmark++
|
||||||
|
if fwmark == 0xFFFFFFFF {
|
||||||
|
return 0, ErrMaxFwMarkSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fwmark, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTableAliases() (map[string]uint16, error) {
|
||||||
|
tables := map[string]uint16{
|
||||||
"unspec": 0,
|
"unspec": 0,
|
||||||
"default": 253,
|
"default": 253,
|
||||||
"main": 254,
|
"main": 254,
|
||||||
@ -82,7 +105,7 @@ func GetTableAliases() (map[string]int, error) {
|
|||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
tableID, err := strconv.Atoi(parts[0])
|
tableID, err := strconv.Atoi(parts[0])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
tables[parts[1]] = tableID
|
tables[parts[1]] = uint16(tableID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -94,8 +117,8 @@ func GetTableAliases() (map[string]int, error) {
|
|||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUsedTables() ([]int, error) {
|
func GetUsedTables() ([]uint16, error) {
|
||||||
tableMap := map[int]struct{}{
|
tableMap := map[uint16]struct{}{
|
||||||
0: {},
|
0: {},
|
||||||
253: {},
|
253: {},
|
||||||
254: {},
|
254: {},
|
||||||
@ -107,7 +130,7 @@ func GetUsedTables() ([]int, error) {
|
|||||||
return nil, fmt.Errorf("error while getting table aliases: %w", err)
|
return nil, fmt.Errorf("error while getting table aliases: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := execIp("route", "show", "table", "all")
|
out, err := ExecIp("route", "show", "table", "all")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error while getting routes: %w", err)
|
return nil, fmt.Errorf("error while getting routes: %w", err)
|
||||||
}
|
}
|
||||||
@ -119,7 +142,8 @@ func GetUsedTables() ([]int, error) {
|
|||||||
if part == "table" && i+1 < len(parts) {
|
if part == "table" && i+1 < len(parts) {
|
||||||
tableNum, ok := tableAliases[parts[i+1]]
|
tableNum, ok := tableAliases[parts[i+1]]
|
||||||
if !ok {
|
if !ok {
|
||||||
tableNum, _ = strconv.Atoi(parts[i+1])
|
tableNumInt, _ := strconv.Atoi(parts[i+1])
|
||||||
|
tableNum = uint16(tableNumInt)
|
||||||
}
|
}
|
||||||
tableMap[tableNum] = struct{}{}
|
tableMap[tableNum] = struct{}{}
|
||||||
}
|
}
|
||||||
@ -127,7 +151,28 @@ func GetUsedTables() ([]int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tables := make([]int, len(tableMap))
|
out, err = ExecIp("rule", "show")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error while getting rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range strings.Split(string(out), "\n") {
|
||||||
|
if strings.Contains(line, "lookup") {
|
||||||
|
parts := strings.Fields(line)
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "lookup" && i+1 < len(parts) {
|
||||||
|
tableNum, ok := tableAliases[parts[i+1]]
|
||||||
|
if !ok {
|
||||||
|
tableNumInt, _ := strconv.Atoi(parts[i+1])
|
||||||
|
tableNum = uint16(tableNumInt)
|
||||||
|
}
|
||||||
|
tableMap[tableNum] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tables := make([]uint16, len(tableMap))
|
||||||
counter := 0
|
counter := 0
|
||||||
for table, _ := range tableMap {
|
for table, _ := range tableMap {
|
||||||
tables[counter] = table
|
tables[counter] = table
|
||||||
@ -136,3 +181,19 @@ func GetUsedTables() ([]int, error) {
|
|||||||
|
|
||||||
return tables, nil
|
return tables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetUnusedTable() (uint16, error) {
|
||||||
|
usedTables, err := GetUsedTables()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error while getting used tables: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableID := uint16(1)
|
||||||
|
for slices.Contains(usedTables, tableID) {
|
||||||
|
tableID++
|
||||||
|
if tableID > 0x3FF {
|
||||||
|
return 0, ErrMaxTableSize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tableID, nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user