merging some code
This commit is contained in:
parent
81d061a316
commit
3c3286fa34
52
group.go
52
group.go
@ -7,10 +7,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"kvas2-go/models"
|
"kvas2-go/models"
|
||||||
"kvas2-go/pkg/ip-helper"
|
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
"github.com/vishvananda/netlink/nl"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GroupOptions struct {
|
type GroupOptions struct {
|
||||||
@ -62,12 +62,52 @@ func (g *Group) Enable() error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
rule := netlink.NewRule()
|
markMap := make(map[uint32]struct{})
|
||||||
rule.Mark, err = ipHelper.GetUnusedFwMark(1)
|
tableMap := map[int]struct{}{
|
||||||
if err != nil {
|
0: {},
|
||||||
return fmt.Errorf("error while getting free fwmark: %w", err)
|
253: {},
|
||||||
|
254: {},
|
||||||
|
255: {},
|
||||||
}
|
}
|
||||||
rule.Table, err = ipHelper.GetUnusedTable(1)
|
var table int
|
||||||
|
var mark uint32
|
||||||
|
|
||||||
|
rules, err := netlink.RuleList(nl.FAMILY_ALL)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while getting rules: %w", err)
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
markMap[rule.Mark] = struct{}{}
|
||||||
|
tableMap[rule.Table] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error while getting routes: %w", err)
|
||||||
|
}
|
||||||
|
for _, route := range routes {
|
||||||
|
tableMap[route.Table] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, exists := tableMap[table]; exists {
|
||||||
|
table++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
if _, exists := markMap[mark]; exists {
|
||||||
|
mark++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := netlink.NewRule()
|
||||||
|
rule.Mark = mark
|
||||||
|
rule.Table = table
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error while getting free table: %w", err)
|
return fmt.Errorf("error while getting free table: %w", err)
|
||||||
}
|
}
|
||||||
|
47
kvas2.go
47
kvas2.go
@ -5,13 +5,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"kvas2-go/dns-proxy"
|
||||||
"kvas2-go/models"
|
"kvas2-go/models"
|
||||||
"kvas2-go/pkg/dns-proxy"
|
|
||||||
"kvas2-go/pkg/iptables-helper"
|
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ type App struct {
|
|||||||
Config Config
|
Config Config
|
||||||
|
|
||||||
DNSProxy *dnsProxy.DNSProxy
|
DNSProxy *dnsProxy.DNSProxy
|
||||||
DNSOverrider *iptablesHelper.DNSOverrider
|
IPTables *iptables.IPTables
|
||||||
Records *Records
|
Records *Records
|
||||||
Groups map[int]*Group
|
Groups map[int]*Group
|
||||||
|
|
||||||
@ -72,13 +73,28 @@ func (a *App) Listen(ctx context.Context) []error {
|
|||||||
newCtx, cancel := context.WithCancel(ctx)
|
newCtx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := a.DNSOverrider.Enable(); err != nil {
|
chainName := fmt.Sprintf("%sDNSOVERRIDER", a.Config.ChainPostfix)
|
||||||
handleError(fmt.Errorf("failed to override DNS: %w", err))
|
|
||||||
|
err := a.IPTables.ClearChain("nat", chainName)
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to clear chain: %w", err))
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.IPTables.AppendUnique("nat", chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(a.Config.ListenPort)))
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to create rule: %w", err))
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", chainName)
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to linking chain: %w", err))
|
||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
for idx, _ := range a.Groups {
|
for idx, _ := range a.Groups {
|
||||||
err := a.Groups[idx].Enable()
|
err = a.Groups[idx].Enable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(fmt.Errorf("failed to enable group: %w", err))
|
handleError(fmt.Errorf("failed to enable group: %w", err))
|
||||||
return errs
|
return errs
|
||||||
@ -97,15 +113,23 @@ func (a *App) Listen(ctx context.Context) []error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for idx, _ := range a.Groups {
|
for idx, _ := range a.Groups {
|
||||||
err := a.Groups[idx].Disable()
|
err = a.Groups[idx].Disable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handleError(fmt.Errorf("failed to disable group: %w", err))
|
handleError(fmt.Errorf("failed to disable group: %w", err))
|
||||||
return errs
|
return errs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := a.DNSOverrider.Disable(); err != nil {
|
err = a.IPTables.DeleteIfExists("nat", "PREROUTING", "-j", chainName)
|
||||||
handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err))
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to unlinking chain: %w", err))
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
|
||||||
|
err = a.IPTables.ClearAndDeleteChain("nat", chainName)
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to delete chain: %w", err))
|
||||||
|
return errs
|
||||||
}
|
}
|
||||||
|
|
||||||
return errs
|
return errs
|
||||||
@ -221,10 +245,11 @@ func New(config Config) (*App, error) {
|
|||||||
|
|
||||||
app.Records = NewRecords()
|
app.Records = NewRecords()
|
||||||
|
|
||||||
app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%sDNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort)
|
ipt, err := iptables.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
|
return nil, fmt.Errorf("iptables init fail: %w", err)
|
||||||
}
|
}
|
||||||
|
app.IPTables = ipt
|
||||||
|
|
||||||
app.Groups = make(map[int]*Group)
|
app.Groups = make(map[int]*Group)
|
||||||
|
|
||||||
|
@ -1,118 +0,0 @@
|
|||||||
package ipHelper
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
"github.com/vishvananda/netlink/nl"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrMaxTableSize = errors.New("max table size")
|
|
||||||
ErrMaxFwMarkSize = errors.New("max fwmark size")
|
|
||||||
)
|
|
||||||
|
|
||||||
func ExecIp(args ...string) ([]byte, error) {
|
|
||||||
cmd := exec.Command("ip", args...)
|
|
||||||
var out bytes.Buffer
|
|
||||||
cmd.Stdout = &out
|
|
||||||
err := cmd.Run()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return out.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUsedFwMarks() ([]uint32, error) {
|
|
||||||
markMap := make(map[uint32]struct{})
|
|
||||||
|
|
||||||
rules, err := netlink.RuleList(nl.FAMILY_ALL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error while getting rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
markMap[rule.Mark] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
marks := make([]uint32, len(markMap))
|
|
||||||
counter := 0
|
|
||||||
for mark, _ := range markMap {
|
|
||||||
marks[counter] = mark
|
|
||||||
counter++
|
|
||||||
}
|
|
||||||
|
|
||||||
return marks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUnusedFwMark(startFrom uint32) (uint32, error) {
|
|
||||||
usedFwMarks, err := GetUsedFwMarks()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error while getting used fwmarks: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fwmark := startFrom
|
|
||||||
for slices.Contains(usedFwMarks, fwmark) {
|
|
||||||
fwmark++
|
|
||||||
if fwmark == 0xFFFFFFFF {
|
|
||||||
return 0, ErrMaxFwMarkSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fwmark, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUsedTables() ([]int, error) {
|
|
||||||
tableMap := map[int]struct{}{
|
|
||||||
0: {},
|
|
||||||
253: {},
|
|
||||||
254: {},
|
|
||||||
255: {},
|
|
||||||
}
|
|
||||||
|
|
||||||
routes, err := netlink.RouteListFiltered(nl.FAMILY_ALL, &netlink.Route{}, netlink.RT_FILTER_TABLE)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error while getting routes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range routes {
|
|
||||||
tableMap[route.Table] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := netlink.RuleList(nl.FAMILY_ALL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error while getting rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
tableMap[rule.Table] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
tables := make([]int, len(tableMap))
|
|
||||||
counter := 0
|
|
||||||
for table, _ := range tableMap {
|
|
||||||
tables[counter] = table
|
|
||||||
counter++
|
|
||||||
}
|
|
||||||
|
|
||||||
return tables, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetUnusedTable(startFrom int) (int, error) {
|
|
||||||
usedTables, err := GetUsedTables()
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error while getting used tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tableID := startFrom
|
|
||||||
for slices.Contains(usedTables, tableID) {
|
|
||||||
tableID++
|
|
||||||
if tableID > 0x3FF {
|
|
||||||
return 0, ErrMaxTableSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tableID, nil
|
|
||||||
}
|
|
@ -1,60 +0,0 @@
|
|||||||
package iptablesHelper
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DNSOverrider struct {
|
|
||||||
ipt *iptables.IPTables
|
|
||||||
chainName string
|
|
||||||
destPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o DNSOverrider) Enable() error {
|
|
||||||
err := o.ipt.ClearChain("nat", o.chainName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to clear chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = o.ipt.AppendUnique("nat", o.chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(o.destPort)))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = o.ipt.InsertUnique("nat", "PREROUTING", 1, "-j", o.chainName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to linking chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o DNSOverrider) Disable() error {
|
|
||||||
err := o.ipt.DeleteIfExists("nat", "PREROUTING", "-j", o.chainName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to unlinking chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = o.ipt.ClearAndDeleteChain("nat", o.chainName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete chain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDNSOverrider(chainName string, destPort uint16) (*DNSOverrider, error) {
|
|
||||||
ipt, err := iptables.New()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("iptables init fail: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &DNSOverrider{
|
|
||||||
ipt: ipt,
|
|
||||||
chainName: chainName,
|
|
||||||
destPort: destPort,
|
|
||||||
}, nil
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user