dns-proxy refactoring
This commit is contained in:
parent
7ea3ec1b70
commit
27e8086c50
120
dns-proxy/dns-proxy.go
Normal file
120
dns-proxy/dns-proxy.go
Normal file
@ -0,0 +1,120 @@
|
||||
package dnsProxy
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
DNSMaxUDPPackageSize = 4096
|
||||
DNSMaxTCPPackageSize = 65536
|
||||
)
|
||||
|
||||
type DNSProxy struct {
|
||||
listenAddr string
|
||||
upstreamAddr string
|
||||
|
||||
udpConn *net.UDPConn
|
||||
|
||||
MsgHandler func(*Message)
|
||||
}
|
||||
|
||||
func (p DNSProxy) Close() error {
|
||||
return p.udpConn.Close()
|
||||
}
|
||||
|
||||
func (p DNSProxy) sendToUpstream(isTCP bool, request []byte) ([]byte, error) {
|
||||
protocol := "udp"
|
||||
if isTCP {
|
||||
protocol = "tcp"
|
||||
}
|
||||
|
||||
conn, err := net.Dial(protocol, p.upstreamAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
|
||||
}
|
||||
|
||||
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set timeout: %w", err)
|
||||
}
|
||||
|
||||
var response []byte
|
||||
if !isTCP {
|
||||
response = make([]byte, DNSMaxUDPPackageSize)
|
||||
} else {
|
||||
response = make([]byte, DNSMaxTCPPackageSize)
|
||||
}
|
||||
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
}
|
||||
|
||||
func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) {
|
||||
upstreamResponse, err := p.sendToUpstream(false, buffer)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get response from upstream DNS: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
|
||||
|
||||
msg, err := ParseResponse(upstreamResponse)
|
||||
if err == nil {
|
||||
if p.MsgHandler != nil {
|
||||
p.MsgHandler(msg)
|
||||
}
|
||||
} else {
|
||||
log.Printf("error while parsing response: %v", err)
|
||||
}
|
||||
|
||||
_, err = p.udpConn.WriteToUDP(upstreamResponse, clientAddr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p DNSProxy) Listen() error {
|
||||
var err error
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", p.listenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve UDP address: %v", err)
|
||||
}
|
||||
|
||||
p.udpConn, err = net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on UDP: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
buffer := make([]byte, DNSMaxUDPPackageSize)
|
||||
n, clientAddr, err := p.udpConn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read from UDP: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go p.handleDNSRequest(clientAddr, buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func New(listenAddr string, listenPort uint16, upstreamAddr string, upstreamPort uint16) *DNSProxy {
|
||||
return &DNSProxy{
|
||||
listenAddr: fmt.Sprintf("%s:%d", listenAddr, listenPort),
|
||||
upstreamAddr: fmt.Sprintf("%s:%d", upstreamAddr, upstreamPort),
|
||||
}
|
||||
}
|
178
dns-proxy/parser.go
Normal file
178
dns-proxy/parser.go
Normal file
@ -0,0 +1,178 @@
|
||||
package dnsProxy
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header")
|
||||
ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header")
|
||||
ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data")
|
||||
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
|
||||
)
|
||||
|
||||
func parseName(response []byte, pos int) (Name, int) {
|
||||
var nameParts []string
|
||||
var jumped bool
|
||||
var outPos int
|
||||
responseLen := len(response)
|
||||
|
||||
for {
|
||||
length := int(response[pos])
|
||||
pos++
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if length&0xC0 == 0xC0 {
|
||||
if !jumped {
|
||||
outPos = pos + 1
|
||||
}
|
||||
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
|
||||
jumped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if pos+length > responseLen {
|
||||
break
|
||||
}
|
||||
|
||||
nameParts = append(nameParts, string(response[pos:pos+length]))
|
||||
pos += length
|
||||
}
|
||||
|
||||
if !jumped {
|
||||
outPos = pos
|
||||
}
|
||||
return Name{Parts: nameParts}, outPos
|
||||
}
|
||||
|
||||
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
|
||||
responseLen := len(response)
|
||||
|
||||
var rhname Name
|
||||
rhname, pos = parseName(response, pos)
|
||||
|
||||
if responseLen < pos+10 {
|
||||
return nil, pos, ErrInvalidDNSResourceRecordHeader
|
||||
}
|
||||
|
||||
rh := ResourceRecordHeader{
|
||||
Name: rhname,
|
||||
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
|
||||
}
|
||||
rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
|
||||
|
||||
pos += 10
|
||||
|
||||
if pos+rdlength > responseLen {
|
||||
return nil, pos, ErrInvalidDNSResourceRecordData
|
||||
}
|
||||
|
||||
switch rh.Type {
|
||||
case 1:
|
||||
if rdlength == 4 {
|
||||
return Address{
|
||||
ResourceRecordHeader: rh,
|
||||
Address: response[pos+0 : pos+4],
|
||||
}, pos + 4, nil
|
||||
} else {
|
||||
return nil, pos, ErrInvalidDNSAddressResourceData
|
||||
}
|
||||
case 2:
|
||||
var ns Name
|
||||
ns, pos = parseName(response, pos)
|
||||
return NameServer{
|
||||
ResourceRecordHeader: rh,
|
||||
NSDName: ns,
|
||||
}, pos, nil
|
||||
case 5:
|
||||
var cname Name
|
||||
cname, pos = parseName(response, pos)
|
||||
return CName{
|
||||
ResourceRecordHeader: rh,
|
||||
CName: cname,
|
||||
}, pos, nil
|
||||
}
|
||||
|
||||
return Unknown{
|
||||
ResourceRecordHeader: rh,
|
||||
Data: response[pos+0 : pos+rdlength],
|
||||
}, pos + rdlength, nil
|
||||
}
|
||||
|
||||
func ParseResponse(response []byte) (*Message, error) {
|
||||
var err error
|
||||
|
||||
responseLen := len(response)
|
||||
if responseLen < 12 {
|
||||
return nil, ErrInvalidDNSMessageHeader
|
||||
}
|
||||
|
||||
msg := new(Message)
|
||||
|
||||
msg.ID = binary.BigEndian.Uint16(response[0:2])
|
||||
|
||||
flagsRAW := binary.BigEndian.Uint16(response[2:4])
|
||||
msg.Flags = Flags{
|
||||
QR: uint8(flagsRAW >> 15 & 0x1),
|
||||
Opcode: uint8(flagsRAW >> 11 & 0xF),
|
||||
AA: uint8(flagsRAW >> 10 & 0x1),
|
||||
TC: uint8(flagsRAW >> 9 & 0x1),
|
||||
RD: uint8(flagsRAW >> 8 & 0x1),
|
||||
RA: uint8(flagsRAW >> 7 & 0x1),
|
||||
Z1: uint8(flagsRAW >> 6 & 0x1),
|
||||
Z2: uint8(flagsRAW >> 5 & 0x1),
|
||||
Z3: uint8(flagsRAW >> 4 & 0x1),
|
||||
RCode: uint8(flagsRAW >> 0 & 0xF),
|
||||
}
|
||||
|
||||
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
|
||||
anCount := int(binary.BigEndian.Uint16(response[6:8]))
|
||||
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
|
||||
arCount := int(binary.BigEndian.Uint16(response[10:12]))
|
||||
|
||||
pos := 12
|
||||
|
||||
msg.QD = make([]Question, qdCount)
|
||||
for i := 0; i < qdCount; i++ {
|
||||
var name Name
|
||||
name, pos = parseName(response, pos)
|
||||
msg.QD[i] = Question{
|
||||
QName: name,
|
||||
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||
}
|
||||
pos += 4
|
||||
}
|
||||
|
||||
msg.AN = make([]ResourceRecord, anCount)
|
||||
for i := 0; i < anCount; i++ {
|
||||
msg.AN[i], pos, err = parseResourceRecord(response, pos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while parsing AN record: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
msg.NS = make([]ResourceRecord, nsCount)
|
||||
for i := 0; i < nsCount; i++ {
|
||||
msg.NS[i], pos, err = parseResourceRecord(response, pos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while parsing NS record: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
msg.AR = make([]ResourceRecord, arCount)
|
||||
for i := 0; i < arCount; i++ {
|
||||
msg.AR[i], pos, err = parseResourceRecord(response, pos)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error while parsing AR record: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
196
dns-proxy/types.go
Normal file
196
dns-proxy/types.go
Normal file
@ -0,0 +1,196 @@
|
||||
package dnsProxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ResourceRecord interface {
|
||||
EncodeResource() []byte
|
||||
}
|
||||
|
||||
type ResourceRecordHeader struct {
|
||||
Name Name
|
||||
Type uint16
|
||||
Class uint16
|
||||
TTL uint32
|
||||
}
|
||||
|
||||
func (q ResourceRecordHeader) EncodeHeader() []byte {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
buf.Write(q.Name.Encode())
|
||||
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type))
|
||||
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class))
|
||||
buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
type Name struct {
|
||||
Parts []string
|
||||
}
|
||||
|
||||
func (n Name) String() string {
|
||||
return strings.Join(n.Parts, ".")
|
||||
}
|
||||
|
||||
func (n Name) Encode() []byte {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
for _, part := range n.Parts {
|
||||
partLen := byte(len(part)) & 0x3F
|
||||
buf.WriteByte(partLen)
|
||||
buf.Write([]byte(part)[0:partLen])
|
||||
}
|
||||
buf.WriteByte(0)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
type Flags struct {
|
||||
QR uint8
|
||||
Opcode uint8
|
||||
AA uint8
|
||||
TC uint8
|
||||
RD uint8
|
||||
RA uint8
|
||||
Z1 uint8
|
||||
Z2 uint8
|
||||
Z3 uint8
|
||||
RCode uint8
|
||||
}
|
||||
|
||||
func (f Flags) Encode() []byte {
|
||||
return []byte{
|
||||
f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0,
|
||||
f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0,
|
||||
}
|
||||
}
|
||||
|
||||
type Question struct {
|
||||
QName Name
|
||||
QType uint16
|
||||
QClass uint16
|
||||
}
|
||||
|
||||
func (q Question) EncodeQuestion() []byte {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
buf.Write(q.QName.Encode())
|
||||
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType))
|
||||
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
type Address struct {
|
||||
ResourceRecordHeader
|
||||
Address net.IP
|
||||
}
|
||||
|
||||
func (a Address) EncodeResource() []byte {
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||
rr.Write([]byte{0x00, 0x04})
|
||||
rr.Write(a.Address[:])
|
||||
return rr.Bytes()
|
||||
}
|
||||
|
||||
type NameServer struct {
|
||||
ResourceRecordHeader
|
||||
NSDName Name
|
||||
}
|
||||
|
||||
func (a NameServer) EncodeResource() []byte {
|
||||
rdataBytes := a.NSDName.Encode()
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||
rr.Write(rdataBytes)
|
||||
return rr.Bytes()
|
||||
}
|
||||
|
||||
type CName struct {
|
||||
ResourceRecordHeader
|
||||
CName Name
|
||||
}
|
||||
|
||||
func (a CName) EncodeResource() []byte {
|
||||
rdataBytes := a.CName.Encode()
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||
rr.Write(rdataBytes)
|
||||
return rr.Bytes()
|
||||
}
|
||||
|
||||
type Authority struct {
|
||||
ResourceRecordHeader
|
||||
MName Name
|
||||
RName Name
|
||||
Serial uint32
|
||||
Refresh uint32
|
||||
Retry uint32
|
||||
Expire uint32
|
||||
Minimum uint32
|
||||
}
|
||||
|
||||
func (a Authority) EncodeResource() []byte {
|
||||
rdata := bytes.NewBuffer([]byte{})
|
||||
rdata.Write(a.MName.Encode())
|
||||
rdata.Write(a.RName.Encode())
|
||||
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial))
|
||||
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh))
|
||||
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry))
|
||||
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire))
|
||||
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum))
|
||||
rdataBytes := rdata.Bytes()
|
||||
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||
rr.Write(rdataBytes)
|
||||
return rr.Bytes()
|
||||
}
|
||||
|
||||
type Unknown struct {
|
||||
ResourceRecordHeader
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (u Unknown) EncodeResource() []byte {
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(u.ResourceRecordHeader.EncodeHeader())
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data))))
|
||||
rr.Write(u.Data)
|
||||
return rr.Bytes()
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
ID uint16
|
||||
Flags Flags
|
||||
QD []Question
|
||||
AN []ResourceRecord
|
||||
NS []ResourceRecord
|
||||
AR []ResourceRecord
|
||||
}
|
||||
|
||||
func (m Message) Encode() []byte {
|
||||
rr := bytes.NewBuffer([]byte{})
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID))
|
||||
rr.Write(m.Flags.Encode())
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD))))
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN))))
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS))))
|
||||
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR))))
|
||||
for _, q := range m.QD {
|
||||
rr.Write(q.EncodeQuestion())
|
||||
}
|
||||
for _, a := range m.AN {
|
||||
rr.Write(a.EncodeResource())
|
||||
}
|
||||
for _, ns := range m.NS {
|
||||
rr.Write(ns.EncodeResource())
|
||||
}
|
||||
for _, a := range m.AR {
|
||||
rr.Write(a.EncodeResource())
|
||||
}
|
||||
return rr.Bytes()
|
||||
}
|
225
dns-proxy/types_test.go
Normal file
225
dns-proxy/types_test.go
Normal file
@ -0,0 +1,225 @@
|
||||
package dnsProxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDNSResourceRecordHeaderEncode(t *testing.T) {
|
||||
recordHeader := ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
}
|
||||
recordHeaderEncoded := recordHeader.EncodeHeader()
|
||||
recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0}
|
||||
if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 {
|
||||
t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSNameString(t *testing.T) {
|
||||
dnsName := Name{Parts: []string{"example", "com"}}
|
||||
dnsNameString := dnsName.String()
|
||||
dnsNameStringGood := "example.com"
|
||||
if dnsNameString != dnsNameStringGood {
|
||||
t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSNameEncode(t *testing.T) {
|
||||
dnsName := Name{Parts: []string{"example", "com"}}
|
||||
dnsNameEncoded := dnsName.Encode()
|
||||
dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||
if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 {
|
||||
t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFlagsEncode(t *testing.T) {
|
||||
dnsFlags := Flags{
|
||||
QR: 0x1,
|
||||
Opcode: 0xF,
|
||||
AA: 0x0,
|
||||
TC: 0x0,
|
||||
RD: 0x1,
|
||||
RA: 0x1,
|
||||
Z1: 0x0,
|
||||
Z2: 0x0,
|
||||
Z3: 0x0,
|
||||
RCode: 0xF,
|
||||
}
|
||||
dnsFlagsEncoded := dnsFlags.Encode()
|
||||
dnsFlagsEncodedGood := []byte{0xf9, 0x8f}
|
||||
if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 {
|
||||
t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSQuestionEncode(t *testing.T) {
|
||||
dnsQuestion := Question{
|
||||
QName: Name{Parts: []string{"example", "com"}},
|
||||
QType: 0x001c,
|
||||
QClass: 0x0001,
|
||||
}
|
||||
dnsQuestionEncoded := dnsQuestion.EncodeQuestion()
|
||||
dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01}
|
||||
if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 {
|
||||
t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAddressEncode(t *testing.T) {
|
||||
dnsAddress := Address{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
Address: []byte{192, 168, 1, 1},
|
||||
}
|
||||
dnsAddressEncoded := dnsAddress.EncodeResource()
|
||||
dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1}
|
||||
if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 {
|
||||
t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSNameServerEncode(t *testing.T) {
|
||||
dnsNameServer := NameServer{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
NSDName: Name{Parts: []string{"example", "com"}},
|
||||
}
|
||||
dnsNameServerEncoded := dnsNameServer.EncodeResource()
|
||||
dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||
if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 {
|
||||
t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSCNameEncode(t *testing.T) {
|
||||
dnsCName := CName{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
CName: Name{Parts: []string{"example", "com"}},
|
||||
}
|
||||
dnsCNameEncoded := dnsCName.EncodeResource()
|
||||
dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||
if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 {
|
||||
t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSAuthorityEncode(t *testing.T) {
|
||||
dnsAuthority := Authority{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
MName: Name{Parts: []string{"example", "com"}},
|
||||
RName: Name{Parts: []string{"example", "com"}},
|
||||
Serial: 0x12345678,
|
||||
Refresh: 0x12345678,
|
||||
Retry: 0x12345678,
|
||||
Expire: 0x12345678,
|
||||
Minimum: 0x12345678,
|
||||
}
|
||||
dnsAuthorityEncoded := dnsAuthority.EncodeResource()
|
||||
dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}
|
||||
if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 {
|
||||
t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSUnknownEncode(t *testing.T) {
|
||||
dnsUnknown := Unknown{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
Data: []byte{0x01, 0x02, 0x03},
|
||||
}
|
||||
dnsUnknownEncoded := dnsUnknown.EncodeResource()
|
||||
dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||
if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 {
|
||||
t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSMessageEncode(t *testing.T) {
|
||||
dnsMessage := Message{
|
||||
ID: 0x00FF,
|
||||
Flags: Flags{
|
||||
QR: 0x1,
|
||||
Opcode: 0xF,
|
||||
AA: 0x0,
|
||||
TC: 0x0,
|
||||
RD: 0x1,
|
||||
RA: 0x1,
|
||||
Z1: 0x0,
|
||||
Z2: 0x0,
|
||||
Z3: 0x0,
|
||||
RCode: 0xF,
|
||||
},
|
||||
QD: []Question{
|
||||
{
|
||||
QName: Name{Parts: []string{"example", "com"}},
|
||||
QType: 0x001c,
|
||||
QClass: 0x0001,
|
||||
},
|
||||
},
|
||||
AN: []ResourceRecord{
|
||||
Unknown{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
Data: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
},
|
||||
NS: []ResourceRecord{
|
||||
Unknown{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
Data: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
},
|
||||
AR: []ResourceRecord{
|
||||
Unknown{
|
||||
ResourceRecordHeader: ResourceRecordHeader{
|
||||
Name: Name{Parts: []string{"example", "com"}},
|
||||
Type: 0xF0,
|
||||
Class: 0xF0,
|
||||
TTL: 0x77770FF0,
|
||||
},
|
||||
Data: []byte{0x01, 0x02, 0x03},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsMessageEncoded := dnsMessage.Encode()
|
||||
dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||
if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 {
|
||||
t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood)
|
||||
}
|
||||
}
|
189
main.go
189
main.go
@ -1,176 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
dnsProxy "kvas2-go/dns-proxy"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ListenPort = 7548
|
||||
ListenPort = uint16(7548)
|
||||
UsableDNSServerAddress = "127.0.0.1"
|
||||
UsableDNSServerPort = 53
|
||||
DNSMaxPackageSize = 4096
|
||||
UsableDNSServerPort = uint16(53)
|
||||
)
|
||||
|
||||
func parseName(response []byte, pos int) (string, int) {
|
||||
var nameParts []string
|
||||
var jumped bool
|
||||
var outPos int
|
||||
responseLen := len(response)
|
||||
|
||||
for {
|
||||
length := int(response[pos])
|
||||
pos++
|
||||
if length == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if length&0xC0 == 0xC0 {
|
||||
if !jumped {
|
||||
outPos = pos + 1
|
||||
}
|
||||
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
|
||||
jumped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if pos+length > responseLen {
|
||||
break
|
||||
}
|
||||
|
||||
nameParts = append(nameParts, string(response[pos:pos+length]))
|
||||
pos += length
|
||||
}
|
||||
|
||||
if !jumped {
|
||||
outPos = pos
|
||||
}
|
||||
return strings.Join(nameParts, "."), outPos
|
||||
}
|
||||
|
||||
func sendToUpstream(upstreamAddr string, request []byte) ([]byte, error) {
|
||||
conn, err := net.Dial("udp", upstreamAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
_, err = conn.Write(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
|
||||
}
|
||||
|
||||
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to set timeout: %w", err)
|
||||
}
|
||||
|
||||
response := make([]byte, DNSMaxPackageSize)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
addr := fmt.Sprintf(":%d", ListenPort)
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to resolve address: %v", err)
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to listen on UDP: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
fmt.Printf("DNS server is running on %s...\n", addr)
|
||||
|
||||
for {
|
||||
buffer := make([]byte, DNSMaxPackageSize)
|
||||
n, clientAddr, err := conn.ReadFromUDP(buffer)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read from UDP: %v", err)
|
||||
continue
|
||||
proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort)
|
||||
proxy.MsgHandler = func(msg *dnsProxy.Message) {
|
||||
for _, q := range msg.QD {
|
||||
fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String())
|
||||
}
|
||||
for _, a := range msg.AN {
|
||||
switch v := a.(type) {
|
||||
case dnsProxy.Address:
|
||||
fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
|
||||
case dnsProxy.CName:
|
||||
fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
|
||||
default:
|
||||
fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
|
||||
}
|
||||
}
|
||||
for _, a := range msg.NS {
|
||||
fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
|
||||
}
|
||||
for _, a := range msg.AR {
|
||||
fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
|
||||
}
|
||||
|
||||
go handleDNSRequest(conn, clientAddr, buffer[:n])
|
||||
}
|
||||
}
|
||||
|
||||
func process(response []byte) {
|
||||
responseLen := len(response)
|
||||
if responseLen <= 12 {
|
||||
return
|
||||
}
|
||||
|
||||
qCount := int(binary.LittleEndian.Uint16(response[5:7]))
|
||||
aCount := int(binary.LittleEndian.Uint16(response[7:9]))
|
||||
|
||||
pos := 12
|
||||
|
||||
for i := 0; i < qCount; i++ {
|
||||
var name string
|
||||
name, pos = parseName(response, pos)
|
||||
fmt.Printf("Requested name: %s\n", name)
|
||||
pos += 4
|
||||
}
|
||||
|
||||
for i := 0; i < aCount; i++ {
|
||||
name, newPos := parseName(response, pos)
|
||||
pos = newPos
|
||||
|
||||
if pos+10 > responseLen {
|
||||
break
|
||||
}
|
||||
|
||||
qtype := binary.BigEndian.Uint16(response[pos : pos+2])
|
||||
pos += 2
|
||||
|
||||
qclass := binary.BigEndian.Uint16(response[pos : pos+2])
|
||||
pos += 2
|
||||
|
||||
ttl := binary.BigEndian.Uint32(response[pos : pos+4])
|
||||
pos += 4
|
||||
|
||||
rdlength := binary.BigEndian.Uint16(response[pos : pos+2])
|
||||
pos += 2
|
||||
|
||||
if pos+int(rdlength) > responseLen {
|
||||
break
|
||||
}
|
||||
|
||||
if qtype == 1 && qclass == 1 && rdlength == 4 {
|
||||
ip := net.IPv4(response[pos], response[pos+1], response[pos+2], response[pos+3])
|
||||
fmt.Printf("Parsed A record: %s -> %s, TTL: %d\n", name, ip, ttl)
|
||||
}
|
||||
|
||||
pos += int(rdlength)
|
||||
}
|
||||
}
|
||||
|
||||
func handleDNSRequest(conn *net.UDPConn, clientAddr *net.UDPAddr, buffer []byte) {
|
||||
upstreamAddr := fmt.Sprintf("%s:%d", UsableDNSServerAddress, UsableDNSServerPort)
|
||||
|
||||
upstreamResponse, err := sendToUpstream(upstreamAddr, buffer)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get response from upstream DNS: %v", err)
|
||||
return
|
||||
}
|
||||
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
|
||||
|
||||
process(upstreamResponse)
|
||||
|
||||
_, err = conn.WriteToUDP(upstreamResponse, clientAddr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to send DNS response: %v", err)
|
||||
}
|
||||
err := proxy.Listen()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer proxy.Close()
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user