MagiTrickle/kvas2.go

472 lines
10 KiB
Go
Raw Normal View History

2024-08-26 19:10:40 +03:00
package main
import (
"context"
2024-08-27 03:07:58 +03:00
"errors"
2024-08-26 19:10:40 +03:00
"fmt"
2024-08-30 04:44:08 +03:00
"net"
2024-09-06 16:50:56 +03:00
"strings"
2024-08-27 03:07:58 +03:00
"sync"
"time"
2024-09-05 03:53:10 +03:00
"kvas2-go/dns-proxy"
2024-08-26 19:10:40 +03:00
"kvas2-go/models"
2024-09-06 14:24:55 +03:00
"kvas2-go/netfilter-helper"
2024-09-04 09:15:03 +03:00
"github.com/rs/zerolog/log"
2024-09-06 15:48:15 +03:00
"github.com/vishvananda/netlink"
2024-08-27 03:07:58 +03:00
)
var (
2024-08-30 04:29:05 +03:00
ErrAlreadyRunning = errors.New("already running")
2024-08-27 03:07:58 +03:00
ErrGroupIDConflict = errors.New("group id conflict")
2024-08-26 19:10:40 +03:00
)
type Config struct {
MinimalTTL time.Duration
ChainPostfix string
2024-08-30 04:40:46 +03:00
IpSetPostfix string
2024-08-26 19:10:40 +03:00
TargetDNSServerAddress string
ListenPort uint16
2024-09-05 09:53:24 +03:00
UseSoftwareRouting bool
2024-08-26 19:10:40 +03:00
}
type App struct {
Config Config
2024-09-14 15:16:50 +03:00
DNSProxy *dnsProxy.DNSProxy
NetfilterHelper4 *netfilterHelper.NetfilterHelper
Records *Records
Groups map[int]*Group
2024-08-30 04:29:05 +03:00
2024-09-14 15:16:50 +03:00
isRunning bool
dnsOverrider4 *netfilterHelper.PortRemap
2024-08-26 19:10:40 +03:00
}
func (a *App) Listen(ctx context.Context) []error {
2024-08-30 04:29:05 +03:00
if a.isRunning {
return []error{ErrAlreadyRunning}
}
a.isRunning = true
defer func() { a.isRunning = false }()
2024-08-26 19:10:40 +03:00
errs := make([]error, 0)
isError := make(chan struct{})
var once sync.Once
var errsMu sync.Mutex
handleError := func(err error) {
errsMu.Lock()
defer errsMu.Unlock()
errs = append(errs, err)
once.Do(func() { close(isError) })
}
2024-09-06 14:24:55 +03:00
handleErrors := func(errs2 []error) {
errsMu.Lock()
defer errsMu.Unlock()
errs = append(errs, errs2...)
once.Do(func() { close(isError) })
}
2024-08-26 19:10:40 +03:00
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
handleError(err)
} else {
handleError(fmt.Errorf("%v", r))
}
}
}()
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
2024-09-14 15:16:50 +03:00
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPostfix), 53, a.Config.ListenPort)
err := a.dnsOverrider4.Enable()
2024-08-26 19:10:40 +03:00
2024-09-06 14:52:42 +03:00
for _, group := range a.Groups {
err = group.Enable()
2024-08-27 03:19:10 +03:00
if err != nil {
2024-08-30 04:30:33 +03:00
handleError(fmt.Errorf("failed to enable group: %w", err))
2024-08-27 03:19:10 +03:00
return errs
}
}
2024-08-26 19:10:40 +03:00
go func() {
if err := a.DNSProxy.Listen(newCtx); err != nil {
handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err))
}
}()
2024-09-06 15:48:15 +03:00
link := make(chan netlink.LinkUpdate)
done := make(chan struct{})
netlink.LinkSubscribe(link, done)
2024-09-06 16:50:56 +03:00
exitListenerLoop := false
listener, err := net.Listen("unix", "/opt/var/run/kvas2-go.sock")
if err != nil {
handleError(fmt.Errorf("error while serve UNIX socket: %v", err))
2024-10-20 22:49:23 +03:00
return errs
2024-09-06 16:50:56 +03:00
}
defer listener.Close()
go func() {
for {
if exitListenerLoop {
break
}
conn, err := listener.Accept()
if err != nil {
log.Error().Err(err).Msg("error while listening unix socket")
}
go func(conn net.Conn) {
defer conn.Close()
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
return
}
args := strings.Split(string(buf[:n]), ":")
if len(args) == 3 && args[0] == "netfilter.d" {
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
2024-09-14 15:16:50 +03:00
if a.dnsOverrider4.Enabled {
err := a.dnsOverrider4.PutIPTable(args[2])
2024-09-06 16:50:56 +03:00
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
for _, group := range a.Groups {
if group.ifaceToIPSet.Enabled {
err := group.ifaceToIPSet.PutIPTable(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
}
}
}
}(conn)
}
}()
2024-09-06 15:48:15 +03:00
Loop:
for {
select {
case event := <-link:
switch event.Change {
case 0x00000001:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Str("operstatestr", event.Attrs().OperState.String()).
Int("operstate", int(event.Attrs().OperState)).
Msg("interface change")
if event.Attrs().OperState != netlink.OperDown {
for _, group := range a.Groups {
if group.Interface == event.Link.Attrs().Name {
err = group.ifaceToIPSet.IfaceHandle()
if err != nil {
log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
}
}
}
}
case 0xFFFFFFFF:
switch event.Header.Type {
case 16:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface add")
case 17:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Int("type", int(event.Header.Type)).
Msg("interface del")
}
}
case <-ctx.Done():
break Loop
case <-isError:
break Loop
}
2024-08-26 19:10:40 +03:00
}
2024-09-14 20:54:26 +03:00
exitListenerLoop = true
2024-09-06 15:48:15 +03:00
close(done)
2024-09-14 15:16:50 +03:00
errs2 := a.dnsOverrider4.Disable()
2024-09-06 14:24:55 +03:00
if errs2 != nil {
handleErrors(errs2)
2024-08-27 03:19:10 +03:00
}
2024-09-06 14:52:42 +03:00
for _, group := range a.Groups {
errs2 = group.Disable()
2024-09-06 14:24:55 +03:00
if errs2 != nil {
handleErrors(errs2)
2024-08-30 04:29:05 +03:00
}
}
2024-09-06 14:24:55 +03:00
return errs
2024-08-27 03:07:58 +03:00
}
2024-09-14 16:13:59 +03:00
func (a *App) AddGroup(group *models.Group) error {
if _, exists := a.Groups[group.ID]; exists {
return ErrGroupIDConflict
}
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID)
2024-09-14 18:20:44 +03:00
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
if err != nil {
return fmt.Errorf("failed to initialize ipset: %w", err)
}
2024-09-14 16:13:59 +03:00
grp := &Group{
Group: group,
iptables: a.NetfilterHelper4.IPTables,
2024-09-14 18:20:44 +03:00
ipset: ipset,
2024-09-14 16:13:59 +03:00
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false),
}
a.Groups[group.ID] = grp
2024-09-14 18:20:44 +03:00
return a.SyncGroup(grp)
}
2024-09-14 16:13:59 +03:00
2024-09-14 18:20:44 +03:00
func (a *App) SyncGroup(group *Group) error {
2024-09-14 16:13:59 +03:00
processedDomains := make(map[string]struct{})
2024-09-14 18:20:44 +03:00
newIpsetAddressesMap := make(map[string]time.Duration)
now := time.Now()
oldIpsetAddresses, err := group.ListIPv4()
if err != nil {
return fmt.Errorf("failed to get old ipset list: %w", err)
}
knownDomains := a.Records.ListKnownDomains()
for _, domain := range group.Domains {
if !domain.IsEnabled() {
2024-09-14 16:13:59 +03:00
continue
}
2024-09-14 18:20:44 +03:00
for _, domainName := range knownDomains {
2024-09-14 16:13:59 +03:00
if !domain.IsMatch(domainName) {
continue
}
cnames := a.Records.GetCNameRecords(domainName, true)
2024-09-14 18:20:44 +03:00
if len(cnames) == 0 {
continue
}
2024-09-14 16:13:59 +03:00
for _, cname := range cnames {
processedDomains[cname] = struct{}{}
}
addresses := a.Records.GetARecords(domainName)
for _, address := range addresses {
2024-09-14 18:20:44 +03:00
ttl := now.Sub(address.Deadline)
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
newIpsetAddressesMap[string(address.Address)] = ttl
2024-09-14 16:13:59 +03:00
}
}
2024-09-14 18:20:44 +03:00
}
}
for addr, ttl := range newIpsetAddressesMap {
if _, exists := oldIpsetAddresses[addr]; exists {
continue
}
ip := net.IP(addr)
err = group.AddIPv4(ip, ttl)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to add address")
}
}
for addr, _ := range oldIpsetAddresses {
if _, exists := newIpsetAddressesMap[addr]; exists {
continue
}
ip := net.IP(addr)
err = group.DelIPv4(ip)
if err != nil {
log.Error().
Str("address", ip.String()).
Err(err).
Msg("failed to delete address")
2024-09-14 19:26:23 +03:00
} else {
log.Trace().
Str("address", ip.String()).
Err(err).
Msg("add address")
2024-09-14 16:13:59 +03:00
}
}
return nil
}
2024-08-30 04:44:08 +03:00
func (a *App) ListInterfaces() ([]net.Interface, error) {
interfaceNames := make([]net.Interface, 0)
interfaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("failed to get interfaces: %w", err)
}
for _, iface := range interfaces {
if iface.Flags&net.FlagPointToPoint == 0 {
continue
}
interfaceNames = append(interfaceNames, iface)
}
return interfaceNames, nil
}
2024-08-27 01:44:17 +03:00
func (a *App) processARecord(aRecord dnsProxy.Address) {
2024-09-04 09:15:03 +03:00
log.Trace().
Str("name", aRecord.Name.String()).
Str("address", aRecord.Address.String()).
Int("ttl", int(aRecord.TTL)).
Msg("processing a record")
2024-08-27 01:44:17 +03:00
ttlDuration := time.Duration(aRecord.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
2024-09-14 16:13:59 +03:00
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
2024-08-27 01:44:17 +03:00
2024-09-14 16:13:59 +03:00
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
2024-08-27 01:44:17 +03:00
for _, group := range a.Groups {
2024-09-14 19:26:23 +03:00
Domain:
2024-09-14 18:20:44 +03:00
for _, domain := range group.Domains {
2024-09-14 18:36:23 +03:00
if !domain.IsEnabled() {
continue
}
2024-09-14 18:20:44 +03:00
for _, name := range names {
if !domain.IsMatch(name) {
continue
}
err := group.AddIPv4(aRecord.Address, ttlDuration)
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
Err(err).
Msg("failed to add address")
2024-09-14 19:26:23 +03:00
} else {
log.Trace().
Str("address", aRecord.Address.String()).
Str("aRecordDomain", aRecord.Name.String()).
Str("cNameDomain", name).
Err(err).
Msg("add address")
2024-09-14 18:20:44 +03:00
}
2024-09-14 19:26:23 +03:00
break Domain
2024-09-14 18:20:44 +03:00
}
2024-08-27 01:44:17 +03:00
}
}
}
func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
2024-09-14 18:20:44 +03:00
log.Trace().
Str("name", cNameRecord.Name.String()).
Str("cname", cNameRecord.CName.String()).
Int("ttl", int(cNameRecord.TTL)).
Msg("processing cname record")
2024-08-27 01:44:17 +03:00
ttlDuration := time.Duration(cNameRecord.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
2024-09-14 16:13:59 +03:00
a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
2024-09-14 19:26:23 +03:00
// TODO: Optimization
now := time.Now()
aRecords := a.Records.GetARecords(cNameRecord.Name.String())
names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true)
for _, group := range a.Groups {
Domain:
for _, domain := range group.Domains {
if !domain.IsEnabled() {
continue
}
for _, name := range names {
if !domain.IsMatch(name) {
continue
}
for _, aRecord := range aRecords {
err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline))
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
Err(err).
Msg("failed to add address")
} else {
log.Trace().
Str("address", aRecord.Address.String()).
Str("cNameDomain", name).
Err(err).
Msg("add address")
}
}
continue Domain
}
}
}
2024-08-27 01:44:17 +03:00
}
2024-08-30 03:37:00 +03:00
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {
switch v := rr.(type) {
case dnsProxy.Address:
2024-09-14 18:20:44 +03:00
// TODO: Optimize equals domain A records
2024-08-30 03:37:00 +03:00
a.processARecord(v)
case dnsProxy.CName:
a.processCNameRecord(v)
default:
2024-08-26 19:10:40 +03:00
}
2024-08-30 03:37:00 +03:00
}
2024-08-26 19:10:40 +03:00
2024-08-30 03:37:00 +03:00
func (a *App) handleMessage(msg *dnsProxy.Message) {
for _, rr := range msg.AN {
a.handleRecord(rr)
2024-08-26 19:10:40 +03:00
}
2024-08-30 03:37:00 +03:00
for _, rr := range msg.NS {
a.handleRecord(rr)
2024-08-26 19:10:40 +03:00
}
2024-08-30 03:37:00 +03:00
for _, rr := range msg.AR {
a.handleRecord(rr)
2024-08-26 19:10:40 +03:00
}
}
func New(config Config) (*App, error) {
var err error
app := &App{}
app.Config = config
app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress)
2024-08-30 03:37:00 +03:00
app.DNSProxy.MsgHandler = app.handleMessage
2024-08-26 19:10:40 +03:00
app.Records = NewRecords()
2024-09-14 15:16:50 +03:00
nh4, err := netfilterHelper.New(false)
2024-08-26 19:10:40 +03:00
if err != nil {
2024-09-06 14:24:55 +03:00
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
2024-08-26 19:10:40 +03:00
}
2024-09-14 15:16:50 +03:00
app.NetfilterHelper4 = nh4
2024-08-26 19:10:40 +03:00
2024-08-27 03:19:10 +03:00
app.Groups = make(map[int]*Group)
2024-08-26 19:10:40 +03:00
return app, nil
}