dns-proxy refactoring

This commit is contained in:
Vladimir Avtsenov 2024-08-24 17:46:34 +03:00
parent 7ea3ec1b70
commit 27e8086c50
5 changed files with 747 additions and 161 deletions

120
dns-proxy/dns-proxy.go Normal file
View File

@ -0,0 +1,120 @@
package dnsProxy
import (
"encoding/hex"
"fmt"
"log"
"net"
"time"
)
const (
DNSMaxUDPPackageSize = 4096
DNSMaxTCPPackageSize = 65536
)
type DNSProxy struct {
listenAddr string
upstreamAddr string
udpConn *net.UDPConn
MsgHandler func(*Message)
}
func (p DNSProxy) Close() error {
return p.udpConn.Close()
}
func (p DNSProxy) sendToUpstream(isTCP bool, request []byte) ([]byte, error) {
protocol := "udp"
if isTCP {
protocol = "tcp"
}
conn, err := net.Dial(protocol, p.upstreamAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
}
defer conn.Close()
_, err = conn.Write(request)
if err != nil {
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
}
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return nil, fmt.Errorf("failed to set timeout: %w", err)
}
var response []byte
if !isTCP {
response = make([]byte, DNSMaxUDPPackageSize)
} else {
response = make([]byte, DNSMaxTCPPackageSize)
}
n, err := conn.Read(response)
if err != nil {
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
}
return response[:n], nil
}
func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) {
upstreamResponse, err := p.sendToUpstream(false, buffer)
if err != nil {
log.Printf("Failed to get response from upstream DNS: %v", err)
return
}
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
msg, err := ParseResponse(upstreamResponse)
if err == nil {
if p.MsgHandler != nil {
p.MsgHandler(msg)
}
} else {
log.Printf("error while parsing response: %v", err)
}
_, err = p.udpConn.WriteToUDP(upstreamResponse, clientAddr)
if err != nil {
log.Printf("Failed to send DNS response: %v", err)
}
}
func (p DNSProxy) Listen() error {
var err error
udpAddr, err := net.ResolveUDPAddr("udp", p.listenAddr)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %v", err)
}
p.udpConn, err = net.ListenUDP("udp", udpAddr)
if err != nil {
return fmt.Errorf("failed to listen on UDP: %v", err)
}
for {
buffer := make([]byte, DNSMaxUDPPackageSize)
n, clientAddr, err := p.udpConn.ReadFromUDP(buffer)
if err != nil {
log.Printf("Failed to read from UDP: %v", err)
continue
}
go p.handleDNSRequest(clientAddr, buffer[:n])
}
}
func New(listenAddr string, listenPort uint16, upstreamAddr string, upstreamPort uint16) *DNSProxy {
return &DNSProxy{
listenAddr: fmt.Sprintf("%s:%d", listenAddr, listenPort),
upstreamAddr: fmt.Sprintf("%s:%d", upstreamAddr, upstreamPort),
}
}

178
dns-proxy/parser.go Normal file
View File

@ -0,0 +1,178 @@
package dnsProxy
import (
"encoding/binary"
"errors"
"fmt"
)
var (
ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header")
ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header")
ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data")
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
)
func parseName(response []byte, pos int) (Name, int) {
var nameParts []string
var jumped bool
var outPos int
responseLen := len(response)
for {
length := int(response[pos])
pos++
if length == 0 {
break
}
if length&0xC0 == 0xC0 {
if !jumped {
outPos = pos + 1
}
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
jumped = true
continue
}
if pos+length > responseLen {
break
}
nameParts = append(nameParts, string(response[pos:pos+length]))
pos += length
}
if !jumped {
outPos = pos
}
return Name{Parts: nameParts}, outPos
}
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
responseLen := len(response)
var rhname Name
rhname, pos = parseName(response, pos)
if responseLen < pos+10 {
return nil, pos, ErrInvalidDNSResourceRecordHeader
}
rh := ResourceRecordHeader{
Name: rhname,
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
}
rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
pos += 10
if pos+rdlength > responseLen {
return nil, pos, ErrInvalidDNSResourceRecordData
}
switch rh.Type {
case 1:
if rdlength == 4 {
return Address{
ResourceRecordHeader: rh,
Address: response[pos+0 : pos+4],
}, pos + 4, nil
} else {
return nil, pos, ErrInvalidDNSAddressResourceData
}
case 2:
var ns Name
ns, pos = parseName(response, pos)
return NameServer{
ResourceRecordHeader: rh,
NSDName: ns,
}, pos, nil
case 5:
var cname Name
cname, pos = parseName(response, pos)
return CName{
ResourceRecordHeader: rh,
CName: cname,
}, pos, nil
}
return Unknown{
ResourceRecordHeader: rh,
Data: response[pos+0 : pos+rdlength],
}, pos + rdlength, nil
}
func ParseResponse(response []byte) (*Message, error) {
var err error
responseLen := len(response)
if responseLen < 12 {
return nil, ErrInvalidDNSMessageHeader
}
msg := new(Message)
msg.ID = binary.BigEndian.Uint16(response[0:2])
flagsRAW := binary.BigEndian.Uint16(response[2:4])
msg.Flags = Flags{
QR: uint8(flagsRAW >> 15 & 0x1),
Opcode: uint8(flagsRAW >> 11 & 0xF),
AA: uint8(flagsRAW >> 10 & 0x1),
TC: uint8(flagsRAW >> 9 & 0x1),
RD: uint8(flagsRAW >> 8 & 0x1),
RA: uint8(flagsRAW >> 7 & 0x1),
Z1: uint8(flagsRAW >> 6 & 0x1),
Z2: uint8(flagsRAW >> 5 & 0x1),
Z3: uint8(flagsRAW >> 4 & 0x1),
RCode: uint8(flagsRAW >> 0 & 0xF),
}
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
anCount := int(binary.BigEndian.Uint16(response[6:8]))
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
arCount := int(binary.BigEndian.Uint16(response[10:12]))
pos := 12
msg.QD = make([]Question, qdCount)
for i := 0; i < qdCount; i++ {
var name Name
name, pos = parseName(response, pos)
msg.QD[i] = Question{
QName: name,
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
}
pos += 4
}
msg.AN = make([]ResourceRecord, anCount)
for i := 0; i < anCount; i++ {
msg.AN[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return nil, fmt.Errorf("error while parsing AN record: %w", err)
}
}
msg.NS = make([]ResourceRecord, nsCount)
for i := 0; i < nsCount; i++ {
msg.NS[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return nil, fmt.Errorf("error while parsing NS record: %w", err)
}
}
msg.AR = make([]ResourceRecord, arCount)
for i := 0; i < arCount; i++ {
msg.AR[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return nil, fmt.Errorf("error while parsing AR record: %w", err)
}
}
return msg, nil
}

196
dns-proxy/types.go Normal file
View File

@ -0,0 +1,196 @@
package dnsProxy
import (
"bytes"
"encoding/binary"
"net"
"strings"
)
type ResourceRecord interface {
EncodeResource() []byte
}
type ResourceRecordHeader struct {
Name Name
Type uint16
Class uint16
TTL uint32
}
func (q ResourceRecordHeader) EncodeHeader() []byte {
buf := bytes.NewBuffer([]byte{})
buf.Write(q.Name.Encode())
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type))
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class))
buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL))
return buf.Bytes()
}
type Name struct {
Parts []string
}
func (n Name) String() string {
return strings.Join(n.Parts, ".")
}
func (n Name) Encode() []byte {
buf := bytes.NewBuffer([]byte{})
for _, part := range n.Parts {
partLen := byte(len(part)) & 0x3F
buf.WriteByte(partLen)
buf.Write([]byte(part)[0:partLen])
}
buf.WriteByte(0)
return buf.Bytes()
}
type Flags struct {
QR uint8
Opcode uint8
AA uint8
TC uint8
RD uint8
RA uint8
Z1 uint8
Z2 uint8
Z3 uint8
RCode uint8
}
func (f Flags) Encode() []byte {
return []byte{
f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0,
f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0,
}
}
type Question struct {
QName Name
QType uint16
QClass uint16
}
func (q Question) EncodeQuestion() []byte {
buf := bytes.NewBuffer([]byte{})
buf.Write(q.QName.Encode())
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType))
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass))
return buf.Bytes()
}
type Address struct {
ResourceRecordHeader
Address net.IP
}
func (a Address) EncodeResource() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write([]byte{0x00, 0x04})
rr.Write(a.Address[:])
return rr.Bytes()
}
type NameServer struct {
ResourceRecordHeader
NSDName Name
}
func (a NameServer) EncodeResource() []byte {
rdataBytes := a.NSDName.Encode()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type CName struct {
ResourceRecordHeader
CName Name
}
func (a CName) EncodeResource() []byte {
rdataBytes := a.CName.Encode()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type Authority struct {
ResourceRecordHeader
MName Name
RName Name
Serial uint32
Refresh uint32
Retry uint32
Expire uint32
Minimum uint32
}
func (a Authority) EncodeResource() []byte {
rdata := bytes.NewBuffer([]byte{})
rdata.Write(a.MName.Encode())
rdata.Write(a.RName.Encode())
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum))
rdataBytes := rdata.Bytes()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type Unknown struct {
ResourceRecordHeader
Data []byte
}
func (u Unknown) EncodeResource() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(u.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data))))
rr.Write(u.Data)
return rr.Bytes()
}
type Message struct {
ID uint16
Flags Flags
QD []Question
AN []ResourceRecord
NS []ResourceRecord
AR []ResourceRecord
}
func (m Message) Encode() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID))
rr.Write(m.Flags.Encode())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR))))
for _, q := range m.QD {
rr.Write(q.EncodeQuestion())
}
for _, a := range m.AN {
rr.Write(a.EncodeResource())
}
for _, ns := range m.NS {
rr.Write(ns.EncodeResource())
}
for _, a := range m.AR {
rr.Write(a.EncodeResource())
}
return rr.Bytes()
}

225
dns-proxy/types_test.go Normal file
View File

@ -0,0 +1,225 @@
package dnsProxy
import (
"bytes"
"testing"
)
func TestDNSResourceRecordHeaderEncode(t *testing.T) {
recordHeader := ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
}
recordHeaderEncoded := recordHeader.EncodeHeader()
recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0}
if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 {
t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood)
}
}
func TestDNSNameString(t *testing.T) {
dnsName := Name{Parts: []string{"example", "com"}}
dnsNameString := dnsName.String()
dnsNameStringGood := "example.com"
if dnsNameString != dnsNameStringGood {
t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood)
}
}
func TestDNSNameEncode(t *testing.T) {
dnsName := Name{Parts: []string{"example", "com"}}
dnsNameEncoded := dnsName.Encode()
dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 {
t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood)
}
}
func TestDNSFlagsEncode(t *testing.T) {
dnsFlags := Flags{
QR: 0x1,
Opcode: 0xF,
AA: 0x0,
TC: 0x0,
RD: 0x1,
RA: 0x1,
Z1: 0x0,
Z2: 0x0,
Z3: 0x0,
RCode: 0xF,
}
dnsFlagsEncoded := dnsFlags.Encode()
dnsFlagsEncodedGood := []byte{0xf9, 0x8f}
if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 {
t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood)
}
}
func TestDNSQuestionEncode(t *testing.T) {
dnsQuestion := Question{
QName: Name{Parts: []string{"example", "com"}},
QType: 0x001c,
QClass: 0x0001,
}
dnsQuestionEncoded := dnsQuestion.EncodeQuestion()
dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01}
if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 {
t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood)
}
}
func TestDNSAddressEncode(t *testing.T) {
dnsAddress := Address{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Address: []byte{192, 168, 1, 1},
}
dnsAddressEncoded := dnsAddress.EncodeResource()
dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1}
if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 {
t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood)
}
}
func TestDNSNameServerEncode(t *testing.T) {
dnsNameServer := NameServer{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
NSDName: Name{Parts: []string{"example", "com"}},
}
dnsNameServerEncoded := dnsNameServer.EncodeResource()
dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 {
t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood)
}
}
func TestDNSCNameEncode(t *testing.T) {
dnsCName := CName{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
CName: Name{Parts: []string{"example", "com"}},
}
dnsCNameEncoded := dnsCName.EncodeResource()
dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 {
t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood)
}
}
func TestDNSAuthorityEncode(t *testing.T) {
dnsAuthority := Authority{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
MName: Name{Parts: []string{"example", "com"}},
RName: Name{Parts: []string{"example", "com"}},
Serial: 0x12345678,
Refresh: 0x12345678,
Retry: 0x12345678,
Expire: 0x12345678,
Minimum: 0x12345678,
}
dnsAuthorityEncoded := dnsAuthority.EncodeResource()
dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}
if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 {
t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood)
}
}
func TestDNSUnknownEncode(t *testing.T) {
dnsUnknown := Unknown{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Data: []byte{0x01, 0x02, 0x03},
}
dnsUnknownEncoded := dnsUnknown.EncodeResource()
dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 {
t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood)
}
}
func TestDNSMessageEncode(t *testing.T) {
dnsMessage := Message{
ID: 0x00FF,
Flags: Flags{
QR: 0x1,
Opcode: 0xF,
AA: 0x0,
TC: 0x0,
RD: 0x1,
RA: 0x1,
Z1: 0x0,
Z2: 0x0,
Z3: 0x0,
RCode: 0xF,
},
QD: []Question{
{
QName: Name{Parts: []string{"example", "com"}},
QType: 0x001c,
QClass: 0x0001,
},
},
AN: []ResourceRecord{
Unknown{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Data: []byte{0x01, 0x02, 0x03},
},
},
NS: []ResourceRecord{
Unknown{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Data: []byte{0x01, 0x02, 0x03},
},
},
AR: []ResourceRecord{
Unknown{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Data: []byte{0x01, 0x02, 0x03},
},
},
}
dnsMessageEncoded := dnsMessage.Encode()
dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 {
t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood)
}
}

189
main.go
View File

@ -1,176 +1,43 @@
package main package main
import ( import (
"encoding/binary"
"encoding/hex"
"fmt" "fmt"
dnsProxy "kvas2-go/dns-proxy"
"log" "log"
"net"
"strings"
"time"
) )
var ( var (
ListenPort = 7548 ListenPort = uint16(7548)
UsableDNSServerAddress = "127.0.0.1" UsableDNSServerAddress = "127.0.0.1"
UsableDNSServerPort = 53 UsableDNSServerPort = uint16(53)
DNSMaxPackageSize = 4096
) )
func parseName(response []byte, pos int) (string, int) {
var nameParts []string
var jumped bool
var outPos int
responseLen := len(response)
for {
length := int(response[pos])
pos++
if length == 0 {
break
}
if length&0xC0 == 0xC0 {
if !jumped {
outPos = pos + 1
}
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
jumped = true
continue
}
if pos+length > responseLen {
break
}
nameParts = append(nameParts, string(response[pos:pos+length]))
pos += length
}
if !jumped {
outPos = pos
}
return strings.Join(nameParts, "."), outPos
}
func sendToUpstream(upstreamAddr string, request []byte) ([]byte, error) {
conn, err := net.Dial("udp", upstreamAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
}
defer conn.Close()
_, err = conn.Write(request)
if err != nil {
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
}
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil {
return nil, fmt.Errorf("failed to set timeout: %w", err)
}
response := make([]byte, DNSMaxPackageSize)
n, err := conn.Read(response)
if err != nil {
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
}
return response[:n], nil
}
func main() { func main() {
addr := fmt.Sprintf(":%d", ListenPort) proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort)
udpAddr, err := net.ResolveUDPAddr("udp", addr) proxy.MsgHandler = func(msg *dnsProxy.Message) {
if err != nil { for _, q := range msg.QD {
log.Fatalf("Failed to resolve address: %v", err) fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String())
} }
for _, a := range msg.AN {
conn, err := net.ListenUDP("udp", udpAddr) switch v := a.(type) {
if err != nil { case dnsProxy.Address:
log.Fatalf("Failed to listen on UDP: %v", err) fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
} case dnsProxy.CName:
defer conn.Close() fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
default:
fmt.Printf("DNS server is running on %s...\n", addr) fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
}
for { }
buffer := make([]byte, DNSMaxPackageSize) for _, a := range msg.NS {
n, clientAddr, err := conn.ReadFromUDP(buffer) fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
if err != nil { }
log.Printf("Failed to read from UDP: %v", err) for _, a := range msg.AR {
continue fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
} }
go handleDNSRequest(conn, clientAddr, buffer[:n])
}
}
func process(response []byte) {
responseLen := len(response)
if responseLen <= 12 {
return
}
qCount := int(binary.LittleEndian.Uint16(response[5:7]))
aCount := int(binary.LittleEndian.Uint16(response[7:9]))
pos := 12
for i := 0; i < qCount; i++ {
var name string
name, pos = parseName(response, pos)
fmt.Printf("Requested name: %s\n", name)
pos += 4
}
for i := 0; i < aCount; i++ {
name, newPos := parseName(response, pos)
pos = newPos
if pos+10 > responseLen {
break
}
qtype := binary.BigEndian.Uint16(response[pos : pos+2])
pos += 2
qclass := binary.BigEndian.Uint16(response[pos : pos+2])
pos += 2
ttl := binary.BigEndian.Uint32(response[pos : pos+4])
pos += 4
rdlength := binary.BigEndian.Uint16(response[pos : pos+2])
pos += 2
if pos+int(rdlength) > responseLen {
break
}
if qtype == 1 && qclass == 1 && rdlength == 4 {
ip := net.IPv4(response[pos], response[pos+1], response[pos+2], response[pos+3])
fmt.Printf("Parsed A record: %s -> %s, TTL: %d\n", name, ip, ttl)
}
pos += int(rdlength)
}
}
func handleDNSRequest(conn *net.UDPConn, clientAddr *net.UDPAddr, buffer []byte) {
upstreamAddr := fmt.Sprintf("%s:%d", UsableDNSServerAddress, UsableDNSServerPort)
upstreamResponse, err := sendToUpstream(upstreamAddr, buffer)
if err != nil {
log.Printf("Failed to get response from upstream DNS: %v", err)
return
}
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
process(upstreamResponse)
_, err = conn.WriteToUDP(upstreamResponse, clientAddr)
if err != nil {
log.Printf("Failed to send DNS response: %v", err)
} }
err := proxy.Listen()
if err != nil {
log.Fatal(err)
}
defer proxy.Close()
} }