app instance

This commit is contained in:
Vladimir Avtsenov 2024-08-26 19:10:40 +03:00
parent 55d623dff2
commit 051aef824d
12 changed files with 193 additions and 77 deletions

View File

@ -2,9 +2,10 @@
Better implementation of [KVAS](https://github.com/qzeleza/kvas)
Roadmap:
- [x] DNS Proxy
- [x] DNS Records table
Realized features:
- [x] DNS Proxy (UDP)
- [ ] DNS Proxy (TCP)
- [x] Records table
- [x] IPTables rules to remap DNS server [1]
- [ ] Rule composer
- [ ] List loading/watching (temporary)

146
kvas2.go Normal file
View File

@ -0,0 +1,146 @@
package main
import (
"context"
"fmt"
"kvas2-go/models"
"kvas2-go/pkg/dns-proxy"
"kvas2-go/pkg/iptables-helper"
"sync"
"time"
)
type Config struct {
MinimalTTL time.Duration
ChainPostfix string
TargetDNSServerAddress string
ListenPort uint16
}
type App struct {
Config Config
DNSProxy *dnsProxy.DNSProxy
DNSOverrider *iptablesHelper.DNSOverrider
Records *Records
Groups []*models.Group
}
func (a *App) Listen(ctx context.Context) []error {
errs := make([]error, 0)
isError := make(chan struct{})
var once sync.Once
var errsMu sync.Mutex
handleError := func(err error) {
errsMu.Lock()
defer errsMu.Unlock()
errs = append(errs, err)
once.Do(func() { close(isError) })
}
defer func() {
if r := recover(); r != nil {
if err, ok := r.(error); ok {
handleError(err)
} else {
handleError(fmt.Errorf("%v", r))
}
}
}()
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
if err := a.DNSOverrider.Enable(); err != nil {
handleError(fmt.Errorf("failed to override DNS: %w", err))
return errs
}
go func() {
if err := a.DNSProxy.Listen(newCtx); err != nil {
handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err))
}
}()
select {
case <-ctx.Done():
case <-isError:
}
if err := a.DNSOverrider.Disable(); err != nil {
handleError(fmt.Errorf("failed to rollback override DNS changes: %w", err))
}
return errs
}
func (a *App) handleRecord(msg *dnsProxy.Message) {
printKnownRecords := func() {
for _, q := range msg.QD {
fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String())
for idx, addr := range a.Records.GetARecords(q.QName.String(), true) {
fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String())
}
}
}
parseResponseRecord := func(rr dnsProxy.ResourceRecord) {
switch v := rr.(type) {
case dnsProxy.Address:
fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
ttlDuration := time.Duration(v.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.PutARecord(v.Name.String(), v.Address, ttlDuration)
case dnsProxy.CName:
fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
ttlDuration := time.Duration(v.TTL) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.PutCNameRecord(v.Name.String(), v.CName.String(), ttlDuration)
default:
fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
}
}
printKnownRecords()
for _, q := range msg.QD {
fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String())
}
for _, a := range msg.AN {
parseResponseRecord(a)
}
for _, a := range msg.NS {
parseResponseRecord(a)
}
for _, a := range msg.AR {
parseResponseRecord(a)
}
printKnownRecords()
}
func New(config Config) (*App, error) {
var err error
app := &App{}
app.Config = config
app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress)
app.DNSProxy.MsgHandler = app.handleRecord
app.Records = NewRecords()
app.DNSOverrider, err = iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", app.Config.ChainPostfix), app.Config.ListenPort)
if err != nil {
return nil, fmt.Errorf("failed to initialize DNS overrider: %w", err)
}
app.Groups = make([]*models.Group, 0)
return app, nil
}

86
main.go
View File

@ -7,93 +7,45 @@ import (
"os"
"os/signal"
"syscall"
dnsProxy "kvas2-go/dns-proxy"
iptablesHelper "kvas2-go/iptables-helper"
ruleComposer "kvas2-go/rule-composer"
)
var (
ChainPostfix = "KVAS2"
ListenPort = uint16(7548)
TargetDNSServerAddress = "127.0.0.1:53"
"time"
)
func main() {
records := ruleComposer.NewRecords()
proxy := dnsProxy.New(ListenPort, TargetDNSServerAddress)
dnsOverrider, err := iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", ChainPostfix), ListenPort)
app, err := New(Config{
MinimalTTL: time.Hour,
ChainPostfix: "KVAS2",
TargetDNSServerAddress: "127.0.0.1:53",
ListenPort: 7548,
})
if err != nil {
log.Fatalf("failed to initialize DNS overrider: %v", err)
}
proxy.MsgHandler = func(msg *dnsProxy.Message) {
printKnownRecords := func() {
for _, q := range msg.QD {
fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String())
for idx, addr := range records.GetARecords(q.QName.String(), true) {
fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String())
}
}
}
parseResponseRecord := func(rr dnsProxy.ResourceRecord) {
switch v := rr.(type) {
case dnsProxy.Address:
fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
records.PutARecord(v.Name.String(), v.Address, int64(v.TTL))
case dnsProxy.CName:
fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL))
default:
fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
}
}
printKnownRecords()
for _, q := range msg.QD {
fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String())
}
for _, a := range msg.AN {
parseResponseRecord(a)
}
for _, a := range msg.NS {
parseResponseRecord(a)
}
for _, a := range msg.AR {
parseResponseRecord(a)
}
printKnownRecords()
log.Fatalf("failed to initialize application: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
appErrsChan := make(chan []error)
go func() {
err := proxy.Listen(ctx)
if err != nil {
log.Fatalf("failed to initialize DNS proxy: %v", err)
}
errs := app.Listen(ctx)
appErrsChan <- errs
}()
err = dnsOverrider.Enable()
if err != nil {
log.Fatalf("failed to override DNS: %v", err)
}
fmt.Printf("Started service on port '%d'\n", ListenPort)
fmt.Println("Started service...")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
for {
select {
case appErrs, _ := <-appErrsChan:
for _, err := range appErrs {
// TODO: Error log level
log.Printf("failed to start application: %v", err)
}
return
case <-c:
fmt.Println("Graceful shutdown...")
cancel()
err = dnsOverrider.Disable()
if err != nil {
log.Fatalf("failed to rollback override DNS changes: %v", err)
}
return
}
}
}

10
models/domain.go Normal file
View File

@ -0,0 +1,10 @@
package models
type Domain struct {
ID int
Group *Group
Type string
Domain string
Enable bool
Comment string
}

8
models/group.go Normal file
View File

@ -0,0 +1,8 @@
package models
type Group struct {
ID int
Name string
Interface string
Domains []*Domain
}

View File

@ -47,7 +47,6 @@ func (p DNSProxy) Listen(ctx context.Context) error {
for {
select {
case <-ctx.Done():
log.Println("Shutting down DNS proxy...")
return nil
default:
buffer := make([]byte, DNSMaxUDPPackageSize)

View File

@ -27,7 +27,7 @@ func parseName(response []byte, pos int) (*Name, int, error) {
break
}
if length&0xC0 == 0xC0 {
if length&0xC0 != 0 {
if responseLen < pos+1 {
return nil, pos, io.EOF
}

View File

@ -1,4 +1,4 @@
package ruleComposer
package main
import (
"net"
@ -73,7 +73,7 @@ func (r *Records) GetARecords(domainName string, recursive bool) []net.IP {
return aRecords
}
func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) {
func (r *Records) PutCNameRecord(domainName string, cName string, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
@ -81,10 +81,10 @@ func (r *Records) PutCNameRecord(domainName string, cName string, ttl int64) {
r.cnameRecords[domainName] = make(map[string]time.Time)
}
r.cnameRecords[domainName][cName] = time.Now().Add(time.Second * time.Duration(ttl))
r.cnameRecords[domainName][cName] = time.Now().Add(ttl)
}
func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) {
func (r *Records) PutARecord(domainName string, addr net.IP, ttl time.Duration) {
r.mutex.Lock()
defer r.mutex.Unlock()
@ -92,7 +92,7 @@ func (r *Records) PutARecord(domainName string, addr net.IP, ttl int64) {
r.aRecords[domainName] = make(map[string]time.Time)
}
r.aRecords[domainName][string(addr)] = time.Now().Add(time.Second * time.Duration(ttl))
r.aRecords[domainName][string(addr)] = time.Now().Add(ttl)
}
func (r *Records) Cleanup() {