package main import ( "context" "errors" "fmt" "net" "os" "strings" "time" "kvas2-go/dns-proxy" "kvas2-go/models" "kvas2-go/netfilter-helper" "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" ) var ( ErrAlreadyRunning = errors.New("already running") ErrGroupIDConflict = errors.New("group id conflict") ) type Config struct { MinimalTTL time.Duration ChainPrefix string IpSetPrefix string TargetDNSServerAddress string ListenPort uint16 UseSoftwareRouting bool } type App struct { Config Config DNSProxy *dnsProxy.DNSProxy NetfilterHelper4 *netfilterHelper.NetfilterHelper Records *Records Groups map[int]*Group isRunning bool dnsOverrider4 *netfilterHelper.PortRemap } func (a *App) handleLink(event netlink.LinkUpdate) { switch event.Change { case 0x00000001: log.Debug(). Str("interface", event.Link.Attrs().Name). Str("operstatestr", event.Attrs().OperState.String()). Int("operstate", int(event.Attrs().OperState)). Msg("interface change") if event.Attrs().OperState != netlink.OperDown { for _, group := range a.Groups { if group.Interface == event.Link.Attrs().Name { err := group.ifaceToIPSet.IfaceHandle() if err != nil { log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up") } } } } case 0xFFFFFFFF: switch event.Header.Type { case 16: log.Debug(). Str("interface", event.Link.Attrs().Name). Int("type", int(event.Header.Type)). Msg("interface add") case 17: log.Debug(). Str("interface", event.Link.Attrs().Name). Int("type", int(event.Header.Type)). Msg("interface del") } } } func (a *App) listen(ctx context.Context) (err error) { errChan := make(chan error) newCtx, cancel := context.WithCancel(ctx) defer cancel() go func() { err := a.DNSProxy.Listen(newCtx) if err != nil { errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err) } }() a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenPort) err = a.dnsOverrider4.Enable() if err != nil { return fmt.Errorf("failed to override DNS: %v", err) } defer func() { // TODO: Handle error _ = a.dnsOverrider4.Disable() }() for _, group := range a.Groups { err = group.Enable() if err != nil { return fmt.Errorf("failed to enable group: %w", err) } } defer func() { for _, group := range a.Groups { // TODO: Handle error _ = group.Disable() } }() socketPath := "/opt/var/run/kvas2-go.sock" err = os.Remove(socketPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("failed to remove existed UNIX socket: %w", err) } socket, err := net.Listen("unix", socketPath) if err != nil { return fmt.Errorf("error while serve UNIX socket: %v", err) } defer func() { // TODO: Handle error _ = socket.Close() _ = os.Remove(socketPath) }() go func() { for { conn, err := socket.Accept() if err != nil { if !strings.Contains(err.Error(), "use of closed network connection") { log.Error().Err(err).Msg("error while listening unix socket") } break } go func(conn net.Conn) { defer func() { // TODO: Handle error _ = conn.Close() }() buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil { return } args := strings.Split(string(buf[:n]), ":") if len(args) == 3 && args[0] == "netfilter.d" { log.Debug().Str("table", args[2]).Msg("netfilter.d event") if a.dnsOverrider4.Enabled { err := a.dnsOverrider4.PutIPTable(args[2]) if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } } for _, group := range a.Groups { if group.ifaceToIPSet.Enabled { err := group.ifaceToIPSet.PutIPTable(args[2]) if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } } } } }(conn) } }() link := make(chan netlink.LinkUpdate) done := make(chan struct{}) err = netlink.LinkSubscribe(link, done) if err != nil { return fmt.Errorf("failed to subscribe to link updates: %w", err) } defer func() { close(done) }() for { select { case event := <-link: a.handleLink(event) case err := <-errChan: return err case <-ctx.Done(): return nil } } } func (a *App) Listen(ctx context.Context) (err error) { if a.isRunning { return ErrAlreadyRunning } a.isRunning = true defer func() { a.isRunning = false }() defer func() { if r := recover(); r != nil { var recoveredError error var ok bool if recoveredError, ok = r.(error); !ok { recoveredError = fmt.Errorf("%v", r) } err = fmt.Errorf("recovered error: %w", recoveredError) } }() appErr := a.listen(ctx) if appErr != nil { return appErr } return err } func (a *App) AddGroup(group *models.Group) error { if _, exists := a.Groups[group.ID]; exists { return ErrGroupIDConflict } ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPrefix, group.ID) ipset, err := a.NetfilterHelper4.IPSet(ipsetName) if err != nil { return fmt.Errorf("failed to initialize ipset: %w", err) } grp := &Group{ Group: group, iptables: a.NetfilterHelper4.IPTables, ipset: ipset, ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false), } a.Groups[group.ID] = grp return a.SyncGroup(grp) } func (a *App) SyncGroup(group *Group) error { processedDomains := make(map[string]struct{}) newIpsetAddressesMap := make(map[string]time.Duration) now := time.Now() oldIpsetAddresses, err := group.ListIPv4() if err != nil { return fmt.Errorf("failed to get old ipset list: %w", err) } knownDomains := a.Records.ListKnownDomains() for _, domain := range group.Domains { if !domain.IsEnabled() { continue } for _, domainName := range knownDomains { if !domain.IsMatch(domainName) { continue } cnames := a.Records.GetCNameRecords(domainName, true) if len(cnames) == 0 { continue } for _, cname := range cnames { processedDomains[cname] = struct{}{} } addresses := a.Records.GetARecords(domainName) for _, address := range addresses { ttl := now.Sub(address.Deadline) if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL { newIpsetAddressesMap[string(address.Address)] = ttl } } } } for addr, ttl := range newIpsetAddressesMap { if _, exists := oldIpsetAddresses[addr]; exists { continue } ip := net.IP(addr) err = group.AddIPv4(ip, ttl) if err != nil { log.Error(). Str("address", ip.String()). Err(err). Msg("failed to add address") } } for addr := range oldIpsetAddresses { if _, exists := newIpsetAddressesMap[addr]; exists { continue } ip := net.IP(addr) err = group.DelIPv4(ip) if err != nil { log.Error(). Str("address", ip.String()). Err(err). Msg("failed to delete address") } else { log.Trace(). Str("address", ip.String()). Err(err). Msg("add address") } } return nil } func (a *App) ListInterfaces() ([]net.Interface, error) { interfaceNames := make([]net.Interface, 0) interfaces, err := net.Interfaces() if err != nil { return nil, fmt.Errorf("failed to get interfaces: %w", err) } for _, iface := range interfaces { if iface.Flags&net.FlagPointToPoint == 0 { continue } interfaceNames = append(interfaceNames, iface) } return interfaceNames, nil } func (a *App) processARecord(aRecord dnsProxy.Address) { log.Trace(). Str("name", aRecord.Name.String()). Str("address", aRecord.Address.String()). Int("ttl", int(aRecord.TTL)). Msg("processing a record") ttlDuration := time.Duration(aRecord.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) names := a.Records.GetCNameRecords(aRecord.Name.String(), true) for _, group := range a.Groups { Domain: for _, domain := range group.Domains { if !domain.IsEnabled() { continue } for _, name := range names { if !domain.IsMatch(name) { continue } err := group.AddIPv4(aRecord.Address, ttlDuration) if err != nil { log.Error(). Str("address", aRecord.Address.String()). Err(err). Msg("failed to add address") } else { log.Trace(). Str("address", aRecord.Address.String()). Str("aRecordDomain", aRecord.Name.String()). Str("cNameDomain", name). Err(err). Msg("add address") } break Domain } } } } func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { log.Trace(). Str("name", cNameRecord.Name.String()). Str("cname", cNameRecord.CName.String()). Int("ttl", int(cNameRecord.TTL)). Msg("processing cname record") ttlDuration := time.Duration(cNameRecord.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration) // TODO: Optimization now := time.Now() aRecords := a.Records.GetARecords(cNameRecord.Name.String()) names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true) for _, group := range a.Groups { Domain: for _, domain := range group.Domains { if !domain.IsEnabled() { continue } for _, name := range names { if !domain.IsMatch(name) { continue } for _, aRecord := range aRecords { err := group.AddIPv4(aRecord.Address, now.Sub(aRecord.Deadline)) if err != nil { log.Error(). Str("address", aRecord.Address.String()). Err(err). Msg("failed to add address") } else { log.Trace(). Str("address", aRecord.Address.String()). Str("cNameDomain", name). Err(err). Msg("add address") } } continue Domain } } } } func (a *App) handleRecord(rr dnsProxy.ResourceRecord) { switch v := rr.(type) { case dnsProxy.Address: // TODO: Optimize equals domain A records a.processARecord(v) case dnsProxy.CName: a.processCNameRecord(v) default: } } func (a *App) handleMessage(msg *dnsProxy.Message) { for _, rr := range msg.AN { a.handleRecord(rr) } for _, rr := range msg.NS { a.handleRecord(rr) } for _, rr := range msg.AR { a.handleRecord(rr) } } func New(config Config) (*App, error) { var err error app := &App{} app.Config = config app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress) app.DNSProxy.MsgHandler = app.handleMessage app.Records = NewRecords() nh4, err := netfilterHelper.New(false) if err != nil { return nil, fmt.Errorf("netfilter helper init fail: %w", err) } app.NetfilterHelper4 = nh4 err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix) if err != nil { return nil, fmt.Errorf("failed to clear iptables: %w", err) } app.Groups = make(map[int]*Group) return app, nil }