package main import ( "context" "fmt" "kvas2-go/models" "kvas2-go/pkg/dns-proxy" "kvas2-go/pkg/iptables-helper" "sync" "time" ) type Config struct { MinimalTTL time.Duration ChainPostfix string TargetDNSServerAddress string ListenPort uint16 } type App struct { Config Config DNSProxy *dnsProxy.DNSProxy DNSOverrider *iptablesHelper.DNSOverrider Records *Records Groups []*models.Group } func (a *App) Listen(ctx context.Context) []error { errs := make([]error, 0) isError := make(chan struct{}) var once sync.Once var errsMu sync.Mutex handleError := func(err error) { errsMu.Lock() defer errsMu.Unlock() errs = append(errs, err) once.Do(func() { close(isError) }) } defer func() { if r := recover(); r != nil { if err, ok := r.(error); ok { handleError(err) } else { handleError(fmt.Errorf("%v", r)) } } }() newCtx, cancel := context.WithCancel(ctx) defer cancel() if err := a.DNSOverrider.Enable(); err != nil { handleError(fmt.Errorf("failed to override DNS: %w", err)) return errs } go func() { if err := a.DNSProxy.Listen(newCtx); err != nil { handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err)) } }() select { case <-ctx.Done(): case <-isError: } if err := a.DNSOverrider.Disable(); err != nil { handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err)) } return errs } func (a *App) handleRecord(msg *dnsProxy.Message) { printKnownRecords := func() { for _, q := range msg.QD { fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) for idx, addr := range a.Records.GetARecords(q.QName.String(), true) { fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String()) } } } parseResponseRecord := func(rr dnsProxy.ResourceRecord) { switch v := rr.(type) { case dnsProxy.Address: fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) ttlDuration := time.Duration(v.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } a.Records.PutARecord(v.Name.String(), v.Address, ttlDuration) case dnsProxy.CName: fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) ttlDuration := time.Duration(v.TTL) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } a.Records.PutCNameRecord(v.Name.String(), v.CName.String(), ttlDuration) default: fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) } } printKnownRecords() for _, q := range msg.QD { fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String()) } for _, a := range msg.AN { parseResponseRecord(a) } for _, a := range msg.NS { parseResponseRecord(a) } for _, a := range msg.AR { parseResponseRecord(a) } printKnownRecords() } func New(config Config) (*App, error) { var err error app := &App{} app.Config = config app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress) app.DNSProxy.MsgHandler = app.handleRecord app.Records = NewRecords() app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort) if err != nil { return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err) } app.Groups = make([]*models.Group, 0) return app, nil }