506 lines
12 KiB
Go
506 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"kvas2-go/dns-proxy"
|
|
"kvas2-go/models"
|
|
"kvas2-go/netfilter-helper"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/vishvananda/netlink"
|
|
)
|
|
|
|
var (
|
|
ErrAlreadyRunning = errors.New("already running")
|
|
ErrGroupIDConflict = errors.New("group id conflict")
|
|
)
|
|
|
|
type Config struct {
|
|
MinimalTTL time.Duration
|
|
ChainPrefix string
|
|
IpSetPrefix string
|
|
TargetDNSServerAddress string
|
|
ListenPort uint16
|
|
UseSoftwareRouting bool
|
|
}
|
|
|
|
type App struct {
|
|
Config Config
|
|
|
|
DNSProxy *dnsProxy.DNSProxy
|
|
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
|
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
|
Records *Records
|
|
Groups map[int]*Group
|
|
|
|
isRunning bool
|
|
dnsOverrider4 *netfilterHelper.PortRemap
|
|
dnsOverrider6 *netfilterHelper.PortRemap
|
|
}
|
|
|
|
func (a *App) handleLink(event netlink.LinkUpdate) {
|
|
switch event.Change {
|
|
case 0x00000001:
|
|
log.Debug().
|
|
Str("interface", event.Link.Attrs().Name).
|
|
Str("operstatestr", event.Attrs().OperState.String()).
|
|
Int("operstate", int(event.Attrs().OperState)).
|
|
Msg("interface change")
|
|
if event.Attrs().OperState != netlink.OperDown {
|
|
for _, group := range a.Groups {
|
|
if group.Interface == event.Link.Attrs().Name {
|
|
err := group.ifaceToIPSet.IfaceHandle()
|
|
if err != nil {
|
|
log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
case 0xFFFFFFFF:
|
|
switch event.Header.Type {
|
|
case 16:
|
|
log.Debug().
|
|
Str("interface", event.Link.Attrs().Name).
|
|
Int("type", int(event.Header.Type)).
|
|
Msg("interface add")
|
|
case 17:
|
|
log.Debug().
|
|
Str("interface", event.Link.Attrs().Name).
|
|
Int("type", int(event.Header.Type)).
|
|
Msg("interface del")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *App) listen(ctx context.Context) (err error) {
|
|
errChan := make(chan error)
|
|
|
|
newCtx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
go func() {
|
|
err := a.DNSProxy.Listen(newCtx)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err)
|
|
}
|
|
}()
|
|
|
|
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort)
|
|
err = a.dnsOverrider4.Enable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
|
|
}
|
|
defer func() {
|
|
// TODO: Handle error
|
|
_ = a.dnsOverrider4.Disable()
|
|
}()
|
|
|
|
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort)
|
|
err = a.dnsOverrider6.Enable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
|
|
}
|
|
defer func() {
|
|
// TODO: Handle error
|
|
_ = a.dnsOverrider6.Disable()
|
|
}()
|
|
|
|
for _, group := range a.Groups {
|
|
err = group.Enable()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to enable group: %w", err)
|
|
}
|
|
}
|
|
defer func() {
|
|
for _, group := range a.Groups {
|
|
// TODO: Handle error
|
|
_ = group.Disable()
|
|
}
|
|
}()
|
|
|
|
socketPath := "/opt/var/run/kvas2-go.sock"
|
|
err = os.Remove(socketPath)
|
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
return fmt.Errorf("failed to remove existed UNIX socket: %w", err)
|
|
}
|
|
socket, err := net.Listen("unix", socketPath)
|
|
if err != nil {
|
|
return fmt.Errorf("error while serve UNIX socket: %v", err)
|
|
}
|
|
defer func() {
|
|
// TODO: Handle error
|
|
_ = socket.Close()
|
|
_ = os.Remove(socketPath)
|
|
}()
|
|
|
|
go func() {
|
|
for {
|
|
conn, err := socket.Accept()
|
|
if err != nil {
|
|
if !strings.Contains(err.Error(), "use of closed network connection") {
|
|
log.Error().Err(err).Msg("error while listening unix socket")
|
|
}
|
|
break
|
|
}
|
|
|
|
go func(conn net.Conn) {
|
|
defer func() {
|
|
// TODO: Handle error
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
buf := make([]byte, 1024)
|
|
n, err := conn.Read(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
args := strings.Split(string(buf[:n]), ":")
|
|
if len(args) == 3 && args[0] == "netfilter.d" {
|
|
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
|
|
if a.dnsOverrider4.Enabled {
|
|
err := a.dnsOverrider4.PutIPTable(args[2])
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
|
}
|
|
}
|
|
for _, group := range a.Groups {
|
|
if group.ifaceToIPSet.Enabled {
|
|
err := group.ifaceToIPSet.PutIPTable(args[2])
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}(conn)
|
|
}
|
|
}()
|
|
|
|
link := make(chan netlink.LinkUpdate)
|
|
done := make(chan struct{})
|
|
err = netlink.LinkSubscribe(link, done)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to subscribe to link updates: %w", err)
|
|
}
|
|
defer func() {
|
|
close(done)
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case event := <-link:
|
|
a.handleLink(event)
|
|
case err := <-errChan:
|
|
return err
|
|
case <-ctx.Done():
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *App) Listen(ctx context.Context) (err error) {
|
|
if a.isRunning {
|
|
return ErrAlreadyRunning
|
|
}
|
|
a.isRunning = true
|
|
defer func() {
|
|
a.isRunning = false
|
|
}()
|
|
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
var recoveredError error
|
|
var ok bool
|
|
if recoveredError, ok = r.(error); !ok {
|
|
recoveredError = fmt.Errorf("%v", r)
|
|
}
|
|
|
|
err = fmt.Errorf("recovered error: %w", recoveredError)
|
|
}
|
|
}()
|
|
|
|
appErr := a.listen(ctx)
|
|
if appErr != nil {
|
|
return appErr
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (a *App) AddGroup(group *models.Group) error {
|
|
if _, exists := a.Groups[group.ID]; exists {
|
|
return ErrGroupIDConflict
|
|
}
|
|
|
|
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPrefix, group.ID)
|
|
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize ipset: %w", err)
|
|
}
|
|
|
|
grp := &Group{
|
|
Group: group,
|
|
iptables: a.NetfilterHelper4.IPTables,
|
|
ipset: ipset,
|
|
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
|
|
}
|
|
a.Groups[group.ID] = grp
|
|
return a.SyncGroup(grp)
|
|
}
|
|
|
|
func (a *App) SyncGroup(group *Group) error {
|
|
processedDomains := make(map[string]struct{})
|
|
newIpsetAddressesMap := make(map[string]time.Duration)
|
|
now := time.Now()
|
|
|
|
oldIpsetAddresses, err := group.ListIPv4()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get old ipset list: %w", err)
|
|
}
|
|
|
|
knownDomains := a.Records.ListKnownDomains()
|
|
for _, domain := range group.Domains {
|
|
if !domain.IsEnabled() {
|
|
continue
|
|
}
|
|
|
|
for _, domainName := range knownDomains {
|
|
if !domain.IsMatch(domainName) {
|
|
continue
|
|
}
|
|
|
|
cnames := a.Records.GetCNameRecords(domainName, true)
|
|
if len(cnames) == 0 {
|
|
continue
|
|
}
|
|
for _, cname := range cnames {
|
|
processedDomains[cname] = struct{}{}
|
|
}
|
|
|
|
addresses := a.Records.GetARecords(domainName)
|
|
for _, address := range addresses {
|
|
ttl := now.Sub(address.Deadline)
|
|
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
|
|
newIpsetAddressesMap[string(address.Address)] = ttl
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
for addr, ttl := range newIpsetAddressesMap {
|
|
if _, exists := oldIpsetAddresses[addr]; exists {
|
|
continue
|
|
}
|
|
ip := net.IP(addr)
|
|
err = group.AddIPv4(ip, ttl)
|
|
if err != nil {
|
|
log.Error().
|
|
Str("address", ip.String()).
|
|
Err(err).
|
|
Msg("failed to add address")
|
|
}
|
|
}
|
|
|
|
for addr := range oldIpsetAddresses {
|
|
if _, exists := newIpsetAddressesMap[addr]; exists {
|
|
continue
|
|
}
|
|
ip := net.IP(addr)
|
|
err = group.DelIPv4(ip)
|
|
if err != nil {
|
|
log.Error().
|
|
Str("address", ip.String()).
|
|
Err(err).
|
|
Msg("failed to delete address")
|
|
} else {
|
|
log.Trace().
|
|
Str("address", ip.String()).
|
|
Err(err).
|
|
Msg("add address")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *App) ListInterfaces() ([]net.Interface, error) {
|
|
interfaceNames := make([]net.Interface, 0)
|
|
|
|
interfaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get interfaces: %w", err)
|
|
}
|
|
|
|
for _, iface := range interfaces {
|
|
if iface.Flags&net.FlagPointToPoint == 0 {
|
|
continue
|
|
}
|
|
|
|
interfaceNames = append(interfaceNames, iface)
|
|
}
|
|
|
|
return interfaceNames, nil
|
|
}
|
|
|
|
func (a *App) processARecord(aRecord dnsProxy.Address) {
|
|
log.Trace().
|
|
Str("name", aRecord.Name.String()).
|
|
Str("address", aRecord.Address.String()).
|
|
Int("ttl", int(aRecord.TTL)).
|
|
Msg("processing a record")
|
|
|
|
ttlDuration := time.Duration(aRecord.TTL) * time.Second
|
|
if ttlDuration < a.Config.MinimalTTL {
|
|
ttlDuration = a.Config.MinimalTTL
|
|
}
|
|
|
|
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
|
|
|
|
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
|
|
for _, group := range a.Groups {
|
|
Domain:
|
|
for _, domain := range group.Domains {
|
|
if !domain.IsEnabled() {
|
|
continue
|
|
}
|
|
for _, name := range names {
|
|
if !domain.IsMatch(name) {
|
|
continue
|
|
}
|
|
err := group.AddIPv4(aRecord.Address, ttlDuration)
|
|
if err != nil {
|
|
log.Error().
|
|
Str("address", aRecord.Address.String()).
|
|
Err(err).
|
|
Msg("failed to add address")
|
|
} else {
|
|
log.Trace().
|
|
Str("address", aRecord.Address.String()).
|
|
Str("aRecordDomain", aRecord.Name.String()).
|
|
Str("cNameDomain", name).
|
|
Err(err).
|
|
Msg("add address")
|
|
}
|
|
break Domain
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
|
|
log.Trace().
|
|
Str("name", cNameRecord.Name.String()).
|
|
Str("cname", cNameRecord.CName.String()).
|
|
Int("ttl", int(cNameRecord.TTL)).
|
|
Msg("processing cname record")
|
|
|
|
ttlDuration := time.Duration(cNameRecord.TTL) * time.Second
|
|
if ttlDuration < a.Config.MinimalTTL {
|
|
ttlDuration = a.Config.MinimalTTL
|
|
}
|
|
|
|
a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
|
|
|
|
// TODO: Optimization
|
|
now := time.Now()
|
|
aRecords := a.Records.GetARecords(cNameRecord.Name.String())
|
|
names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true)
|
|
for _, group := range a.Groups {
|
|
Domain:
|
|
for _, domain := range group.Domains {
|
|
if !domain.IsEnabled() {
|
|
continue
|
|
}
|
|
for _, name := range names {
|
|
if !domain.IsMatch(name) {
|
|
continue
|
|
}
|
|
for _, aRecord := range aRecords {
|
|
err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline))
|
|
if err != nil {
|
|
log.Error().
|
|
Str("address", aRecord.Address.String()).
|
|
Err(err).
|
|
Msg("failed to add address")
|
|
} else {
|
|
log.Trace().
|
|
Str("address", aRecord.Address.String()).
|
|
Str("cNameDomain", name).
|
|
Err(err).
|
|
Msg("add address")
|
|
}
|
|
}
|
|
continue Domain
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {
|
|
switch v := rr.(type) {
|
|
case dnsProxy.Address:
|
|
// TODO: Optimize equals domain A records
|
|
a.processARecord(v)
|
|
case dnsProxy.CName:
|
|
a.processCNameRecord(v)
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (a *App) handleMessage(msg *dnsProxy.Message) {
|
|
for _, rr := range msg.AN {
|
|
a.handleRecord(rr)
|
|
}
|
|
for _, rr := range msg.NS {
|
|
a.handleRecord(rr)
|
|
}
|
|
for _, rr := range msg.AR {
|
|
a.handleRecord(rr)
|
|
}
|
|
}
|
|
|
|
func New(config Config) (*App, error) {
|
|
var err error
|
|
|
|
app := &App{}
|
|
|
|
app.Config = config
|
|
|
|
app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress)
|
|
app.DNSProxy.MsgHandler = app.handleMessage
|
|
|
|
app.Records = NewRecords()
|
|
|
|
nh4, err := netfilterHelper.New(false)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
|
|
}
|
|
app.NetfilterHelper4 = nh4
|
|
err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to clear iptables: %w", err)
|
|
}
|
|
|
|
nh6, err := netfilterHelper.New(true)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
|
|
}
|
|
app.NetfilterHelper6 = nh6
|
|
err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to clear iptables: %w", err)
|
|
}
|
|
|
|
app.Groups = make(map[int]*Group)
|
|
|
|
return app, nil
|
|
}
|