merging some code

This commit is contained in:
Vladimir Avtsenov 2024-09-05 03:53:10 +03:00
parent 81d061a316
commit 3c3286fa34
8 changed files with 85 additions and 198 deletions

View File

@ -7,10 +7,10 @@ import (
"time"
"kvas2-go/models"
"kvas2-go/pkg/ip-helper"
"github.com/rs/zerolog/log"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
)
type GroupOptions struct {
@ -62,12 +62,52 @@ func (g *Group) Enable() error {
var err error
rule := netlink.NewRule()
rule.Mark, err = ipHelper.GetUnusedFwMark(1)
if err != nil {
return fmt.Errorf("error while getting free fwmark: %w", err)
markMap := make(map[uint32]struct{})
tableMap := map[int]struct{}{
0: {},
253: {},
254: {},
255: {},
}
rule.Table, err = ipHelper.GetUnusedTable(1)
var table int
var mark uint32
rules, err := netlink.RuleList(nl.FAMILY_ALL)
if err != nil {
return fmt.Errorf("error while getting rules: %w", err)
}
for _, rule := range rules {
markMap[rule.Mark] = struct{}{}
tableMap[rule.Table] = struct{}{}
}
routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE)
if err != nil {
return fmt.Errorf("error while getting routes: %w", err)
}
for _, route := range routes {
tableMap[route.Table] = struct{}{}
}
for {
if _, exists := tableMap[table]; exists {
table++
continue
}
break
}
for {
if _, exists := markMap[mark]; exists {
mark++
continue
}
break
}
rule := netlink.NewRule()
rule.Mark = mark
rule.Table = table
if err != nil {
return fmt.Errorf("error while getting free table: %w", err)
}

View File

@ -5,13 +5,14 @@ import (
"errors"
"fmt"
"net"
"strconv"
"sync"
"time"
"kvas2-go/dns-proxy"
"kvas2-go/models"
"kvas2-go/pkg/dns-proxy"
"kvas2-go/pkg/iptables-helper"
"github.com/coreos/go-iptables/iptables"
"github.com/rs/zerolog/log"
)
@ -32,7 +33,7 @@ type App struct {
Config Config
DNSProxy *dnsProxy.DNSProxy
DNSOverrider *iptablesHelper.DNSOverrider
IPTables *iptables.IPTables
Records *Records
Groups map[int]*Group
@ -72,13 +73,28 @@ func (a *App) Listen(ctx context.Context) []error {
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := a.DNSOverrider.Enable(); err != nil {
handleError(fmt.Errorf("failed to override DNS: %w", err))
chainName := fmt.Sprintf("%sDNSOVERRIDER", a.Config.ChainPostfix)
err := a.IPTables.ClearChain("nat", chainName)
if err != nil {
handleError(fmt.Errorf("failed to clear chain: %w", err))
return errs
}
err = a.IPTables.AppendUnique("nat", chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(a.Config.ListenPort)))
if err != nil {
handleError(fmt.Errorf("failed to create rule: %w", err))
return errs
}
err = a.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", chainName)
if err != nil {
handleError(fmt.Errorf("failed to linking chain: %w", err))
return errs
}
for idx, _ := range a.Groups {
err := a.Groups[idx].Enable()
err = a.Groups[idx].Enable()
if err != nil {
handleError(fmt.Errorf("failed to enable group: %w", err))
return errs
@ -97,15 +113,23 @@ func (a *App) Listen(ctx context.Context) []error {
}
for idx, _ := range a.Groups {
err := a.Groups[idx].Disable()
err = a.Groups[idx].Disable()
if err != nil {
handleError(fmt.Errorf("failed to disable group: %w", err))
return errs
}
}
if err := a.DNSOverrider.Disable(); err != nil {
handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err))
err = a.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", chainName)
if err != nil {
handleError(fmt.Errorf("failed to unlinking chain: %w", err))
return errs
}
err = a.IPTables.ClearAndDeleteChain("nat", chainName)
if err != nil {
handleError(fmt.Errorf("failed to delete chain: %w", err))
return errs
}
return errs
@ -221,10 +245,11 @@ func New(config Config) (*App, error) {
app.Records = NewRecords()
app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%sDNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort)
ipt, err := iptables.New()
if err != nil {
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
return nil, fmt.Errorf("iptables init fail: %w", err)
}
app.IPTables = ipt
app.Groups = make(map[int]*Group)

View File

@ -1,118 +0,0 @@
package ipHelper
import (
"bytes"
"errors"
"fmt"
"os/exec"
"slices"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
)
var (
ErrMaxTableSize = errors.New("max table size")
ErrMaxFwMarkSize = errors.New("max fwmark size")
)
func ExecIp(args ...string) ([]byte, error) {
cmd := exec.Command("ip", args...)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
return nil, err
}
return out.Bytes(), nil
}
func GetUsedFwMarks() ([]uint32, error) {
markMap := make(map[uint32]struct{})
rules, err := netlink.RuleList(nl.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("error while getting rules: %w", err)
}
for _, rule := range rules {
markMap[rule.Mark] = struct{}{}
}
marks := make([]uint32, len(markMap))
counter := 0
for mark, _ := range markMap {
marks[counter] = mark
counter++
}
return marks, nil
}
func GetUnusedFwMark(startFrom uint32) (uint32, error) {
usedFwMarks, err := GetUsedFwMarks()
if err != nil {
return 0, fmt.Errorf("error while getting used fwmarks: %w", err)
}
fwmark := startFrom
for slices.Contains(usedFwMarks, fwmark) {
fwmark++
if fwmark == 0xFFFFFFFF {
return 0, ErrMaxFwMarkSize
}
}
return fwmark, nil
}
func GetUsedTables() ([]int, error) {
tableMap := map[int]struct{}{
0: {},
253: {},
254: {},
255: {},
}
routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE)
if err != nil {
return nil, fmt.Errorf("error while getting routes: %w", err)
}
for _, route := range routes {
tableMap[route.Table] = struct{}{}
}
rules, err := netlink.RuleList(nl.FAMILY_ALL)
if err != nil {
return nil, fmt.Errorf("error while getting rules: %w", err)
}
for _, rule := range rules {
tableMap[rule.Table] = struct{}{}
}
tables := make([]int, len(tableMap))
counter := 0
for table, _ := range tableMap {
tables[counter] = table
counter++
}
return tables, nil
}
func GetUnusedTable(startFrom int) (int, error) {
usedTables, err := GetUsedTables()
if err != nil {
return 0, fmt.Errorf("error while getting used tables: %w", err)
}
tableID := startFrom
for slices.Contains(usedTables, tableID) {
tableID++
if tableID > 0x3FF {
return 0, ErrMaxTableSize
}
}
return tableID, nil
}

View File

@ -1,60 +0,0 @@
package iptablesHelper
import (
"fmt"
"strconv"
"github.com/coreos/go-iptables/iptables"
)
type DNSOverrider struct {
ipt *iptables.IPTables
chainName string
destPort uint16
}
func (o DNSOverrider) Enable() error {
err := o.ipt.ClearChain("nat", o.chainName)
if err != nil {
return fmt.Errorf("failed to clear chain: %w", err)
}
err = o.ipt.AppendUnique("nat", o.chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(o.destPort)))
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
err = o.ipt.InsertUnique("nat", "PREROUTING", 1, "-j", o.chainName)
if err != nil {
return fmt.Errorf("failed to linking chain: %w", err)
}
return nil
}
func (o DNSOverrider) Disable() error {
err := o.ipt.DeleteIfExists("nat", "PREROUTING", "-j", o.chainName)
if err != nil {
return fmt.Errorf("failed to unlinking chain: %w", err)
}
err = o.ipt.ClearAndDeleteChain("nat", o.chainName)
if err != nil {
return fmt.Errorf("failed to delete chain: %w", err)
}
return nil
}
func NewDNSOverrider(chainName string, destPort uint16) (*DNSOverrider, error) {
ipt, err := iptables.New()
if err != nil {
return nil, fmt.Errorf("iptables init fail: %w", err)
}
return &DNSOverrider{
ipt: ipt,
chainName: chainName,
destPort: destPort,
}, nil
}