diff --git a/dns-proxy/dns-proxy.go b/dns-proxy/dns-proxy.go new file mode 100644 index 0000000..3bc109e --- /dev/null +++ b/dns-proxy/dns-proxy.go @@ -0,0 +1,120 @@ +package dnsProxy + +import ( + "encoding/hex" + "fmt" + "log" + "net" + "time" +) + +const ( + DNSMaxUDPPackageSize = 4096 + DNSMaxTCPPackageSize = 65536 +) + +type DNSProxy struct { + listenAddr string + upstreamAddr string + + udpConn *net.UDPConn + + MsgHandler func(*Message) +} + +func (p DNSProxy) Close() error { + return p.udpConn.Close() +} + +func (p DNSProxy) sendToUpstream(isTCP bool, request []byte) ([]byte, error) { + protocol := "udp" + if isTCP { + protocol = "tcp" + } + + conn, err := net.Dial(protocol, p.upstreamAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial upstream DNS: %w", err) + } + defer conn.Close() + + _, err = conn.Write(request) + if err != nil { + return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err) + } + + err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to set timeout: %w", err) + } + + var response []byte + if !isTCP { + response = make([]byte, DNSMaxUDPPackageSize) + } else { + response = make([]byte, DNSMaxTCPPackageSize) + } + + n, err := conn.Read(response) + if err != nil { + return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err) + } + + return response[:n], nil +} + +func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) { + upstreamResponse, err := p.sendToUpstream(false, buffer) + if err != nil { + log.Printf("Failed to get response from upstream DNS: %v", err) + return + } + + log.Printf("Response: %s", hex.EncodeToString(upstreamResponse)) + + msg, err := ParseResponse(upstreamResponse) + if err == nil { + if p.MsgHandler != nil { + p.MsgHandler(msg) + } + } else { + log.Printf("error while parsing response: %v", err) + } + + _, err = p.udpConn.WriteToUDP(upstreamResponse, clientAddr) + if err != nil { + log.Printf("Failed to send DNS response: %v", err) + } +} + +func (p DNSProxy) Listen() error { + var err error + + udpAddr, err := net.ResolveUDPAddr("udp", p.listenAddr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %v", err) + } + + p.udpConn, err = net.ListenUDP("udp", udpAddr) + if err != nil { + return fmt.Errorf("failed to listen on UDP: %v", err) + } + + for { + buffer := make([]byte, DNSMaxUDPPackageSize) + n, clientAddr, err := p.udpConn.ReadFromUDP(buffer) + if err != nil { + log.Printf("Failed to read from UDP: %v", err) + continue + } + + go p.handleDNSRequest(clientAddr, buffer[:n]) + } +} + +func New(listenAddr string, listenPort uint16, upstreamAddr string, upstreamPort uint16) *DNSProxy { + return &DNSProxy{ + listenAddr: fmt.Sprintf("%s:%d", listenAddr, listenPort), + upstreamAddr: fmt.Sprintf("%s:%d", upstreamAddr, upstreamPort), + } +} diff --git a/dns-proxy/parser.go b/dns-proxy/parser.go new file mode 100644 index 0000000..5d31c8d --- /dev/null +++ b/dns-proxy/parser.go @@ -0,0 +1,178 @@ +package dnsProxy + +import ( + "encoding/binary" + "errors" + "fmt" +) + +var ( + ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header") + ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header") + ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data") + ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data") +) + +func parseName(response []byte, pos int) (Name, int) { + var nameParts []string + var jumped bool + var outPos int + responseLen := len(response) + + for { + length := int(response[pos]) + pos++ + if length == 0 { + break + } + + if length&0xC0 == 0xC0 { + if !jumped { + outPos = pos + 1 + } + pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF) + jumped = true + continue + } + + if pos+length > responseLen { + break + } + + nameParts = append(nameParts, string(response[pos:pos+length])) + pos += length + } + + if !jumped { + outPos = pos + } + return Name{Parts: nameParts}, outPos +} + +func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) { + responseLen := len(response) + + var rhname Name + rhname, pos = parseName(response, pos) + + if responseLen < pos+10 { + return nil, pos, ErrInvalidDNSResourceRecordHeader + } + + rh := ResourceRecordHeader{ + Name: rhname, + Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]), + Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]), + TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]), + } + rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10])) + + pos += 10 + + if pos+rdlength > responseLen { + return nil, pos, ErrInvalidDNSResourceRecordData + } + + switch rh.Type { + case 1: + if rdlength == 4 { + return Address{ + ResourceRecordHeader: rh, + Address: response[pos+0 : pos+4], + }, pos + 4, nil + } else { + return nil, pos, ErrInvalidDNSAddressResourceData + } + case 2: + var ns Name + ns, pos = parseName(response, pos) + return NameServer{ + ResourceRecordHeader: rh, + NSDName: ns, + }, pos, nil + case 5: + var cname Name + cname, pos = parseName(response, pos) + return CName{ + ResourceRecordHeader: rh, + CName: cname, + }, pos, nil + } + + return Unknown{ + ResourceRecordHeader: rh, + Data: response[pos+0 : pos+rdlength], + }, pos + rdlength, nil +} + +func ParseResponse(response []byte) (*Message, error) { + var err error + + responseLen := len(response) + if responseLen < 12 { + return nil, ErrInvalidDNSMessageHeader + } + + msg := new(Message) + + msg.ID = binary.BigEndian.Uint16(response[0:2]) + + flagsRAW := binary.BigEndian.Uint16(response[2:4]) + msg.Flags = Flags{ + QR: uint8(flagsRAW >> 15 & 0x1), + Opcode: uint8(flagsRAW >> 11 & 0xF), + AA: uint8(flagsRAW >> 10 & 0x1), + TC: uint8(flagsRAW >> 9 & 0x1), + RD: uint8(flagsRAW >> 8 & 0x1), + RA: uint8(flagsRAW >> 7 & 0x1), + Z1: uint8(flagsRAW >> 6 & 0x1), + Z2: uint8(flagsRAW >> 5 & 0x1), + Z3: uint8(flagsRAW >> 4 & 0x1), + RCode: uint8(flagsRAW >> 0 & 0xF), + } + + qdCount := int(binary.BigEndian.Uint16(response[4:6])) + anCount := int(binary.BigEndian.Uint16(response[6:8])) + nsCount := int(binary.BigEndian.Uint16(response[8:10])) + arCount := int(binary.BigEndian.Uint16(response[10:12])) + + pos := 12 + + msg.QD = make([]Question, qdCount) + for i := 0; i < qdCount; i++ { + var name Name + name, pos = parseName(response, pos) + msg.QD[i] = Question{ + QName: name, + QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]), + QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]), + } + pos += 4 + } + + msg.AN = make([]ResourceRecord, anCount) + for i := 0; i < anCount; i++ { + msg.AN[i], pos, err = parseResourceRecord(response, pos) + if err != nil { + return nil, fmt.Errorf("error while parsing AN record: %w", err) + } + } + + msg.NS = make([]ResourceRecord, nsCount) + for i := 0; i < nsCount; i++ { + msg.NS[i], pos, err = parseResourceRecord(response, pos) + if err != nil { + return nil, fmt.Errorf("error while parsing NS record: %w", err) + } + } + + msg.AR = make([]ResourceRecord, arCount) + for i := 0; i < arCount; i++ { + msg.AR[i], pos, err = parseResourceRecord(response, pos) + if err != nil { + return nil, fmt.Errorf("error while parsing AR record: %w", err) + } + } + + return msg, nil +} diff --git a/dns-proxy/types.go b/dns-proxy/types.go new file mode 100644 index 0000000..294b5f4 --- /dev/null +++ b/dns-proxy/types.go @@ -0,0 +1,196 @@ +package dnsProxy + +import ( + "bytes" + "encoding/binary" + "net" + "strings" +) + +type ResourceRecord interface { + EncodeResource() []byte +} + +type ResourceRecordHeader struct { + Name Name + Type uint16 + Class uint16 + TTL uint32 +} + +func (q ResourceRecordHeader) EncodeHeader() []byte { + buf := bytes.NewBuffer([]byte{}) + buf.Write(q.Name.Encode()) + buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type)) + buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class)) + buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL)) + return buf.Bytes() +} + +type Name struct { + Parts []string +} + +func (n Name) String() string { + return strings.Join(n.Parts, ".") +} + +func (n Name) Encode() []byte { + buf := bytes.NewBuffer([]byte{}) + for _, part := range n.Parts { + partLen := byte(len(part)) & 0x3F + buf.WriteByte(partLen) + buf.Write([]byte(part)[0:partLen]) + } + buf.WriteByte(0) + return buf.Bytes() +} + +type Flags struct { + QR uint8 + Opcode uint8 + AA uint8 + TC uint8 + RD uint8 + RA uint8 + Z1 uint8 + Z2 uint8 + Z3 uint8 + RCode uint8 +} + +func (f Flags) Encode() []byte { + return []byte{ + f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0, + f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0, + } +} + +type Question struct { + QName Name + QType uint16 + QClass uint16 +} + +func (q Question) EncodeQuestion() []byte { + buf := bytes.NewBuffer([]byte{}) + buf.Write(q.QName.Encode()) + buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType)) + buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass)) + return buf.Bytes() +} + +type Address struct { + ResourceRecordHeader + Address net.IP +} + +func (a Address) EncodeResource() []byte { + rr := bytes.NewBuffer([]byte{}) + rr.Write(a.ResourceRecordHeader.EncodeHeader()) + rr.Write([]byte{0x00, 0x04}) + rr.Write(a.Address[:]) + return rr.Bytes() +} + +type NameServer struct { + ResourceRecordHeader + NSDName Name +} + +func (a NameServer) EncodeResource() []byte { + rdataBytes := a.NSDName.Encode() + rr := bytes.NewBuffer([]byte{}) + rr.Write(a.ResourceRecordHeader.EncodeHeader()) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) + rr.Write(rdataBytes) + return rr.Bytes() +} + +type CName struct { + ResourceRecordHeader + CName Name +} + +func (a CName) EncodeResource() []byte { + rdataBytes := a.CName.Encode() + rr := bytes.NewBuffer([]byte{}) + rr.Write(a.ResourceRecordHeader.EncodeHeader()) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) + rr.Write(rdataBytes) + return rr.Bytes() +} + +type Authority struct { + ResourceRecordHeader + MName Name + RName Name + Serial uint32 + Refresh uint32 + Retry uint32 + Expire uint32 + Minimum uint32 +} + +func (a Authority) EncodeResource() []byte { + rdata := bytes.NewBuffer([]byte{}) + rdata.Write(a.MName.Encode()) + rdata.Write(a.RName.Encode()) + rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial)) + rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh)) + rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry)) + rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire)) + rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum)) + rdataBytes := rdata.Bytes() + + rr := bytes.NewBuffer([]byte{}) + rr.Write(a.ResourceRecordHeader.EncodeHeader()) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) + rr.Write(rdataBytes) + return rr.Bytes() +} + +type Unknown struct { + ResourceRecordHeader + Data []byte +} + +func (u Unknown) EncodeResource() []byte { + rr := bytes.NewBuffer([]byte{}) + rr.Write(u.ResourceRecordHeader.EncodeHeader()) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data)))) + rr.Write(u.Data) + return rr.Bytes() +} + +type Message struct { + ID uint16 + Flags Flags + QD []Question + AN []ResourceRecord + NS []ResourceRecord + AR []ResourceRecord +} + +func (m Message) Encode() []byte { + rr := bytes.NewBuffer([]byte{}) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID)) + rr.Write(m.Flags.Encode()) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD)))) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN)))) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS)))) + rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR)))) + for _, q := range m.QD { + rr.Write(q.EncodeQuestion()) + } + for _, a := range m.AN { + rr.Write(a.EncodeResource()) + } + for _, ns := range m.NS { + rr.Write(ns.EncodeResource()) + } + for _, a := range m.AR { + rr.Write(a.EncodeResource()) + } + return rr.Bytes() +} diff --git a/dns-proxy/types_test.go b/dns-proxy/types_test.go new file mode 100644 index 0000000..9b4d41c --- /dev/null +++ b/dns-proxy/types_test.go @@ -0,0 +1,225 @@ +package dnsProxy + +import ( + "bytes" + "testing" +) + +func TestDNSResourceRecordHeaderEncode(t *testing.T) { + recordHeader := ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + } + recordHeaderEncoded := recordHeader.EncodeHeader() + recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0} + if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 { + t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood) + } +} + +func TestDNSNameString(t *testing.T) { + dnsName := Name{Parts: []string{"example", "com"}} + dnsNameString := dnsName.String() + dnsNameStringGood := "example.com" + if dnsNameString != dnsNameStringGood { + t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood) + } +} + +func TestDNSNameEncode(t *testing.T) { + dnsName := Name{Parts: []string{"example", "com"}} + dnsNameEncoded := dnsName.Encode() + dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} + if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 { + t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood) + } +} + +func TestDNSFlagsEncode(t *testing.T) { + dnsFlags := Flags{ + QR: 0x1, + Opcode: 0xF, + AA: 0x0, + TC: 0x0, + RD: 0x1, + RA: 0x1, + Z1: 0x0, + Z2: 0x0, + Z3: 0x0, + RCode: 0xF, + } + dnsFlagsEncoded := dnsFlags.Encode() + dnsFlagsEncodedGood := []byte{0xf9, 0x8f} + if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 { + t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood) + } +} + +func TestDNSQuestionEncode(t *testing.T) { + dnsQuestion := Question{ + QName: Name{Parts: []string{"example", "com"}}, + QType: 0x001c, + QClass: 0x0001, + } + dnsQuestionEncoded := dnsQuestion.EncodeQuestion() + dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01} + if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 { + t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood) + } +} + +func TestDNSAddressEncode(t *testing.T) { + dnsAddress := Address{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + Address: []byte{192, 168, 1, 1}, + } + dnsAddressEncoded := dnsAddress.EncodeResource() + dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1} + if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 { + t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood) + } +} + +func TestDNSNameServerEncode(t *testing.T) { + dnsNameServer := NameServer{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + NSDName: Name{Parts: []string{"example", "com"}}, + } + dnsNameServerEncoded := dnsNameServer.EncodeResource() + dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} + if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 { + t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood) + } +} + +func TestDNSCNameEncode(t *testing.T) { + dnsCName := CName{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + CName: Name{Parts: []string{"example", "com"}}, + } + dnsCNameEncoded := dnsCName.EncodeResource() + dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} + if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 { + t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood) + } +} + +func TestDNSAuthorityEncode(t *testing.T) { + dnsAuthority := Authority{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + MName: Name{Parts: []string{"example", "com"}}, + RName: Name{Parts: []string{"example", "com"}}, + Serial: 0x12345678, + Refresh: 0x12345678, + Retry: 0x12345678, + Expire: 0x12345678, + Minimum: 0x12345678, + } + dnsAuthorityEncoded := dnsAuthority.EncodeResource() + dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78} + if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 { + t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood) + } +} + +func TestDNSUnknownEncode(t *testing.T) { + dnsUnknown := Unknown{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + Data: []byte{0x01, 0x02, 0x03}, + } + dnsUnknownEncoded := dnsUnknown.EncodeResource() + dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03} + if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 { + t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood) + } +} + +func TestDNSMessageEncode(t *testing.T) { + dnsMessage := Message{ + ID: 0x00FF, + Flags: Flags{ + QR: 0x1, + Opcode: 0xF, + AA: 0x0, + TC: 0x0, + RD: 0x1, + RA: 0x1, + Z1: 0x0, + Z2: 0x0, + Z3: 0x0, + RCode: 0xF, + }, + QD: []Question{ + { + QName: Name{Parts: []string{"example", "com"}}, + QType: 0x001c, + QClass: 0x0001, + }, + }, + AN: []ResourceRecord{ + Unknown{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + Data: []byte{0x01, 0x02, 0x03}, + }, + }, + NS: []ResourceRecord{ + Unknown{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + Data: []byte{0x01, 0x02, 0x03}, + }, + }, + AR: []ResourceRecord{ + Unknown{ + ResourceRecordHeader: ResourceRecordHeader{ + Name: Name{Parts: []string{"example", "com"}}, + Type: 0xF0, + Class: 0xF0, + TTL: 0x77770FF0, + }, + Data: []byte{0x01, 0x02, 0x03}, + }, + }, + } + dnsMessageEncoded := dnsMessage.Encode() + dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03} + if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 { + t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood) + } +} diff --git a/main.go b/main.go index 5381213..ee6d08f 100644 --- a/main.go +++ b/main.go @@ -1,176 +1,43 @@ package main import ( - "encoding/binary" - "encoding/hex" "fmt" + dnsProxy "kvas2-go/dns-proxy" "log" - "net" - "strings" - "time" ) var ( - ListenPort = 7548 + ListenPort = uint16(7548) UsableDNSServerAddress = "127.0.0.1" - UsableDNSServerPort = 53 - DNSMaxPackageSize = 4096 + UsableDNSServerPort = uint16(53) ) -func parseName(response []byte, pos int) (string, int) { - var nameParts []string - var jumped bool - var outPos int - responseLen := len(response) - - for { - length := int(response[pos]) - pos++ - if length == 0 { - break - } - - if length&0xC0 == 0xC0 { - if !jumped { - outPos = pos + 1 - } - pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF) - jumped = true - continue - } - - if pos+length > responseLen { - break - } - - nameParts = append(nameParts, string(response[pos:pos+length])) - pos += length - } - - if !jumped { - outPos = pos - } - return strings.Join(nameParts, "."), outPos -} - -func sendToUpstream(upstreamAddr string, request []byte) ([]byte, error) { - conn, err := net.Dial("udp", upstreamAddr) - if err != nil { - return nil, fmt.Errorf("failed to dial upstream DNS: %w", err) - } - defer conn.Close() - - _, err = conn.Write(request) - if err != nil { - return nil, fmt.Errorf("failed to send request to upstream DNS: %w", err) - } - - err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - if err != nil { - return nil, fmt.Errorf("failed to set timeout: %w", err) - } - - response := make([]byte, DNSMaxPackageSize) - n, err := conn.Read(response) - if err != nil { - return nil, fmt.Errorf("failed to read response from upstream DNS: %w", err) - } - - return response[:n], nil -} - func main() { - addr := fmt.Sprintf(":%d", ListenPort) - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - log.Fatalf("Failed to resolve address: %v", err) - } - - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - log.Fatalf("Failed to listen on UDP: %v", err) - } - defer conn.Close() - - fmt.Printf("DNS server is running on %s...\n", addr) - - for { - buffer := make([]byte, DNSMaxPackageSize) - n, clientAddr, err := conn.ReadFromUDP(buffer) - if err != nil { - log.Printf("Failed to read from UDP: %v", err) - continue + proxy := dnsProxy.New("", ListenPort, UsableDNSServerAddress, UsableDNSServerPort) + proxy.MsgHandler = func(msg *dnsProxy.Message) { + for _, q := range msg.QD { + fmt.Printf("%x: <- Request name: %s\n", msg.ID, q.QName.String()) + } + for _, a := range msg.AN { + switch v := a.(type) { + case dnsProxy.Address: + fmt.Printf("%x: -> A: Name: %s; Address: %s; TTL: %d\n", msg.ID, v.Name, v.Address.String(), v.TTL) + case dnsProxy.CName: + fmt.Printf("%x: -> CNAME: Name: %s; CName: %s\n", msg.ID, v.Name, v.CName) + default: + fmt.Printf("%x: -> Unknown: %x\n", msg.ID, v.EncodeResource()) + } + } + for _, a := range msg.NS { + fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource()) + } + for _, a := range msg.AR { + fmt.Printf("%x: -> NS: %x\n", msg.ID, a.EncodeResource()) } - - go handleDNSRequest(conn, clientAddr, buffer[:n]) - } -} - -func process(response []byte) { - responseLen := len(response) - if responseLen <= 12 { - return - } - - qCount := int(binary.LittleEndian.Uint16(response[5:7])) - aCount := int(binary.LittleEndian.Uint16(response[7:9])) - - pos := 12 - - for i := 0; i < qCount; i++ { - var name string - name, pos = parseName(response, pos) - fmt.Printf("Requested name: %s\n", name) - pos += 4 - } - - for i := 0; i < aCount; i++ { - name, newPos := parseName(response, pos) - pos = newPos - - if pos+10 > responseLen { - break - } - - qtype := binary.BigEndian.Uint16(response[pos : pos+2]) - pos += 2 - - qclass := binary.BigEndian.Uint16(response[pos : pos+2]) - pos += 2 - - ttl := binary.BigEndian.Uint32(response[pos : pos+4]) - pos += 4 - - rdlength := binary.BigEndian.Uint16(response[pos : pos+2]) - pos += 2 - - if pos+int(rdlength) > responseLen { - break - } - - if qtype == 1 && qclass == 1 && rdlength == 4 { - ip := net.IPv4(response[pos], response[pos+1], response[pos+2], response[pos+3]) - fmt.Printf("Parsed A record: %s -> %s, TTL: %d\n", name, ip, ttl) - } - - pos += int(rdlength) - } -} - -func handleDNSRequest(conn *net.UDPConn, clientAddr *net.UDPAddr, buffer []byte) { - upstreamAddr := fmt.Sprintf("%s:%d", UsableDNSServerAddress, UsableDNSServerPort) - - upstreamResponse, err := sendToUpstream(upstreamAddr, buffer) - if err != nil { - log.Printf("Failed to get response from upstream DNS: %v", err) - return - } - log.Printf("Response: %s", hex.EncodeToString(upstreamResponse)) - - process(upstreamResponse) - - _, err = conn.WriteToUDP(upstreamResponse, clientAddr) - if err != nil { - log.Printf("Failed to send DNS response: %v", err) } + err := proxy.Listen() + if err != nil { + log.Fatal(err) + } + defer proxy.Close() }