refactor dns record parser

This commit is contained in:
Vladimir Avtsenov 2024-08-25 00:50:34 +03:00
parent 11ddf5aedb
commit 0f716a3c49

View File

@ -4,22 +4,23 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io"
) )
var ( var (
ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header")
ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header")
ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data")
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data") ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
) )
func parseName(response []byte, pos int) (Name, int) { func parseName(response []byte, pos int) (*Name, int, error) {
var nameParts []string var nameParts []string
var jumped bool var jumped bool
var outPos int var outPos int
responseLen := len(response) responseLen := len(response)
for { for {
if responseLen < pos+1 {
return nil, pos, io.EOF
}
length := int(response[pos]) length := int(response[pos])
pos++ pos++
if length == 0 { if length == 0 {
@ -27,6 +28,9 @@ func parseName(response []byte, pos int) (Name, int) {
} }
if length&0xC0 == 0xC0 { if length&0xC0 == 0xC0 {
if responseLen < pos+1 {
return nil, pos, io.EOF
}
if !jumped { if !jumped {
outPos = pos + 1 outPos = pos + 1
} }
@ -35,8 +39,8 @@ func parseName(response []byte, pos int) (Name, int) {
continue continue
} }
if pos+length > responseLen { if responseLen < pos+length {
break return nil, pos, io.EOF
} }
nameParts = append(nameParts, string(response[pos:pos+length])) nameParts = append(nameParts, string(response[pos:pos+length]))
@ -46,75 +50,82 @@ func parseName(response []byte, pos int) (Name, int) {
if !jumped { if !jumped {
outPos = pos outPos = pos
} }
return Name{Parts: nameParts}, outPos return &Name{Parts: nameParts}, outPos, nil
} }
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) { func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
responseLen := len(response) responseLen := len(response)
var rhname Name rhName, pos, err := parseName(response, pos)
rhname, pos = parseName(response, pos) if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS name: %w", err)
}
if responseLen < pos+10 { if responseLen < pos+10 {
return nil, pos, ErrInvalidDNSResourceRecordHeader return nil, pos, io.EOF
} }
rh := ResourceRecordHeader{ rh := ResourceRecordHeader{
Name: rhname, Name: *rhName,
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]), Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]), Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]), TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
} }
rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10])) rdLen := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
pos += 10 pos += 10
if pos+rdlength > responseLen { if responseLen < pos+rdLen {
return nil, pos, ErrInvalidDNSResourceRecordData return nil, pos, io.EOF
} }
switch rh.Type { switch rh.Type {
case 1: case 1:
if rdlength == 4 { if rdLen != 4 {
return nil, pos, ErrInvalidDNSAddressResourceData
}
return Address{ return Address{
ResourceRecordHeader: rh, ResourceRecordHeader: rh,
Address: response[pos+0 : pos+4], Address: response[pos+0 : pos+4],
}, pos + 4, nil }, pos + 4, nil
} else {
return nil, pos, ErrInvalidDNSAddressResourceData
}
case 2: case 2:
var ns Name var ns *Name
ns, pos = parseName(response, pos) ns, pos, err = parseName(response, pos)
if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
}
return NameServer{ return NameServer{
ResourceRecordHeader: rh, ResourceRecordHeader: rh,
NSDName: ns, NSDName: *ns,
}, pos, nil }, pos, nil
case 5: case 5:
var cname Name var cname *Name
cname, pos = parseName(response, pos) cname, pos, err = parseName(response, pos)
if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
}
return CName{ return CName{
ResourceRecordHeader: rh, ResourceRecordHeader: rh,
CName: cname, CName: *cname,
}, pos, nil }, pos, nil
} }
return Unknown{ return Unknown{
ResourceRecordHeader: rh, ResourceRecordHeader: rh,
Data: response[pos+0 : pos+rdlength], Data: response[pos+0 : pos+rdLen],
}, pos + rdlength, nil }, pos + rdLen, nil
} }
func ParseResponse(response []byte) (*Message, error) { func ParseResponse(response []byte) (*Message, error) {
var err error var err error
msg := new(Message)
responseLen := len(response) responseLen := len(response)
if responseLen < 12 { if responseLen < 12 {
return nil, ErrInvalidDNSMessageHeader return msg, io.EOF
} }
msg := new(Message)
msg.ID = binary.BigEndian.Uint16(response[0:2]) msg.ID = binary.BigEndian.Uint16(response[0:2])
flagsRAW := binary.BigEndian.Uint16(response[2:4]) flagsRAW := binary.BigEndian.Uint16(response[2:4])
@ -140,10 +151,16 @@ func ParseResponse(response []byte) (*Message, error) {
msg.QD = make([]Question, qdCount) msg.QD = make([]Question, qdCount)
for i := 0; i < qdCount; i++ { for i := 0; i < qdCount; i++ {
var name Name var name *Name
name, pos = parseName(response, pos) name, pos, err = parseName(response, pos)
if err != nil {
return msg, fmt.Errorf("error while parsing DNS name: %w", err)
}
if responseLen < pos+4 {
return msg, io.EOF
}
msg.QD[i] = Question{ msg.QD[i] = Question{
QName: name, QName: *name,
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]), QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]), QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
} }
@ -154,7 +171,7 @@ func ParseResponse(response []byte) (*Message, error) {
for i := 0; i < anCount; i++ { for i := 0; i < anCount; i++ {
msg.AN[i], pos, err = parseResourceRecord(response, pos) msg.AN[i], pos, err = parseResourceRecord(response, pos)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while parsing AN record: %w", err) return msg, fmt.Errorf("error while parsing AN record: %w", err)
} }
} }
@ -162,7 +179,7 @@ func ParseResponse(response []byte) (*Message, error) {
for i := 0; i < nsCount; i++ { for i := 0; i < nsCount; i++ {
msg.NS[i], pos, err = parseResourceRecord(response, pos) msg.NS[i], pos, err = parseResourceRecord(response, pos)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while parsing NS record: %w", err) return msg, fmt.Errorf("error while parsing NS record: %w", err)
} }
} }
@ -170,7 +187,7 @@ func ParseResponse(response []byte) (*Message, error) {
for i := 0; i < arCount; i++ { for i := 0; i < arCount; i++ {
msg.AR[i], pos, err = parseResourceRecord(response, pos) msg.AR[i], pos, err = parseResourceRecord(response, pos)
if err != nil { if err != nil {
return nil, fmt.Errorf("error while parsing AR record: %w", err) return msg, fmt.Errorf("error while parsing AR record: %w", err)
} }
} }