dns-proxy refactoring
This commit is contained in:
parent
7ea3ec1b70
commit
27e8086c50
120
dns-proxy/dns-proxy.go
Normal file
120
dns-proxy/dns-proxy.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DNSMaxUDPPackageSize = 4096
|
||||||
|
DNSMaxTCPPackageSize = 65536
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNSProxy struct {
|
||||||
|
listenAddr string
|
||||||
|
upstreamAddr string
|
||||||
|
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
|
||||||
|
MsgHandler func(*Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) Close() error {
|
||||||
|
return p.udpConn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) sendToUpstream(isTCP bool, request []byte) ([]byte, error) {
|
||||||
|
protocol := "udp"
|
||||||
|
if isTCP {
|
||||||
|
protocol = "tcp"
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial(protocol, p.upstreamAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.Write(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to set timeout: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response []byte
|
||||||
|
if !isTCP {
|
||||||
|
response = make([]byte, DNSMaxUDPPackageSize)
|
||||||
|
} else {
|
||||||
|
response = make([]byte, DNSMaxTCPPackageSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := conn.Read(response)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return response[:n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) {
|
||||||
|
upstreamResponse, err := p.sendToUpstream(false, buffer)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get response from upstream DNS: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
|
||||||
|
|
||||||
|
msg, err := ParseResponse(upstreamResponse)
|
||||||
|
if err == nil {
|
||||||
|
if p.MsgHandler != nil {
|
||||||
|
p.MsgHandler(msg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("error while parsing response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = p.udpConn.WriteToUDP(upstreamResponse, clientAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to send DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) Listen() error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", p.listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve UDP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.udpConn, err = net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen on UDP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
buffer := make([]byte, DNSMaxUDPPackageSize)
|
||||||
|
n, clientAddr, err := p.udpConn.ReadFromUDP(buffer)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read from UDP: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.handleDNSRequest(clientAddr, buffer[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(listenAddr string, listenPort uint16, upstreamAddr string, upstreamPort uint16) *DNSProxy {
|
||||||
|
return &DNSProxy{
|
||||||
|
listenAddr: fmt.Sprintf("%s:%d", listenAddr, listenPort),
|
||||||
|
upstreamAddr: fmt.Sprintf("%s:%d", upstreamAddr, upstreamPort),
|
||||||
|
}
|
||||||
|
}
|
178
dns-proxy/parser.go
Normal file
178
dns-proxy/parser.go
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header")
|
||||||
|
ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header")
|
||||||
|
ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data")
|
||||||
|
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseName(response []byte, pos int) (Name, int) {
|
||||||
|
var nameParts []string
|
||||||
|
var jumped bool
|
||||||
|
var outPos int
|
||||||
|
responseLen := len(response)
|
||||||
|
|
||||||
|
for {
|
||||||
|
length := int(response[pos])
|
||||||
|
pos++
|
||||||
|
if length == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if length&0xC0 == 0xC0 {
|
||||||
|
if !jumped {
|
||||||
|
outPos = pos + 1
|
||||||
|
}
|
||||||
|
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
|
||||||
|
jumped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if pos+length > responseLen {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
nameParts = append(nameParts, string(response[pos:pos+length]))
|
||||||
|
pos += length
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jumped {
|
||||||
|
outPos = pos
|
||||||
|
}
|
||||||
|
return Name{Parts: nameParts}, outPos
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
|
||||||
|
responseLen := len(response)
|
||||||
|
|
||||||
|
var rhname Name
|
||||||
|
rhname, pos = parseName(response, pos)
|
||||||
|
|
||||||
|
if responseLen < pos+10 {
|
||||||
|
return nil, pos, ErrInvalidDNSResourceRecordHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
rh := ResourceRecordHeader{
|
||||||
|
Name: rhname,
|
||||||
|
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||||
|
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||||
|
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
|
||||||
|
}
|
||||||
|
rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
|
||||||
|
|
||||||
|
pos += 10
|
||||||
|
|
||||||
|
if pos+rdlength > responseLen {
|
||||||
|
return nil, pos, ErrInvalidDNSResourceRecordData
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rh.Type {
|
||||||
|
case 1:
|
||||||
|
if rdlength == 4 {
|
||||||
|
return Address{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
Address: response[pos+0 : pos+4],
|
||||||
|
}, pos + 4, nil
|
||||||
|
} else {
|
||||||
|
return nil, pos, ErrInvalidDNSAddressResourceData
|
||||||
|
}
|
||||||
|
case 2:
|
||||||
|
var ns Name
|
||||||
|
ns, pos = parseName(response, pos)
|
||||||
|
return NameServer{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
NSDName: ns,
|
||||||
|
}, pos, nil
|
||||||
|
case 5:
|
||||||
|
var cname Name
|
||||||
|
cname, pos = parseName(response, pos)
|
||||||
|
return CName{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
CName: cname,
|
||||||
|
}, pos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return Unknown{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
Data: response[pos+0 : pos+rdlength],
|
||||||
|
}, pos + rdlength, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseResponse(response []byte) (*Message, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
responseLen := len(response)
|
||||||
|
if responseLen < 12 {
|
||||||
|
return nil, ErrInvalidDNSMessageHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := new(Message)
|
||||||
|
|
||||||
|
msg.ID = binary.BigEndian.Uint16(response[0:2])
|
||||||
|
|
||||||
|
flagsRAW := binary.BigEndian.Uint16(response[2:4])
|
||||||
|
msg.Flags = Flags{
|
||||||
|
QR: uint8(flagsRAW >> 15 & 0x1),
|
||||||
|
Opcode: uint8(flagsRAW >> 11 & 0xF),
|
||||||
|
AA: uint8(flagsRAW >> 10 & 0x1),
|
||||||
|
TC: uint8(flagsRAW >> 9 & 0x1),
|
||||||
|
RD: uint8(flagsRAW >> 8 & 0x1),
|
||||||
|
RA: uint8(flagsRAW >> 7 & 0x1),
|
||||||
|
Z1: uint8(flagsRAW >> 6 & 0x1),
|
||||||
|
Z2: uint8(flagsRAW >> 5 & 0x1),
|
||||||
|
Z3: uint8(flagsRAW >> 4 & 0x1),
|
||||||
|
RCode: uint8(flagsRAW >> 0 & 0xF),
|
||||||
|
}
|
||||||
|
|
||||||
|
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
|
||||||
|
anCount := int(binary.BigEndian.Uint16(response[6:8]))
|
||||||
|
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
|
||||||
|
arCount := int(binary.BigEndian.Uint16(response[10:12]))
|
||||||
|
|
||||||
|
pos := 12
|
||||||
|
|
||||||
|
msg.QD = make([]Question, qdCount)
|
||||||
|
for i := 0; i < qdCount; i++ {
|
||||||
|
var name Name
|
||||||
|
name, pos = parseName(response, pos)
|
||||||
|
msg.QD[i] = Question{
|
||||||
|
QName: name,
|
||||||
|
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||||
|
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||||
|
}
|
||||||
|
pos += 4
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.AN = make([]ResourceRecord, anCount)
|
||||||
|
for i := 0; i < anCount; i++ {
|
||||||
|
msg.AN[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error while parsing AN record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.NS = make([]ResourceRecord, nsCount)
|
||||||
|
for i := 0; i < nsCount; i++ {
|
||||||
|
msg.NS[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error while parsing NS record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.AR = make([]ResourceRecord, arCount)
|
||||||
|
for i := 0; i < arCount; i++ {
|
||||||
|
msg.AR[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error while parsing AR record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
196
dns-proxy/types.go
Normal file
196
dns-proxy/types.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResourceRecord interface {
|
||||||
|
EncodeResource() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResourceRecordHeader struct {
|
||||||
|
Name Name
|
||||||
|
Type uint16
|
||||||
|
Class uint16
|
||||||
|
TTL uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q ResourceRecordHeader) EncodeHeader() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
buf.Write(q.Name.Encode())
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Name struct {
|
||||||
|
Parts []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Name) String() string {
|
||||||
|
return strings.Join(n.Parts, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Name) Encode() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
for _, part := range n.Parts {
|
||||||
|
partLen := byte(len(part)) & 0x3F
|
||||||
|
buf.WriteByte(partLen)
|
||||||
|
buf.Write([]byte(part)[0:partLen])
|
||||||
|
}
|
||||||
|
buf.WriteByte(0)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Flags struct {
|
||||||
|
QR uint8
|
||||||
|
Opcode uint8
|
||||||
|
AA uint8
|
||||||
|
TC uint8
|
||||||
|
RD uint8
|
||||||
|
RA uint8
|
||||||
|
Z1 uint8
|
||||||
|
Z2 uint8
|
||||||
|
Z3 uint8
|
||||||
|
RCode uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flags) Encode() []byte {
|
||||||
|
return []byte{
|
||||||
|
f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0,
|
||||||
|
f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Question struct {
|
||||||
|
QName Name
|
||||||
|
QType uint16
|
||||||
|
QClass uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q Question) EncodeQuestion() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
buf.Write(q.QName.Encode())
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
Address net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Address) EncodeResource() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write([]byte{0x00, 0x04})
|
||||||
|
rr.Write(a.Address[:])
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type NameServer struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
NSDName Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a NameServer) EncodeResource() []byte {
|
||||||
|
rdataBytes := a.NSDName.Encode()
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type CName struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
CName Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a CName) EncodeResource() []byte {
|
||||||
|
rdataBytes := a.CName.Encode()
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Authority struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
MName Name
|
||||||
|
RName Name
|
||||||
|
Serial uint32
|
||||||
|
Refresh uint32
|
||||||
|
Retry uint32
|
||||||
|
Expire uint32
|
||||||
|
Minimum uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Authority) EncodeResource() []byte {
|
||||||
|
rdata := bytes.NewBuffer([]byte{})
|
||||||
|
rdata.Write(a.MName.Encode())
|
||||||
|
rdata.Write(a.RName.Encode())
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum))
|
||||||
|
rdataBytes := rdata.Bytes()
|
||||||
|
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Unknown struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u Unknown) EncodeResource() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(u.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data))))
|
||||||
|
rr.Write(u.Data)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint16
|
||||||
|
Flags Flags
|
||||||
|
QD []Question
|
||||||
|
AN []ResourceRecord
|
||||||
|
NS []ResourceRecord
|
||||||
|
AR []ResourceRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) Encode() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID))
|
||||||
|
rr.Write(m.Flags.Encode())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR))))
|
||||||
|
for _, q := range m.QD {
|
||||||
|
rr.Write(q.EncodeQuestion())
|
||||||
|
}
|
||||||
|
for _, a := range m.AN {
|
||||||
|
rr.Write(a.EncodeResource())
|
||||||
|
}
|
||||||
|
for _, ns := range m.NS {
|
||||||
|
rr.Write(ns.EncodeResource())
|
||||||
|
}
|
||||||
|
for _, a := range m.AR {
|
||||||
|
rr.Write(a.EncodeResource())
|
||||||
|
}
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
225
dns-proxy/types_test.go
Normal file
225
dns-proxy/types_test.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSResourceRecordHeaderEncode(t *testing.T) {
|
||||||
|
recordHeader := ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
}
|
||||||
|
recordHeaderEncoded := recordHeader.EncodeHeader()
|
||||||
|
recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0}
|
||||||
|
if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameString(t *testing.T) {
|
||||||
|
dnsName := Name{Parts: []string{"example", "com"}}
|
||||||
|
dnsNameString := dnsName.String()
|
||||||
|
dnsNameStringGood := "example.com"
|
||||||
|
if dnsNameString != dnsNameStringGood {
|
||||||
|
t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameEncode(t *testing.T) {
|
||||||
|
dnsName := Name{Parts: []string{"example", "com"}}
|
||||||
|
dnsNameEncoded := dnsName.Encode()
|
||||||
|
dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSFlagsEncode(t *testing.T) {
|
||||||
|
dnsFlags := Flags{
|
||||||
|
QR: 0x1,
|
||||||
|
Opcode: 0xF,
|
||||||
|
AA: 0x0,
|
||||||
|
TC: 0x0,
|
||||||
|
RD: 0x1,
|
||||||
|
RA: 0x1,
|
||||||
|
Z1: 0x0,
|
||||||
|
Z2: 0x0,
|
||||||
|
Z3: 0x0,
|
||||||
|
RCode: 0xF,
|
||||||
|
}
|
||||||
|
dnsFlagsEncoded := dnsFlags.Encode()
|
||||||
|
dnsFlagsEncodedGood := []byte{0xf9, 0x8f}
|
||||||
|
if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSQuestionEncode(t *testing.T) {
|
||||||
|
dnsQuestion := Question{
|
||||||
|
QName: Name{Parts: []string{"example", "com"}},
|
||||||
|
QType: 0x001c,
|
||||||
|
QClass: 0x0001,
|
||||||
|
}
|
||||||
|
dnsQuestionEncoded := dnsQuestion.EncodeQuestion()
|
||||||
|
dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01}
|
||||||
|
if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSAddressEncode(t *testing.T) {
|
||||||
|
dnsAddress := Address{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Address: []byte{192, 168, 1, 1},
|
||||||
|
}
|
||||||
|
dnsAddressEncoded := dnsAddress.EncodeResource()
|
||||||
|
dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1}
|
||||||
|
if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameServerEncode(t *testing.T) {
|
||||||
|
dnsNameServer := NameServer{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
NSDName: Name{Parts: []string{"example", "com"}},
|
||||||
|
}
|
||||||
|
dnsNameServerEncoded := dnsNameServer.EncodeResource()
|
||||||
|
dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSCNameEncode(t *testing.T) {
|
||||||
|
dnsCName := CName{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
CName: Name{Parts: []string{"example", "com"}},
|
||||||
|
}
|
||||||
|
dnsCNameEncoded := dnsCName.EncodeResource()
|
||||||
|
dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSAuthorityEncode(t *testing.T) {
|
||||||
|
dnsAuthority := Authority{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
MName: Name{Parts: []string{"example", "com"}},
|
||||||
|
RName: Name{Parts: []string{"example", "com"}},
|
||||||
|
Serial: 0x12345678,
|
||||||
|
Refresh: 0x12345678,
|
||||||
|
Retry: 0x12345678,
|
||||||
|
Expire: 0x12345678,
|
||||||
|
Minimum: 0x12345678,
|
||||||
|
}
|
||||||
|
dnsAuthorityEncoded := dnsAuthority.EncodeResource()
|
||||||
|
dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}
|
||||||
|
if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSUnknownEncode(t *testing.T) {
|
||||||
|
dnsUnknown := Unknown{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
}
|
||||||
|
dnsUnknownEncoded := dnsUnknown.EncodeResource()
|
||||||
|
dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||||
|
if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSMessageEncode(t *testing.T) {
|
||||||
|
dnsMessage := Message{
|
||||||
|
ID: 0x00FF,
|
||||||
|
Flags: Flags{
|
||||||
|
QR: 0x1,
|
||||||
|
Opcode: 0xF,
|
||||||
|
AA: 0x0,
|
||||||
|
TC: 0x0,
|
||||||
|
RD: 0x1,
|
||||||
|
RA: 0x1,
|
||||||
|
Z1: 0x0,
|
||||||
|
Z2: 0x0,
|
||||||
|
Z3: 0x0,
|
||||||
|
RCode: 0xF,
|
||||||
|
},
|
||||||
|
QD: []Question{
|
||||||
|
{
|
||||||
|
QName: Name{Parts: []string{"example", "com"}},
|
||||||
|
QType: 0x001c,
|
||||||
|
QClass: 0x0001,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AN: []ResourceRecord{
|
||||||
|
Unknown{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NS: []ResourceRecord{
|
||||||
|
Unknown{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
AR: []ResourceRecord{
|
||||||
|
Unknown{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsMessageEncoded := dnsMessage.Encode()
|
||||||
|
dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||||
|
if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
189
main.go
189
main.go
@ -1,176 +1,43 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
dnsProxy "kvas2-go/dns-proxy"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ListenPort = 7548
|
ListenPort = uint16(7548)
|
||||||
UsableDNSServerAddress = "127.0.0.1"
|
UsableDNSServerAddress = "127.0.0.1"
|
||||||
UsableDNSServerPort = 53
|
UsableDNSServerPort = uint16(53)
|
||||||
DNSMaxPackageSize = 4096
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseName(response []byte, pos int) (string, int) {
|
|
||||||
var nameParts []string
|
|
||||||
var jumped bool
|
|
||||||
var outPos int
|
|
||||||
responseLen := len(response)
|
|
||||||
|
|
||||||
for {
|
|
||||||
length := int(response[pos])
|
|
||||||
pos++
|
|
||||||
if length == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if length&0xC0 == 0xC0 {
|
|
||||||
if !jumped {
|
|
||||||
outPos = pos + 1
|
|
||||||
}
|
|
||||||
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
|
|
||||||
jumped = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if pos+length > responseLen {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
nameParts = append(nameParts, string(response[pos:pos+length]))
|
|
||||||
pos += length
|
|
||||||
}
|
|
||||||
|
|
||||||
if !jumped {
|
|
||||||
outPos = pos
|
|
||||||
}
|
|
||||||
return strings.Join(nameParts, "."), outPos
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendToUpstream(upstreamAddr string, request []byte) ([]byte, error) {
|
|
||||||
conn, err := net.Dial("udp", upstreamAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to dial upstream DNS: %w", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
_, err = conn.Write(request)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set timeout: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
response := make([]byte, DNSMaxPackageSize)
|
|
||||||
n, err := conn.Read(response)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return response[:n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
addr := fmt.Sprintf(":%d", ListenPort)
|
proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort)
|
||||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
proxy.MsgHandler = func(msg *dnsProxy.Message) {
|
||||||
|
for _, q := range msg.QD {
|
||||||
|
fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String())
|
||||||
|
}
|
||||||
|
for _, a := range msg.AN {
|
||||||
|
switch v := a.(type) {
|
||||||
|
case dnsProxy.Address:
|
||||||
|
fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL)
|
||||||
|
case dnsProxy.CName:
|
||||||
|
fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName)
|
||||||
|
default:
|
||||||
|
fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, a := range msg.NS {
|
||||||
|
fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
|
||||||
|
}
|
||||||
|
for _, a := range msg.AR {
|
||||||
|
fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := proxy.Listen()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to resolve address: %v", err)
|
log.Fatal(err)
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", udpAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Failed to listen on UDP: %v", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
fmt.Printf("DNS server is running on %s...\n", addr)
|
|
||||||
|
|
||||||
for {
|
|
||||||
buffer := make([]byte, DNSMaxPackageSize)
|
|
||||||
n, clientAddr, err := conn.ReadFromUDP(buffer)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to read from UDP: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go handleDNSRequest(conn, clientAddr, buffer[:n])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func process(response []byte) {
|
|
||||||
responseLen := len(response)
|
|
||||||
if responseLen <= 12 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
qCount := int(binary.LittleEndian.Uint16(response[5:7]))
|
|
||||||
aCount := int(binary.LittleEndian.Uint16(response[7:9]))
|
|
||||||
|
|
||||||
pos := 12
|
|
||||||
|
|
||||||
for i := 0; i < qCount; i++ {
|
|
||||||
var name string
|
|
||||||
name, pos = parseName(response, pos)
|
|
||||||
fmt.Printf("Requested name: %s\n", name)
|
|
||||||
pos += 4
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < aCount; i++ {
|
|
||||||
name, newPos := parseName(response, pos)
|
|
||||||
pos = newPos
|
|
||||||
|
|
||||||
if pos+10 > responseLen {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
qtype := binary.BigEndian.Uint16(response[pos : pos+2])
|
|
||||||
pos += 2
|
|
||||||
|
|
||||||
qclass := binary.BigEndian.Uint16(response[pos : pos+2])
|
|
||||||
pos += 2
|
|
||||||
|
|
||||||
ttl := binary.BigEndian.Uint32(response[pos : pos+4])
|
|
||||||
pos += 4
|
|
||||||
|
|
||||||
rdlength := binary.BigEndian.Uint16(response[pos : pos+2])
|
|
||||||
pos += 2
|
|
||||||
|
|
||||||
if pos+int(rdlength) > responseLen {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if qtype == 1 && qclass == 1 && rdlength == 4 {
|
|
||||||
ip := net.IPv4(response[pos], response[pos+1], response[pos+2], response[pos+3])
|
|
||||||
fmt.Printf("Parsed A record: %s -> %s, TTL: %d\n", name, ip, ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
pos += int(rdlength)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleDNSRequest(conn *net.UDPConn, clientAddr *net.UDPAddr, buffer []byte) {
|
|
||||||
upstreamAddr := fmt.Sprintf("%s:%d", UsableDNSServerAddress, UsableDNSServerPort)
|
|
||||||
|
|
||||||
upstreamResponse, err := sendToUpstream(upstreamAddr, buffer)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get response from upstream DNS: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Response: %s", hex.EncodeToString(upstreamResponse))
|
|
||||||
|
|
||||||
process(upstreamResponse)
|
|
||||||
|
|
||||||
_, err = conn.WriteToUDP(upstreamResponse, clientAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to send DNS response: %v", err)
|
|
||||||
}
|
}
|
||||||
|
defer proxy.Close()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user