diff --git a/dns-proxy/parser.go b/dns-proxy/parser.go index 5d31c8d..dfd961a 100644 --- a/dns-proxy/parser.go +++ b/dns-proxy/parser.go @@ -4,22 +4,23 @@ import ( "encoding/binary" "errors" "fmt" + "io" ) var ( - ErrInvalidDNSMessageHeader = errors.New("invalid DNS message header") - ErrInvalidDNSResourceRecordHeader = errors.New("invalid DNS resource record header") - ErrInvalidDNSResourceRecordData = errors.New("invalid DNS resource record data") - ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data") + ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data") ) -func parseName(response []byte, pos int) (Name, int) { +func parseName(response []byte, pos int) (*Name, int, error) { var nameParts []string var jumped bool var outPos int responseLen := len(response) for { + if responseLen < pos+1 { + return nil, pos, io.EOF + } length := int(response[pos]) pos++ if length == 0 { @@ -27,6 +28,9 @@ func parseName(response []byte, pos int) (Name, int) { } if length&0xC0 == 0xC0 { + if responseLen < pos+1 { + return nil, pos, io.EOF + } if !jumped { outPos = pos + 1 } @@ -35,8 +39,8 @@ func parseName(response []byte, pos int) (Name, int) { continue } - if pos+length > responseLen { - break + if responseLen < pos+length { + return nil, pos, io.EOF } nameParts = append(nameParts, string(response[pos:pos+length])) @@ -46,75 +50,82 @@ func parseName(response []byte, pos int) (Name, int) { if !jumped { outPos = pos } - return Name{Parts: nameParts}, outPos + return &Name{Parts: nameParts}, outPos, nil } func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) { responseLen := len(response) - var rhname Name - rhname, pos = parseName(response, pos) + rhName, pos, err := parseName(response, pos) + if err != nil { + return nil, pos, fmt.Errorf("error while parsing DNS name: %w", err) + } if responseLen < pos+10 { - return nil, pos, ErrInvalidDNSResourceRecordHeader + return nil, pos, io.EOF } rh := ResourceRecordHeader{ - Name: rhname, + Name: *rhName, Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]), Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]), TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]), } - rdlength := int(binary.BigEndian.Uint16(response[pos+8 : pos+10])) + rdLen := int(binary.BigEndian.Uint16(response[pos+8 : pos+10])) pos += 10 - if pos+rdlength > responseLen { - return nil, pos, ErrInvalidDNSResourceRecordData + if responseLen < pos+rdLen { + return nil, pos, io.EOF } switch rh.Type { case 1: - if rdlength == 4 { - return Address{ - ResourceRecordHeader: rh, - Address: response[pos+0 : pos+4], - }, pos + 4, nil - } else { + if rdLen != 4 { return nil, pos, ErrInvalidDNSAddressResourceData } + return Address{ + ResourceRecordHeader: rh, + Address: response[pos+0 : pos+4], + }, pos + 4, nil case 2: - var ns Name - ns, pos = parseName(response, pos) + var ns *Name + ns, pos, err = parseName(response, pos) + if err != nil { + return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err) + } return NameServer{ ResourceRecordHeader: rh, - NSDName: ns, + NSDName: *ns, }, pos, nil case 5: - var cname Name - cname, pos = parseName(response, pos) + var cname *Name + cname, pos, err = parseName(response, pos) + if err != nil { + return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err) + } return CName{ ResourceRecordHeader: rh, - CName: cname, + CName: *cname, }, pos, nil } return Unknown{ ResourceRecordHeader: rh, - Data: response[pos+0 : pos+rdlength], - }, pos + rdlength, nil + Data: response[pos+0 : pos+rdLen], + }, pos + rdLen, nil } func ParseResponse(response []byte) (*Message, error) { var err error + msg := new(Message) + responseLen := len(response) if responseLen < 12 { - return nil, ErrInvalidDNSMessageHeader + return msg, io.EOF } - msg := new(Message) - msg.ID = binary.BigEndian.Uint16(response[0:2]) flagsRAW := binary.BigEndian.Uint16(response[2:4]) @@ -140,10 +151,16 @@ func ParseResponse(response []byte) (*Message, error) { msg.QD = make([]Question, qdCount) for i := 0; i < qdCount; i++ { - var name Name - name, pos = parseName(response, pos) + var name *Name + name, pos, err = parseName(response, pos) + if err != nil { + return msg, fmt.Errorf("error while parsing DNS name: %w", err) + } + if responseLen < pos+4 { + return msg, io.EOF + } msg.QD[i] = Question{ - QName: name, + QName: *name, QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]), QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]), } @@ -154,7 +171,7 @@ func ParseResponse(response []byte) (*Message, error) { for i := 0; i < anCount; i++ { msg.AN[i], pos, err = parseResourceRecord(response, pos) if err != nil { - return nil, fmt.Errorf("error while parsing AN record: %w", err) + return msg, fmt.Errorf("error while parsing AN record: %w", err) } } @@ -162,7 +179,7 @@ func ParseResponse(response []byte) (*Message, error) { for i := 0; i < nsCount; i++ { msg.NS[i], pos, err = parseResourceRecord(response, pos) if err != nil { - return nil, fmt.Errorf("error while parsing NS record: %w", err) + return msg, fmt.Errorf("error while parsing NS record: %w", err) } } @@ -170,7 +187,7 @@ func ParseResponse(response []byte) (*Message, error) { for i := 0; i < arCount; i++ { msg.AR[i], pos, err = parseResourceRecord(response, pos) if err != nil { - return nil, fmt.Errorf("error while parsing AR record: %w", err) + return msg, fmt.Errorf("error while parsing AR record: %w", err) } }