diff --git a/README.md b/README.md index 3cce9ba..791ee70 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,10 @@ Better implementation of [KVAS](https://github.com/qzeleza/kvas) -Roadmap: -- [x] DNS Proxy -- [x] DNS Records table +Realized features: +- [x] DNS Proxy (UDP) +- [ ] DNS Proxy (TCP) +- [x] Records table - [x] IPTables rules to remap DNS server [1] - [ ] Rule composer - [ ] List loading/watching (temporary) diff --git a/kvas2.go b/kvas2.go new file mode 100644 index 0000000..daeb21e --- /dev/null +++ b/kvas2.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "kvas2-go/models" + "kvas2-go/pkg/dns-proxy" + "kvas2-go/pkg/iptables-helper" + "sync" + "time" +) + +type Config struct { + MinimalTTL time.Duration + ChainPostfix string + TargetDNSServerAddress string + ListenPort uint16 +} + +type App struct { + Config Config + + DNSProxy *dnsProxy.DNSProxy + DNSOverrider *iptablesHelper.DNSOverrider + Records *Records + Groups []*models.Group +} + +func (a *App) Listen(ctx context.Context) []error { + errs := make([]error, 0) + isError := make(chan struct{}) + + var once sync.Once + var errsMu sync.Mutex + handleError := func(err error) { + errsMu.Lock() + defer errsMu.Unlock() + + errs = append(errs, err) + once.Do(func() { close(isError) }) + } + + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + handleError(err) + } else { + handleError(fmt.Errorf("%v", r)) + } + } + }() + + newCtx, cancel := context.WithCancel(ctx) + defer cancel() + + if err := a.DNSOverrider.Enable(); err != nil { + handleError(fmt.Errorf("failed to override DNS: %w", err)) + return errs + } + + go func() { + if err := a.DNSProxy.Listen(newCtx); err != nil { + handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err)) + } + }() + + select { + case <-ctx.Done(): + case <-isError: + } + + if err := a.DNSOverrider.Disable(); err != nil { + handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err)) + } + + return errs +} + +func (a *App) handleRecord(msg *dnsProxy.Message) { + printKnownRecords := func() { + for _, q := range msg.QD { + fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) + for idx, addr := range a.Records.GetARecords(q.QName.String(), true) { + fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String()) + } + } + } + parseResponseRecord := func(rr dnsProxy.ResourceRecord) { + switch v := rr.(type) { + case dnsProxy.Address: + fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) + ttlDuration := time.Duration(v.TTL) * time.Second + if ttlDuration < a.Config.MinimalTTL { + ttlDuration = a.Config.MinimalTTL + } + a.Records.PutARecord(v.Name.String(), v.Address, ttlDuration) + case dnsProxy.CName: + fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) + ttlDuration := time.Duration(v.TTL) * time.Second + if ttlDuration < a.Config.MinimalTTL { + ttlDuration = a.Config.MinimalTTL + } + a.Records.PutCNameRecord(v.Name.String(), v.CName.String(), ttlDuration) + default: + fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) + } + } + + printKnownRecords() + for _, q := range msg.QD { + fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String()) + } + for _, a := range msg.AN { + parseResponseRecord(a) + } + for _, a := range msg.NS { + parseResponseRecord(a) + } + for _, a := range msg.AR { + parseResponseRecord(a) + } + printKnownRecords() +} + +func New(config Config) (*App, error) { + var err error + + app := &App{} + + app.Config = config + + app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress) + app.DNSProxy.MsgHandler = app.handleRecord + + app.Records = NewRecords() + + app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort) + if err != nil { + return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err) + } + + app.Groups = make([]*models.Group, 0) + + return app, nil + +} diff --git a/main.go b/main.go index 8d8a70f..1d82fee 100644 --- a/main.go +++ b/main.go @@ -7,93 +7,45 @@ import ( "os" "os/signal" "syscall" - - dnsProxy "kvas2-go/dns-proxy" - iptablesHelper "kvas2-go/iptables-helper" - ruleComposer "kvas2-go/rule-composer" -) - -var ( - ChainPostfix = "KVAS2" - ListenPort = uint16(7548) - TargetDNSServerAddress = "127.0.0.1:53" + "time" ) func main() { - records := ruleComposer.NewRecords() - proxy := dnsProxy.New(ListenPort, TargetDNSServerAddress) - dnsOverrider, err := iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", ChainPostfix), ListenPort) + app, err := New(Config{ + MinimalTTL: time.Hour, + ChainPostfix: "KVAS2", + TargetDNSServerAddress: "127.0.0.1:53", + ListenPort: 7548, + }) if err != nil { - log.Fatalf("failed to initialize DNS overrider: %v", err) - } - - proxy.MsgHandler = func(msg *dnsProxy.Message) { - printKnownRecords := func() { - for _, q := range msg.QD { - fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) - for idx, addr := range records.GetARecords(q.QName.String(), true) { - fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String()) - } - } - } - parseResponseRecord := func(rr dnsProxy.ResourceRecord) { - switch v := rr.(type) { - case dnsProxy.Address: - fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) - records.PutARecord(v.Name.String(), v.Address, int64(v.TTL)) - case dnsProxy.CName: - fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) - records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL)) - default: - fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) - } - } - - printKnownRecords() - for _, q := range msg.QD { - fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String()) - } - for _, a := range msg.AN { - parseResponseRecord(a) - } - for _, a := range msg.NS { - parseResponseRecord(a) - } - for _, a := range msg.AR { - parseResponseRecord(a) - } - printKnownRecords() + log.Fatalf("failed to initialize application: %v", err) } ctx, cancel := context.WithCancel(context.Background()) + appErrsChan := make(chan []error) go func() { - err := proxy.Listen(ctx) - if err != nil { - log.Fatalf("failed to initialize DNS proxy: %v", err) - } + errs := app.Listen(ctx) + appErrsChan <- errs + }() - err = dnsOverrider.Enable() - if err != nil { - log.Fatalf("failed to override DNS: %v", err) - } - - fmt.Printf("Started service on port '%d'\n", ListenPort) + fmt.Println("Started service...") c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) for { select { + case appErrs, _ := <-appErrsChan: + for _, err := range appErrs { + // TODO: Error log level + log.Printf("failed to start application: %v", err) + } + return case <-c: fmt.Println("Graceful shutdown...") cancel() - err = dnsOverrider.Disable() - if err != nil { - log.Fatalf("failed to rollback override DNS changes: %v", err) - } - return } } } diff --git a/models/domain.go b/models/domain.go new file mode 100644 index 0000000..07ac2d0 --- /dev/null +++ b/models/domain.go @@ -0,0 +1,10 @@ +package models + +type Domain struct { + ID int + Group *Group + Type string + Domain string + Enable bool + Comment string +} diff --git a/models/group.go b/models/group.go new file mode 100644 index 0000000..7c527b0 --- /dev/null +++ b/models/group.go @@ -0,0 +1,8 @@ +package models + +type Group struct { + ID int + Name string + Interface string + Domains []*Domain +} diff --git a/dns-proxy/dns-proxy.go b/pkg/dns-proxy/dns-proxy.go similarity index 98% rename from dns-proxy/dns-proxy.go rename to pkg/dns-proxy/dns-proxy.go index 146aa6e..c1b5eac 100644 --- a/dns-proxy/dns-proxy.go +++ b/pkg/dns-proxy/dns-proxy.go @@ -47,7 +47,6 @@ func (p DNSProxy) Listen(ctx context.Context) error { for { select { case <-ctx.Done(): - log.Println("Shutting down DNS proxy...") return nil default: buffer := make([]byte, DNSMaxUDPPackageSize) diff --git a/dns-proxy/parser.go b/pkg/dns-proxy/parser.go similarity index 99% rename from dns-proxy/parser.go rename to pkg/dns-proxy/parser.go index dfd961a..c7b97a9 100644 --- a/dns-proxy/parser.go +++ b/pkg/dns-proxy/parser.go @@ -27,7 +27,7 @@ func parseName(response []byte, pos int) (*Name, int, error) { break } - if length&0xC0 == 0xC0 { + if length&0xC0 != 0 { if responseLen < pos+1 { return nil, pos, io.EOF } diff --git a/dns-proxy/types.go b/pkg/dns-proxy/types.go similarity index 100% rename from dns-proxy/types.go rename to pkg/dns-proxy/types.go diff --git a/dns-proxy/types_test.go b/pkg/dns-proxy/types_test.go similarity index 100% rename from dns-proxy/types_test.go rename to pkg/dns-proxy/types_test.go diff --git a/ip-helper/ip-helper.go b/pkg/ip-helper/ip-helper.go similarity index 100% rename from ip-helper/ip-helper.go rename to pkg/ip-helper/ip-helper.go diff --git a/iptables-helper/iptables-helper.go b/pkg/iptables-helper/iptables-helper.go similarity index 100% rename from iptables-helper/iptables-helper.go rename to pkg/iptables-helper/iptables-helper.go diff --git a/rule-composer/records.go b/records.go similarity index 92% rename from rule-composer/records.go rename to records.go index cc29816..079226e 100644 --- a/rule-composer/records.go +++ b/records.go @@ -1,4 +1,4 @@ -package ruleComposer +package main import ( "net" @@ -73,7 +73,7 @@ func (r *Records) GetARecords(domainName string, recursive bool) []net.IP { return aRecords } -func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) { +func (r *Records) PutCNameRecord(domainName string, cName string, ttl time.Duration) { r.mutex.Lock() defer r.mutex.Unlock() @@ -81,10 +81,10 @@ func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) { r.cnameRecords[domainName] = make(map[string]time.Time) } - r.cnameRecords[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl)) + r.cnameRecords[domainName][cName] = time.Now().Add(ttl) } -func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) { +func (r *Records) PutARecord(domainName string, addr net.IP, ttl time.Duration) { r.mutex.Lock() defer r.mutex.Unlock() @@ -92,7 +92,7 @@ func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) { r.aRecords[domainName] = make(map[string]time.Time) } - r.aRecords[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl)) + r.aRecords[domainName][string(addr)] = time.Now().Add(ttl) } func (r *Records) Cleanup() {