From 8c4d8c1d494144e13239cf5abe85b4f1ed4d34a6 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sun, 25 Aug 2024 01:43:44 +0300 Subject: [PATCH] refactoring --- dns-proxy/dns-proxy.go | 202 ++++++++++++----------------- iptables-helper/iptables-helper.go | 60 +++++++++ main.go | 105 +++++++++------ 3 files changed, 208 insertions(+), 159 deletions(-) create mode 100644 iptables-helper/iptables-helper.go diff --git a/dns-proxy/dns-proxy.go b/dns-proxy/dns-proxy.go index c408a33..146aa6e 100644 --- a/dns-proxy/dns-proxy.go +++ b/dns-proxy/dns-proxy.go @@ -1,9 +1,9 @@ package dnsProxy import ( + "context" "encoding/hex" "fmt" - "github.com/coreos/go-iptables/iptables" "log" "net" "time" @@ -11,147 +11,111 @@ import ( const ( DNSMaxUDPPackageSize = 4096 - DNSMaxTCPPackageSize = 65536 ) type DNSProxy struct { - listenAddr string - upstreamAddr string + udpConn *net.UDPConn + listenPort uint16 - udpConn *net.UDPConn + targetDNSServerAddress string MsgHandler func(*Message) } -func (p DNSProxy) Close() error { - ipt, err := iptables.New() - if err != nil { - log.Fatalf("iptables init fail: %v", err) - } - - err = ipt.DeleteIfExists("nat", "PREROUTING", "-j", "KVAS2_DNSOVERRIDE") - if err != nil { - log.Fatalf("failed to attaching chain: %v", err) - } - - err = ipt.ClearAndDeleteChain("nat", "KVAS2_DNSOVERRIDE") - if err != nil { - log.Fatalf("failed to delete chain: %v", err) - } - - return nil - //return p.udpConn.Close() -} - -func (p DNSProxy) sendToUpstream(isTCP bool, request []byte) ([]byte, error) { - protocol := "udp" - if isTCP { - protocol = "tcp" - } - - conn, err := net.Dial(protocol, p.upstreamAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial upstream DNS: %w", err) - } - defer conn.Close() - - _, err = conn.Write(request) - if err != nil { - return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err) - } - - err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - if err != nil { - return nil, fmt.Errorf("failed to set timeout: %w", err) - } - - var response []byte - if !isTCP { - response = make([]byte, DNSMaxUDPPackageSize) - } else { - response = make([]byte, DNSMaxTCPPackageSize) - } - - n, err := conn.Read(response) - if err != nil { - return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err) - } - - return response[:n], nil -} - -func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) { - upstreamResponse, err := p.sendToUpstream(false, buffer) - if err != nil { - log.Printf("Failed to get response from upstream DNS: %v", err) - return - } - - log.Printf("Response: %s", hex.EncodeToString(upstreamResponse)) - - msg, err := ParseResponse(upstreamResponse) - if err == nil { - if p.MsgHandler != nil { - p.MsgHandler(msg) - } - } else { - log.Printf("error while parsing response: %v", err) - } - - _, err = p.udpConn.WriteToUDP(upstreamResponse, clientAddr) - if err != nil { - log.Printf("Failed to send DNS response: %v", err) - } -} - -func (p DNSProxy) Listen() error { +func (p DNSProxy) Listen(ctx context.Context) error { var err error - ipt, err := iptables.New() - if err != nil { - log.Fatalf("iptables init fail: %v", err) - } - - err = ipt.ClearChain("nat", "KVAS2_DNSOVERRIDE") - if err != nil { - log.Fatalf("failed to clean chain: %v", err) - } - - err = ipt.AppendUnique("nat", "KVAS2_DNSOVERRIDE", "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", "7548") - if err != nil { - log.Fatalf("failed to create rule: %v", err) - } - - err = ipt.InsertUnique("nat", "PREROUTING", 1, "-j", "KVAS2_DNSOVERRIDE") - if err != nil { - log.Fatalf("failed to attaching chain: %v", err) - } - - udpAddr, err := net.ResolveUDPAddr("udp", p.listenAddr) + udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", p.listenPort)) if err != nil { return fmt.Errorf("failed to resolve UDP address: %v", err) } p.udpConn, err = net.ListenUDP("udp", udpAddr) if err != nil { - return fmt.Errorf("failed to listen on UDP: %v", err) + return fmt.Errorf("failed to listen UDP address: %v", err) } + defer func() { + if p.udpConn != nil { + err := p.udpConn.Close() + if err != nil { + log.Printf("failed to close UDP connection: %v", err) + } + } + }() + for { - buffer := make([]byte, DNSMaxUDPPackageSize) - n, clientAddr, err := p.udpConn.ReadFromUDP(buffer) - if err != nil { - log.Printf("Failed to read from UDP: %v", err) - continue + select { + case <-ctx.Done(): + log.Println("Shutting down DNS proxy...") + return nil + default: + buffer := make([]byte, DNSMaxUDPPackageSize) + n, clientAddr, err := p.udpConn.ReadFromUDP(buffer) + if err != nil { + log.Printf("failed to read UDP packet: %v", err) + continue + } + + go p.handleDNSRequest(clientAddr, buffer[:n]) } - - go p.handleDNSRequest(clientAddr, buffer[:n]) } } -func New(listenAddr string, listenPort uint16, upstreamAddr string, upstreamPort uint16) *DNSProxy { +func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) { + conn, err := net.Dial("udp", p.targetDNSServerAddress) + if err != nil { + log.Printf("failed to dial target DNS: %v", err) + return + } + defer conn.Close() + + _, err = conn.Write(buffer) + if err != nil { + // TODO: Error log level + log.Printf("failed to send request to target DNS: %v", err) + return + } + + err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + // TODO: Error log level + log.Printf("failed to set read deadline: %v", err) + return + } + + response := make([]byte, DNSMaxUDPPackageSize) + n, err := conn.Read(response) + if err != nil { + // TODO: Error log level + log.Printf("failed to read response from target DNS: %v", err) + return + } + + // TODO: Debug log level + log.Printf("Response: %s", hex.EncodeToString(response[:n])) + + msg, err := ParseResponse(response[:n]) + if err == nil { + if p.MsgHandler != nil { + p.MsgHandler(msg) + } + } else { + // TODO: Warn log level + log.Printf("error while parsing DNS message: %v", err) + } + + _, err = p.udpConn.WriteToUDP(response[:n], clientAddr) + if err != nil { + // TODO: Error log level + log.Printf("failed to send DNS message: %v", err) + return + } +} + +func New(listenPort uint16, targetDNSServerAddress string) *DNSProxy { return &DNSProxy{ - listenAddr: fmt.Sprintf("%s:%d", listenAddr, listenPort), - upstreamAddr: fmt.Sprintf("%s:%d", upstreamAddr, upstreamPort), + listenPort: listenPort, + targetDNSServerAddress: targetDNSServerAddress, } } diff --git a/iptables-helper/iptables-helper.go b/iptables-helper/iptables-helper.go new file mode 100644 index 0000000..96f63ef --- /dev/null +++ b/iptables-helper/iptables-helper.go @@ -0,0 +1,60 @@ +package iptablesHelper + +import ( + "fmt" + "strconv" + + "github.com/coreos/go-iptables/iptables" +) + +type DNSOverrider struct { + ipt *iptables.IPTables + chainName string + destPort uint16 +} + +func (o DNSOverrider) Enable() error { + err := o.ipt.ClearChain("nat", o.chainName) + if err != nil { + return fmt.Errorf("failed to clear chain: %w", err) + } + + err = o.ipt.AppendUnique("nat", o.chainName, "-p", "udp", "--dport", "53", "-j", "REDIRECT", "--to-port", strconv.Itoa(int(o.destPort))) + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } + + err = o.ipt.InsertUnique("nat", "PREROUTING", 1, "-j", o.chainName) + if err != nil { + return fmt.Errorf("failed to linking chain: %w", err) + } + + return nil +} + +func (o DNSOverrider) Disable() error { + err := o.ipt.DeleteIfExists("nat", "PREROUTING", "-j", o.chainName) + if err != nil { + return fmt.Errorf("failed to unlinking chain: %w", err) + } + + err = o.ipt.ClearAndDeleteChain("nat", o.chainName) + if err != nil { + return fmt.Errorf("failed to delete chain: %w", err) + } + + return nil +} + +func NewDNSOverrider(chainName string, destPort uint16) (*DNSOverrider, error) { + ipt, err := iptables.New() + if err != nil { + return nil, fmt.Errorf("iptables init fail: %w", err) + } + + return &DNSOverrider{ + ipt: ipt, + chainName: chainName, + destPort: destPort, + }, nil +} diff --git a/main.go b/main.go index 42b562c..2aadf83 100644 --- a/main.go +++ b/main.go @@ -1,73 +1,98 @@ package main import ( + "context" "fmt" - dnsProxy "kvas2-go/dns-proxy" - ruleComposer "kvas2-go/rule-composer" "log" "os" "os/signal" "syscall" + + dnsProxy "kvas2-go/dns-proxy" + iptablesHelper "kvas2-go/iptables-helper" + ruleComposer "kvas2-go/rule-composer" ) var ( + ChainPostfix = "KVAS2" ListenPort = uint16(7548) - UsableDNSServerAddress = "127.0.0.1" - UsableDNSServerPort = uint16(53) + TargetDNSServerAddress = "127.0.0.1:53" ) func main() { records := ruleComposer.NewRecords() - proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort) - proxy.MsgHandler = func(msg *dnsProxy.Message) { - for _, q := range msg.QD { - fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String()) - fmt.Printf("%x: DBG (Before) Known addresses for: %s\n", msg.ID, q.QName.String()) - for idx, addr := range records.GetARecords(q.QName.String(), true) { - fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) - } - } - for _, a := range msg.AN { - switch v := a.(type) { - case dnsProxy.Address: - fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) - records.PutARecord(v.Name.String(), v.Address, int64(v.TTL)) - case dnsProxy.CName: - fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) - records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL)) - default: - fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) - } - } - for _, a := range msg.NS { - fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource()) - } - for _, a := range msg.AR { - fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource()) - } - - for _, q := range msg.QD { - fmt.Printf("%x: DBG (After) Known addresses for: %s\n", msg.ID, q.QName.String()) - for idx, addr := range records.GetARecords(q.QName.String(), true) { - fmt.Printf("%x: #%d: %s\n", msg.ID, idx, addr.String()) - } - } + proxy := dnsProxy.New(ListenPort, TargetDNSServerAddress) + dnsOverrider, err := iptablesHelper.NewDNSOverrider(fmt.Sprintf("%s_DNSOVERRIDER", ChainPostfix), ListenPort) + if err != nil { + log.Fatalf("failed to initialize DNS overrider: %v", err) } + proxy.MsgHandler = func(msg *dnsProxy.Message) { + printKnownRecords := func() { + for _, q := range msg.QD { + fmt.Printf("%04x: DBG Known addresses for: %s\n", msg.ID, q.QName.String()) + for idx, addr := range records.GetARecords(q.QName.String(), true) { + fmt.Printf("%04x: #%d: %s\n", msg.ID, idx, addr.String()) + } + } + } + parseResponseRecord := func(rr dnsProxy.ResourceRecord) { + switch v := rr.(type) { + case dnsProxy.Address: + fmt.Printf("%04x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) + records.PutARecord(v.Name.String(), v.Address, int64(v.TTL)) + case dnsProxy.CName: + fmt.Printf("%04x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) + records.PutCNameRecord(v.Name.String(), v.CName.String(), int64(v.TTL)) + default: + fmt.Printf("%04x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) + } + } + + printKnownRecords() + for _, q := range msg.QD { + fmt.Printf("%04x: <- Request name: %s\n", msg.ID, q.QName.String()) + } + for _, a := range msg.AN { + parseResponseRecord(a) + } + for _, a := range msg.NS { + parseResponseRecord(a) + } + for _, a := range msg.AR { + parseResponseRecord(a) + } + printKnownRecords() + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { - err := proxy.Listen() + err := proxy.Listen(ctx) if err != nil { log.Fatal(err) } }() + err = dnsOverrider.Enable() + if err != nil { + log.Fatalf("failed to override DNS: %v", err) + } + + fmt.Printf("Started service on port '%d'\n", ListenPort) + c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) for { select { case <-c: - proxy.Close() + fmt.Printf("Graceful shutdown...") + cancel() + err = dnsOverrider.Disable() + if err != nil { + log.Fatalf("failed to rollback override DNS changes: %v", err) + } return } }