From f8c31c9f015770a3e88b9b82fcbf3fd40759f5c1 Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sat, 24 Aug 2024 01:16:10 +0300 Subject: [PATCH] first dirty implementation of dns proxy --- go.mod | 3 + main.go | 176 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+) create mode 100644 go.mod create mode 100644 main.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..514b814 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module kvas2-go + +go 1.21 \ No newline at end of file diff --git a/main.go b/main.go new file mode 100644 index 0000000..5381213 --- /dev/null +++ b/main.go @@ -0,0 +1,176 @@ +package main + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "log" + "net" + "strings" + "time" +) + +var ( + ListenPort = 7548 + UsableDNSServerAddress = "127.0.0.1" + UsableDNSServerPort = 53 + DNSMaxPackageSize = 4096 +) + +func parseName(response []byte, pos int) (string, int) { + var nameParts []string + var jumped bool + var outPos int + responseLen := len(response) + + for { + length := int(response[pos]) + pos++ + if length == 0 { + break + } + + if length&0xC0 == 0xC0 { + if !jumped { + outPos = pos + 1 + } + pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF) + jumped = true + continue + } + + if pos+length > responseLen { + break + } + + nameParts = append(nameParts, string(response[pos:pos+length])) + pos += length + } + + if !jumped { + outPos = pos + } + return strings.Join(nameParts, "."), outPos +} + +func sendToUpstream(upstreamAddr string, request []byte) ([]byte, error) { + conn, err := net.Dial("udp", upstreamAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial upstream DNS: %w", err) + } + defer conn.Close() + + _, err = conn.Write(request) + if err != nil { + return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err) + } + + err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to set timeout: %w", err) + } + + response := make([]byte, DNSMaxPackageSize) + n, err := conn.Read(response) + if err != nil { + return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err) + } + + return response[:n], nil +} + +func main() { + addr := fmt.Sprintf(":%d", ListenPort) + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + log.Fatalf("Failed to resolve address: %v", err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + log.Fatalf("Failed to listen on UDP: %v", err) + } + defer conn.Close() + + fmt.Printf("DNS server is running on %s...\n", addr) + + for { + buffer := make([]byte, DNSMaxPackageSize) + n, clientAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + log.Printf("Failed to read from UDP: %v", err) + continue + } + + go handleDNSRequest(conn, clientAddr, buffer[:n]) + } +} + +func process(response []byte) { + responseLen := len(response) + if responseLen <= 12 { + return + } + + qCount := int(binary.LittleEndian.Uint16(response[5:7])) + aCount := int(binary.LittleEndian.Uint16(response[7:9])) + + pos := 12 + + for i := 0; i < qCount; i++ { + var name string + name, pos = parseName(response, pos) + fmt.Printf("Requested name: %s\n", name) + pos += 4 + } + + for i := 0; i < aCount; i++ { + name, newPos := parseName(response, pos) + pos = newPos + + if pos+10 > responseLen { + break + } + + qtype := binary.BigEndian.Uint16(response[pos : pos+2]) + pos += 2 + + qclass := binary.BigEndian.Uint16(response[pos : pos+2]) + pos += 2 + + ttl := binary.BigEndian.Uint32(response[pos : pos+4]) + pos += 4 + + rdlength := binary.BigEndian.Uint16(response[pos : pos+2]) + pos += 2 + + if pos+int(rdlength) > responseLen { + break + } + + if qtype == 1 && qclass == 1 && rdlength == 4 { + ip := net.IPv4(response[pos], response[pos+1], response[pos+2], response[pos+3]) + fmt.Printf("Parsed A record: %s -> %s, TTL: %d\n", name, ip, ttl) + } + + pos += int(rdlength) + } +} + +func handleDNSRequest(conn *net.UDPConn, clientAddr *net.UDPAddr, buffer []byte) { + upstreamAddr := fmt.Sprintf("%s:%d", UsableDNSServerAddress, UsableDNSServerPort) + + upstreamResponse, err := sendToUpstream(upstreamAddr, buffer) + if err != nil { + log.Printf("Failed to get response from upstream DNS: %v", err) + return + } + log.Printf("Response: %s", hex.EncodeToString(upstreamResponse)) + + process(upstreamResponse) + + _, err = conn.WriteToUDP(upstreamResponse, clientAddr) + if err != nil { + log.Printf("Failed to send DNS response: %v", err) + } +}