Compare commits

...

4 Commits

Author SHA1 Message Date
60e1f4c540 change 128 bit IDs to 32 bit 2025-02-11 23:19:16 +03:00
184956829b build script 2025-02-11 23:05:32 +03:00
d5b4645ec0 config extension 2025-02-11 23:05:24 +03:00
fdb1038ba9 group refactoring 2025-02-11 22:44:07 +03:00
10 changed files with 228 additions and 202 deletions

1
build_mipsel.sh Normal file
View File

@ -0,0 +1 @@
GOOS=linux GOMIPS=softfloat GOARCH=mipsle go build -v -a -trimpath -ldflags="-w -s" .

View File

@ -1,79 +0,0 @@
package main
import (
"fmt"
"net"
"time"
"kvas2-go/models"
"kvas2-go/netfilter-helper"
"github.com/coreos/go-iptables/iptables"
)
type Group struct {
*models.Group
Enabled bool
iptables *iptables.IPTables
ipset *netfilterHelper.IPSet
ipsetToLink *netfilterHelper.IPSetToLink
}
func (g *Group) AddIP(address net.IP, ttl time.Duration) error {
ttlSeconds := uint32(ttl.Seconds())
return g.ipset.AddIP(address, &ttlSeconds)
}
func (g *Group) DelIP(address net.IP) error {
return g.ipset.DelIP(address)
}
func (g *Group) ListIP() (map[string]*uint32, error) {
return g.ipset.ListIPs()
}
func (g *Group) Enable() error {
if g.Enabled {
return nil
}
defer func() {
if !g.Enabled {
_ = g.Disable()
}
}()
if g.FixProtect {
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
if err != nil {
return fmt.Errorf("failed to fix protect: %w", err)
}
}
err := g.ipsetToLink.Enable()
if err != nil {
return err
}
g.Enabled = true
return nil
}
func (g *Group) Disable() []error {
var errs []error
if !g.Enabled {
return nil
}
err := g.ipsetToLink.Disable()
if err != nil {
errs = append(errs, err...)
}
g.Enabled = false
return errs
}

183
group/group.go Normal file
View File

@ -0,0 +1,183 @@
package group
import (
"fmt"
"net"
"time"
"kvas2-go/models"
"kvas2-go/netfilter-helper"
"kvas2-go/records"
"github.com/coreos/go-iptables/iptables"
"github.com/rs/zerolog/log"
"github.com/vishvananda/netlink"
)
type Group struct {
*models.Group
enabled bool
iptables *iptables.IPTables
ipset *netfilterHelper.IPSet
ipsetToLink *netfilterHelper.IPSetToLink
}
func (g *Group) AddIP(address net.IP, ttl uint32) error {
return g.ipset.AddIP(address, &ttl)
}
func (g *Group) DelIP(address net.IP) error {
return g.ipset.DelIP(address)
}
func (g *Group) ListIP() (map[string]*uint32, error) {
return g.ipset.ListIPs()
}
func (g *Group) Enable() error {
if g.enabled {
return nil
}
defer func() {
if !g.enabled {
_ = g.Disable()
}
}()
if g.FixProtect {
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
if err != nil {
return fmt.Errorf("failed to fix protect: %w", err)
}
}
err := g.ipsetToLink.Enable()
if err != nil {
return err
}
g.enabled = true
return nil
}
func (g *Group) Disable() []error {
var errs []error
if !g.enabled {
return nil
}
if g.FixProtect {
err := g.iptables.Delete("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
if err != nil {
errs = append(errs, fmt.Errorf("failed to remove fix protect: %w", err))
}
}
err := g.ipsetToLink.Disable()
if err != nil {
errs = append(errs, err...)
}
g.enabled = false
return errs
}
func (g *Group) Sync(records *records.Records) error {
now := time.Now()
addresses := make(map[string]uint32)
knownDomains := records.ListKnownDomains()
for _, domain := range g.Rules {
if !domain.IsEnabled() {
continue
}
for _, domainName := range knownDomains {
if !domain.IsMatch(domainName) {
continue
}
domainAddresses := records.GetARecords(domainName)
for _, address := range domainAddresses {
ttl := uint32(now.Sub(address.Deadline).Seconds())
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
addresses[string(address.Address)] = ttl
}
}
}
}
currentAddresses, err := g.ListIP()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
for addr, ttl := range addresses {
// TODO: Check TTL
if _, exists := currentAddresses[addr]; exists {
continue
}
ip := net.IP(addr)
err = g.AddIP(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to add address")
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("add address")
}
}
for addr := range currentAddresses {
if _, ok := addresses[addr]; ok {
continue
}
ip := net.IP(addr)
err = g.DelIP(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to delete address")
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("del address")
}
}
return nil
}
func (g *Group) NetfilterDHook(table string) error {
return g.ipsetToLink.NetfilterDHook(table)
}
func (g *Group) LinkUpdateHook(event netlink.LinkUpdate) error {
return g.ipsetToLink.LinkUpdateHook(event)
}
func NewGroup(group *models.Group, nh4 *netfilterHelper.NetfilterHelper, chainPrefix, ipsetNamePrefix string) (*Group, error) {
ipsetName := fmt.Sprintf("%s%8x", ipsetNamePrefix, group.ID)
ipset, err := nh4.IPSet(ipsetName)
if err != nil {
return nil, fmt.Errorf("failed to initialize ipset: %w", err)
}
ipsetToLink := nh4.IPSetToLink(fmt.Sprintf("%s%8x", chainPrefix, group.ID), group.Interface, ipsetName)
return &Group{
Group: group,
iptables: nh4.IPTables,
ipset: ipset,
ipsetToLink: ipsetToLink,
}, nil
}

134
kvas2.go
View File

@ -2,8 +2,11 @@ package main
import ( import (
"context" "context"
"encoding/binary"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -11,11 +14,11 @@ import (
"time" "time"
"kvas2-go/dns-mitm-proxy" "kvas2-go/dns-mitm-proxy"
"kvas2-go/group"
"kvas2-go/models" "kvas2-go/models"
"kvas2-go/netfilter-helper" "kvas2-go/netfilter-helper"
"kvas2-go/records" "kvas2-go/records"
"github.com/google/uuid"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
@ -27,12 +30,19 @@ var (
ErrGroupIDConflict = errors.New("group id conflict") ErrGroupIDConflict = errors.New("group id conflict")
) )
func randomId() [4]byte {
id := make([]byte, 4)
binary.BigEndian.PutUint32(id, rand.Uint32())
return [4]byte(id)
}
type Config struct { type Config struct {
MinimalTTL time.Duration AdditionalTTL uint32
ChainPrefix string ChainPrefix string
IpSetPrefix string IpSetPrefix string
LinkName string LinkName string
TargetDNSServerAddress string TargetDNSServerAddress string
TargetDNSServerPort uint16
ListenDNSPort uint16 ListenDNSPort uint16
} }
@ -43,7 +53,7 @@ type App struct {
NetfilterHelper4 *netfilterHelper.NetfilterHelper NetfilterHelper4 *netfilterHelper.NetfilterHelper
NetfilterHelper6 *netfilterHelper.NetfilterHelper NetfilterHelper6 *netfilterHelper.NetfilterHelper
Records *records.Records Records *records.Records
Groups map[uuid.UUID]*Group Groups map[[4]byte]*group.Group
Link netlink.Link Link netlink.Link
@ -68,9 +78,9 @@ func (a *App) handleLink(event netlink.LinkUpdate) {
continue continue
} }
err := group.ipsetToLink.LinkUpdateHook() err := group.LinkUpdateHook(event)
if err != nil { if err != nil {
log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up") log.Error().Str("group", hex.EncodeToString(group.ID[:])).Err(err).Msg("error while handling interface up")
} }
} }
} }
@ -204,16 +214,16 @@ func (a *App) start(ctx context.Context) (err error) {
args := strings.Split(string(buf[:n]), ":") args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" { if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event") log.Debug().Str("table", args[2]).Msg("netfilter.d event")
err = a.dnsOverrider4.NetfilerDHook(args[2]) err = a.dnsOverrider4.NetfilterDHook(args[2])
if err != nil { if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
} }
err = a.dnsOverrider6.NetfilerDHook(args[2]) err = a.dnsOverrider6.NetfilterDHook(args[2])
if err != nil { if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
} }
for _, group := range a.Groups { for _, group := range a.Groups {
err := group.ipsetToLink.NetfilerDHook(args[2]) err := group.NetfilterDHook(args[2])
if err != nil { if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
} }
@ -276,97 +286,17 @@ func (a *App) Start(ctx context.Context) (err error) {
return err return err
} }
func (a *App) AddGroup(group *models.Group) error { func (a *App) AddGroup(groupModel *models.Group) error {
if _, exists := a.Groups[group.ID]; exists { if _, exists := a.Groups[groupModel.ID]; exists {
return ErrGroupIDConflict return ErrGroupIDConflict
} }
ipsetName := fmt.Sprintf("%s%8x", a.Config.IpSetPrefix, group.ID.ID()) grp, err := group.NewGroup(groupModel, a.NetfilterHelper4, a.Config.ChainPrefix, a.Config.IpSetPrefix)
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize ipset: %w", err) return fmt.Errorf("failed to create group: %w", err)
}
grp := &Group{
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ipsetToLink: a.NetfilterHelper4.IPSetToLink(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
} }
a.Groups[grp.ID] = grp a.Groups[grp.ID] = grp
return a.SyncGroup(grp) return grp.Sync(a.Records)
}
func (a *App) SyncGroup(group *Group) error {
now := time.Now()
addresses := make(map[string]time.Duration)
knownDomains := a.Records.ListKnownDomains()
for _, domain := range group.Rules {
if !domain.IsEnabled() {
continue
}
for _, domainName := range knownDomains {
if !domain.IsMatch(domainName) {
continue
}
domainAddresses := a.Records.GetARecords(domainName)
for _, address := range domainAddresses {
ttl := now.Sub(address.Deadline)
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
addresses[string(address.Address)] = ttl
}
}
}
}
currentAddresses, err := group.ListIP()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
for addr, ttl := range addresses {
// TODO: Check TTL
if _, exists := currentAddresses[addr]; exists {
continue
}
ip := net.IP(addr)
err = group.AddIP(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to add address")
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("add address")
}
}
for addr := range currentAddresses {
if _, ok := addresses[addr]; ok {
continue
}
ip := net.IP(addr)
err = group.DelIP(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to delete address")
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("del address")
}
}
return nil
} }
func (a *App) ListInterfaces() ([]net.Interface, error) { func (a *App) ListInterfaces() ([]net.Interface, error) {
@ -395,10 +325,7 @@ func (a *App) processARecord(aRecord dns.A) {
Int("ttl", int(aRecord.Hdr.Ttl)). Int("ttl", int(aRecord.Hdr.Ttl)).
Msg("processing a record") Msg("processing a record")
ttlDuration := time.Duration(aRecord.Hdr.Ttl) * time.Second ttlDuration := aRecord.Hdr.Ttl + a.Config.AdditionalTTL
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration) a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
@ -441,10 +368,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
Int("ttl", int(cNameRecord.Hdr.Ttl)). Int("ttl", int(cNameRecord.Hdr.Ttl)).
Msg("processing cname record") Msg("processing cname record")
ttlDuration := time.Duration(cNameRecord.Hdr.Ttl) * time.Second ttlDuration := cNameRecord.Hdr.Ttl + a.Config.AdditionalTTL
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration) a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
@ -463,7 +387,7 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
continue continue
} }
for _, aRecord := range aRecords { for _, aRecord := range aRecords {
err := group.AddIP(aRecord.Address, now.Sub(aRecord.Deadline)) err := group.AddIP(aRecord.Address, uint32(now.Sub(aRecord.Deadline).Seconds()))
if err != nil { if err != nil {
log.Error(). log.Error().
Str("address", aRecord.Address.String()). Str("address", aRecord.Address.String()).
@ -508,7 +432,7 @@ func New(config Config) (*App, error) {
app.DNSMITM = dnsMitmProxy.New() app.DNSMITM = dnsMitmProxy.New()
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
app.DNSMITM.TargetDNSServerPort = 53 app.DNSMITM.TargetDNSServerPort = app.Config.TargetDNSServerPort
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) { app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
// TODO: Need to understand why it not works in proxy mode // TODO: Need to understand why it not works in proxy mode
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR { if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
@ -544,7 +468,7 @@ func New(config Config) (*App, error) {
} }
app.Records = records.New() app.Records = records.New()
app.Groups = make(map[uuid.UUID]*Group, 0) app.Groups = make(map[[4]byte]*group.Group)
link, err := netlink.LinkByName(app.Config.LinkName) link, err := netlink.LinkByName(app.Config.LinkName)
if err != nil { if err != nil {
@ -572,7 +496,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("failed to clear iptables: %w", err) return nil, fmt.Errorf("failed to clear iptables: %w", err)
} }
app.Groups = make(map[uuid.UUID]*Group) app.Groups = make(map[[4]byte]*group.Group)
return app, nil return app, nil
} }

View File

@ -2,12 +2,11 @@ package main
import ( import (
"context" "context"
"github.com/rs/zerolog"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -15,12 +14,13 @@ func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
app, err := New(Config{ app, err := New(Config{
MinimalTTL: time.Hour, AdditionalTTL: 216000, // 1 hour
ChainPrefix: "KVAS2_", ChainPrefix: "KVAS2_",
IpSetPrefix: "kvas2_", IpSetPrefix: "kvas2_",
LinkName: "br0", LinkName: "br0",
TargetDNSServerAddress: "127.0.0.1", TargetDNSServerAddress: "127.0.0.1",
ListenDNSPort: 7553, TargetDNSServerPort: 53,
ListenDNSPort: 3553,
}) })
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("failed to initialize application") log.Fatal().Err(err).Msg("failed to initialize application")

View File

@ -1,9 +1,7 @@
package models package models
import "github.com/google/uuid"
type Group struct { type Group struct {
ID uuid.UUID ID [4]byte
Name string Name string
Interface string Interface string
Rules []*Rule Rules []*Rule

View File

@ -4,11 +4,10 @@ import (
"regexp" "regexp"
"github.com/IGLOU-EU/go-wildcard/v2" "github.com/IGLOU-EU/go-wildcard/v2"
"github.com/google/uuid"
) )
type Rule struct { type Rule struct {
ID uuid.UUID ID [4]byte
Name string Name string
Type string Type string
Rule string Rule string

View File

@ -265,21 +265,21 @@ func (r *IPSetToLink) Disable() []error {
return errs return errs
} }
func (r *IPSetToLink) NetfilerDHook(table string) error { func (r *IPSetToLink) NetfilterDHook(table string) error {
if !r.enabled { if !r.enabled {
return nil return nil
} }
return r.insertIPTablesRules(table) return r.insertIPTablesRules(table)
} }
func (r *IPSetToLink) LinkUpdateHook() error { func (r *IPSetToLink) LinkUpdateHook(event netlink.LinkUpdate) error {
if !r.enabled { if !r.enabled || event.Change != 1 || event.Link.Attrs().Name != r.IfaceName || event.Attrs().OperState != netlink.OperUp {
return nil return nil
} }
return r.insertIPRoute() return r.insertIPRoute()
} }
func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string, softwareMode bool) *IPSetToLink { func (nh *NetfilterHelper) IPSetToLink(name string, ifaceName, ipsetName string) *IPSetToLink {
return &IPSetToLink{ return &IPSetToLink{
IPTables: nh.IPTables, IPTables: nh.IPTables,
ChainName: name, ChainName: name,

View File

@ -105,7 +105,7 @@ func (r *PortRemap) Disable() []error {
return errs return errs
} }
func (r *PortRemap) NetfilerDHook(table string) error { func (r *PortRemap) NetfilterDHook(table string) error {
if !r.enabled { if !r.enabled {
return nil return nil
} }

View File

@ -22,7 +22,7 @@ type Records struct {
records map[string]interface{} records map[string]interface{}
} }
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) { func (r *Records) AddCNameRecord(domainName, alias string, ttl uint32) {
if domainName == alias { if domainName == alias {
return return
} }
@ -30,16 +30,16 @@ func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
r.mux.Lock() r.mux.Lock()
r.records[domainName] = &CNameRecord{ r.records[domainName] = &CNameRecord{
Alias: alias, Alias: alias,
Deadline: time.Now().Add(ttl), Deadline: time.Now().Add(time.Duration(ttl) * time.Second),
} }
r.mux.Unlock() r.mux.Unlock()
} }
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) { func (r *Records) AddARecord(domainName string, addr net.IP, ttl uint32) {
r.mux.Lock() r.mux.Lock()
defer r.mux.Unlock() defer r.mux.Unlock()
deadline := time.Now().Add(ttl) deadline := time.Now().Add(time.Duration(ttl) * time.Second)
aRecords, _ := r.records[domainName].([]*ARecord) aRecords, _ := r.records[domainName].([]*ARecord)
for _, aRecord := range aRecords { for _, aRecord := range aRecords {