From cf078c330c80afbf73ee300a6c19624432a5867a Mon Sep 17 00:00:00 2001 From: Vladimir Avtsenov Date: Sat, 8 Feb 2025 06:23:36 +0300 Subject: [PATCH] refactoring, dns over tcp support, moving to uuid --- README.md | 2 +- dns-mitm/mitm.go | 239 +++++++++++++++++++++++++++++++++ dns-proxy/dns-proxy.go | 119 ---------------- dns-proxy/parser.go | 195 --------------------------- dns-proxy/types.go | 196 --------------------------- dns-proxy/types_test.go | 225 ------------------------------- go.mod | 8 +- group.go | 20 +-- kvas2.go | 239 ++++++++++++++++++++++----------- main.go | 11 +- models/domain.go | 33 ----- models/domain_test.go | 42 ------ models/group.go | 6 +- models/rule.go | 33 +++++ models/rule_test.go | 42 ++++++ netfilter-helper/port-remap.go | 5 + 16 files changed, 509 insertions(+), 906 deletions(-) create mode 100644 dns-mitm/mitm.go delete mode 100644 dns-proxy/dns-proxy.go delete mode 100644 dns-proxy/parser.go delete mode 100644 dns-proxy/types.go delete mode 100644 dns-proxy/types_test.go delete mode 100644 models/domain.go delete mode 100644 models/domain_test.go create mode 100644 models/rule.go create mode 100644 models/rule_test.go diff --git a/README.md b/README.md index 1e357e6..3b583ea 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Better implementation of [KVAS](https://github.com/qzeleza/kvas) Realized features: - [x] DNS Proxy (UDP) -- [ ] DNS Proxy (TCP) +- [x] DNS Proxy (TCP) - [x] Records memory - [x] IPTables rules for rebind DNS server port - [X] IPSet integration diff --git a/dns-mitm/mitm.go b/dns-mitm/mitm.go new file mode 100644 index 0000000..d785990 --- /dev/null +++ b/dns-mitm/mitm.go @@ -0,0 +1,239 @@ +package dnsMitm + +import ( + "context" + "encoding/binary" + "fmt" + "net" + "strconv" + "time" + + "github.com/miekg/dns" + "github.com/rs/zerolog/log" +) + +type DNSMITM struct { + ListenPort uint16 + TargetDNSServerAddress string + TargetDNSServerPort uint16 + + RequestHook func(net.Addr, dns.Msg, string) (*dns.Msg, *dns.Msg, error) + ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error) +} + +func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) { + serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort)) + if err != nil { + return nil, fmt.Errorf("failed to dial DNS server: %w", err) + } + defer func() { _ = serverConn.Close() }() + + err = serverConn.SetDeadline(time.Now().Add(time.Second * 5)) + if err != nil { + return nil, fmt.Errorf("failed to set deadline: %w", err) + } + + if network == "tcp" { + err = binary.Write(serverConn, binary.BigEndian, uint16(len(req))) + if err != nil { + return nil, fmt.Errorf("failed to write length: %w", err) + } + } + + n, err := serverConn.Write(req) + if err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + var resp []byte + if network == "tcp" { + var respLen uint16 + err = binary.Read(serverConn, binary.BigEndian, &respLen) + if err != nil { + return nil, fmt.Errorf("failed to read length: %w", err) + } + resp = make([]byte, respLen) + } else { + resp = make([]byte, 512) + } + + n, err = serverConn.Read(resp) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return resp[:n], nil +} + +func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) { + var reqMsg dns.Msg + if p.RequestHook != nil || p.ResponseHook != nil { + err := reqMsg.Unpack(req) + if err != nil { + return nil, fmt.Errorf("failed to parse request: %w", err) + } + } + + if p.RequestHook != nil { + modifiedReq, modifiedResp, err := p.RequestHook(clientAddr, reqMsg, network) + if err != nil { + return nil, fmt.Errorf("request hook error: %w", err) + } + if modifiedResp != nil { + resp, err := modifiedResp.Pack() + if err != nil { + return nil, fmt.Errorf("failed to send modified response: %w", err) + } + return resp, nil + } + if modifiedReq != nil { + reqMsg = *modifiedReq + req, err = reqMsg.Pack() + if err != nil { + return nil, fmt.Errorf("failed to pack modified request: %w", err) + } + } + } + + resp, err := p.requestDNS(req, network) + if err != nil { + return nil, fmt.Errorf("failed to send request") + } + + if p.ResponseHook != nil { + var respMsg dns.Msg + err = respMsg.Unpack(resp) + if err != nil { + return nil, fmt.Errorf("failed to parse response") + } + + modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network) + if err != nil { + return nil, fmt.Errorf("response hook error: %w", err) + } + if modifiedResp != nil { + resp, err = modifiedResp.Pack() + if err != nil { + return nil, fmt.Errorf("failed to send modified response: %w", err) + } + return resp, nil + } + } + + return resp, nil +} + +func (p DNSMITM) ListenTCP(ctx context.Context) error { + addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort))) + if err != nil { + return fmt.Errorf("failed to resolve tcp address: %v", err) + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen tcp port: %v", err) + } + defer func() { _ = listener.Close() }() + + for { + // Exit if context is done + if ctx.Err() != nil { + return nil + } + + conn, err := listener.Accept() + if err != nil { + log.Error().Err(err).Msg("tcp connection error") + continue + } + + go func(clientConn net.Conn) { + defer func() { _ = clientConn.Close() }() + + var respLen uint16 + err = binary.Read(clientConn, binary.BigEndian, &respLen) + if err != nil { + log.Error().Err(err).Msg("failed to read length") + return + } + + req := make([]byte, int(respLen)) + _, err = clientConn.Read(req) + if err != nil { + log.Error().Err(err).Msg("failed to read tcp request") + return + } + + resp, err := p.processReq(clientConn.RemoteAddr(), req, "tcp") + if err != nil { + log.Error().Err(err).Msg("failed to process request") + return + } + + err = binary.Write(clientConn, binary.BigEndian, uint16(len(resp))) + if err != nil { + log.Error().Err(err).Msg("failed to send length") + return + } + _, err = clientConn.Write(resp) + if err != nil { + log.Error().Err(err).Msg("failed to send response") + return + } + }(conn) + } +} + +func (p DNSMITM) ListenUDP(ctx context.Context) error { + addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort))) + if err != nil { + return fmt.Errorf("failed to resolve udp address: %v", err) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return fmt.Errorf("failed to listen udp port: %v", err) + } + defer func() { _ = conn.Close() }() + + for { + // Exit if context is done + if ctx.Err() != nil { + return nil + } + + req := make([]byte, 512) + n, clientAddr, err := conn.ReadFromUDP(req) + if err != nil { + log.Error().Err(err).Msg("failed to read udp request") + continue + } + req = req[:n] + + go func(clientConn *net.UDPConn, clientAddr *net.UDPAddr) { + resp, err := p.processReq(clientAddr, req, "udp") + if err != nil { + log.Error().Err(err).Msg("failed to process request") + return + } + + _, err = clientConn.WriteToUDP(resp, clientAddr) + if err != nil { + log.Error().Err(err).Msg("failed to send response") + return + } + }(conn, clientAddr) + } +} + +func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM { + dnsMitm := &DNSMITM{ + ListenPort: listenPort, + TargetDNSServerAddress: targetDNSServerAddress, + TargetDNSServerPort: 53, + } + if len(targetDNSServerPort) > 0 { + dnsMitm.TargetDNSServerPort = targetDNSServerPort[0] + } + return dnsMitm +} diff --git a/dns-proxy/dns-proxy.go b/dns-proxy/dns-proxy.go deleted file mode 100644 index ce3e189..0000000 --- a/dns-proxy/dns-proxy.go +++ /dev/null @@ -1,119 +0,0 @@ -package dnsProxy - -import ( - "context" - "errors" - "fmt" - "net" - "os" - "time" - - "github.com/rs/zerolog/log" -) - -const ( - DNSMaxUDPPackageSize = 4096 -) - -type DNSProxy struct { - udpConn *net.UDPConn - listenPort uint16 - - targetDNSServerAddress string - - MsgHandler func(*Message) -} - -func (p DNSProxy) Listen(ctx context.Context) error { - var err error - - udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", p.listenPort)) - if err != nil { - return fmt.Errorf("failed to resolve UDP address: %v", err) - } - - p.udpConn, err = net.ListenUDP("udp", udpAddr) - if err != nil { - return fmt.Errorf("failed to listen UDP address: %v", err) - } - - defer func() { - if p.udpConn != nil { - err := p.udpConn.Close() - if err != nil { - log.Error().Err(err).Msg("failed to close UDP connection") - } - } - }() - - for { - select { - case <-ctx.Done(): - return nil - default: - buffer := make([]byte, DNSMaxUDPPackageSize) - n, clientAddr, err := p.udpConn.ReadFromUDP(buffer) - if err != nil { - log.Error().Err(err).Msg("failed to read UDP packet") - continue - } - - go p.handleDNSRequest(clientAddr, buffer[:n]) - } - } -} - -func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) { - conn, err := net.Dial("udp", p.targetDNSServerAddress) - if err != nil { - log.Error().Err(err).Msg("failed to dial target DNS") - return - } - defer conn.Close() - - _, err = conn.Write(buffer) - if err != nil { - log.Error().Err(err).Msg("failed to send request to target DNS") - return - } - - err = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) - if err != nil { - log.Error().Err(err).Msg("failed to set read deadline") - return - } - - response := make([]byte, DNSMaxUDPPackageSize) - n, err := conn.Read(response) - if err != nil { - if errors.Is(err, os.ErrDeadlineExceeded) { - // Just skip it - return - } - - log.Error().Err(err).Msg("failed to read response from target DNS") - return - } - - msg, err := ParseResponse(response[:n]) - if err == nil { - if p.MsgHandler != nil { - p.MsgHandler(msg) - } - } else { - log.Warn().Err(err).Msg("error while parsing DNS message") - } - - _, err = p.udpConn.WriteToUDP(response[:n], clientAddr) - if err != nil { - log.Error().Err(err).Msg("failed to send DNS message") - return - } -} - -func New(listenPort uint16, targetDNSServerAddress string) *DNSProxy { - return &DNSProxy{ - listenPort: listenPort, - targetDNSServerAddress: targetDNSServerAddress, - } -} diff --git a/dns-proxy/parser.go b/dns-proxy/parser.go deleted file mode 100644 index c7b97a9..0000000 --- a/dns-proxy/parser.go +++ /dev/null @@ -1,195 +0,0 @@ -package dnsProxy - -import ( - "encoding/binary" - "errors" - "fmt" - "io" -) - -var ( - ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data") -) - -func parseName(response []byte, pos int) (*Name, int, error) { - var nameParts []string - var jumped bool - var outPos int - responseLen := len(response) - - for { - if responseLen < pos+1 { - return nil, pos, io.EOF - } - length := int(response[pos]) - pos++ - if length == 0 { - break - } - - if length&0xC0 != 0 { - if responseLen < pos+1 { - return nil, pos, io.EOF - } - if !jumped { - outPos = pos + 1 - } - pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF) - jumped = true - continue - } - - if responseLen < pos+length { - return nil, pos, io.EOF - } - - nameParts = append(nameParts, string(response[pos:pos+length])) - pos += length - } - - if !jumped { - outPos = pos - } - return &Name{Parts: nameParts}, outPos, nil -} - -func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) { - responseLen := len(response) - - rhName, pos, err := parseName(response, pos) - if err != nil { - return nil, pos, fmt.Errorf("error while parsing DNS name: %w", err) - } - - if responseLen < pos+10 { - return nil, pos, io.EOF - } - - rh := ResourceRecordHeader{ - Name: *rhName, - Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]), - Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]), - TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]), - } - rdLen := int(binary.BigEndian.Uint16(response[pos+8 : pos+10])) - - pos += 10 - - if responseLen < pos+rdLen { - return nil, pos, io.EOF - } - - switch rh.Type { - case 1: - if rdLen != 4 { - return nil, pos, ErrInvalidDNSAddressResourceData - } - return Address{ - ResourceRecordHeader: rh, - Address: response[pos+0 : pos+4], - }, pos + 4, nil - case 2: - var ns *Name - ns, pos, err = parseName(response, pos) - if err != nil { - return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err) - } - return NameServer{ - ResourceRecordHeader: rh, - NSDName: *ns, - }, pos, nil - case 5: - var cname *Name - cname, pos, err = parseName(response, pos) - if err != nil { - return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err) - } - return CName{ - ResourceRecordHeader: rh, - CName: *cname, - }, pos, nil - } - - return Unknown{ - ResourceRecordHeader: rh, - Data: response[pos+0 : pos+rdLen], - }, pos + rdLen, nil -} - -func ParseResponse(response []byte) (*Message, error) { - var err error - - msg := new(Message) - - responseLen := len(response) - if responseLen < 12 { - return msg, io.EOF - } - - msg.ID = binary.BigEndian.Uint16(response[0:2]) - - flagsRAW := binary.BigEndian.Uint16(response[2:4]) - msg.Flags = Flags{ - QR: uint8(flagsRAW >> 15 & 0x1), - Opcode: uint8(flagsRAW >> 11 & 0xF), - AA: uint8(flagsRAW >> 10 & 0x1), - TC: uint8(flagsRAW >> 9 & 0x1), - RD: uint8(flagsRAW >> 8 & 0x1), - RA: uint8(flagsRAW >> 7 & 0x1), - Z1: uint8(flagsRAW >> 6 & 0x1), - Z2: uint8(flagsRAW >> 5 & 0x1), - Z3: uint8(flagsRAW >> 4 & 0x1), - RCode: uint8(flagsRAW >> 0 & 0xF), - } - - qdCount := int(binary.BigEndian.Uint16(response[4:6])) - anCount := int(binary.BigEndian.Uint16(response[6:8])) - nsCount := int(binary.BigEndian.Uint16(response[8:10])) - arCount := int(binary.BigEndian.Uint16(response[10:12])) - - pos := 12 - - msg.QD = make([]Question, qdCount) - for i := 0; i < qdCount; i++ { - var name *Name - name, pos, err = parseName(response, pos) - if err != nil { - return msg, fmt.Errorf("error while parsing DNS name: %w", err) - } - if responseLen < pos+4 { - return msg, io.EOF - } - msg.QD[i] = Question{ - QName: *name, - QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]), - QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]), - } - pos += 4 - } - - msg.AN = make([]ResourceRecord, anCount) - for i := 0; i < anCount; i++ { - msg.AN[i], pos, err = parseResourceRecord(response, pos) - if err != nil { - return msg, fmt.Errorf("error while parsing AN record: %w", err) - } - } - - msg.NS = make([]ResourceRecord, nsCount) - for i := 0; i < nsCount; i++ { - msg.NS[i], pos, err = parseResourceRecord(response, pos) - if err != nil { - return msg, fmt.Errorf("error while parsing NS record: %w", err) - } - } - - msg.AR = make([]ResourceRecord, arCount) - for i := 0; i < arCount; i++ { - msg.AR[i], pos, err = parseResourceRecord(response, pos) - if err != nil { - return msg, fmt.Errorf("error while parsing AR record: %w", err) - } - } - - return msg, nil -} diff --git a/dns-proxy/types.go b/dns-proxy/types.go deleted file mode 100644 index 294b5f4..0000000 --- a/dns-proxy/types.go +++ /dev/null @@ -1,196 +0,0 @@ -package dnsProxy - -import ( - "bytes" - "encoding/binary" - "net" - "strings" -) - -type ResourceRecord interface { - EncodeResource() []byte -} - -type ResourceRecordHeader struct { - Name Name - Type uint16 - Class uint16 - TTL uint32 -} - -func (q ResourceRecordHeader) EncodeHeader() []byte { - buf := bytes.NewBuffer([]byte{}) - buf.Write(q.Name.Encode()) - buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type)) - buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class)) - buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL)) - return buf.Bytes() -} - -type Name struct { - Parts []string -} - -func (n Name) String() string { - return strings.Join(n.Parts, ".") -} - -func (n Name) Encode() []byte { - buf := bytes.NewBuffer([]byte{}) - for _, part := range n.Parts { - partLen := byte(len(part)) & 0x3F - buf.WriteByte(partLen) - buf.Write([]byte(part)[0:partLen]) - } - buf.WriteByte(0) - return buf.Bytes() -} - -type Flags struct { - QR uint8 - Opcode uint8 - AA uint8 - TC uint8 - RD uint8 - RA uint8 - Z1 uint8 - Z2 uint8 - Z3 uint8 - RCode uint8 -} - -func (f Flags) Encode() []byte { - return []byte{ - f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0, - f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0, - } -} - -type Question struct { - QName Name - QType uint16 - QClass uint16 -} - -func (q Question) EncodeQuestion() []byte { - buf := bytes.NewBuffer([]byte{}) - buf.Write(q.QName.Encode()) - buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType)) - buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass)) - return buf.Bytes() -} - -type Address struct { - ResourceRecordHeader - Address net.IP -} - -func (a Address) EncodeResource() []byte { - rr := bytes.NewBuffer([]byte{}) - rr.Write(a.ResourceRecordHeader.EncodeHeader()) - rr.Write([]byte{0x00, 0x04}) - rr.Write(a.Address[:]) - return rr.Bytes() -} - -type NameServer struct { - ResourceRecordHeader - NSDName Name -} - -func (a NameServer) EncodeResource() []byte { - rdataBytes := a.NSDName.Encode() - rr := bytes.NewBuffer([]byte{}) - rr.Write(a.ResourceRecordHeader.EncodeHeader()) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) - rr.Write(rdataBytes) - return rr.Bytes() -} - -type CName struct { - ResourceRecordHeader - CName Name -} - -func (a CName) EncodeResource() []byte { - rdataBytes := a.CName.Encode() - rr := bytes.NewBuffer([]byte{}) - rr.Write(a.ResourceRecordHeader.EncodeHeader()) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) - rr.Write(rdataBytes) - return rr.Bytes() -} - -type Authority struct { - ResourceRecordHeader - MName Name - RName Name - Serial uint32 - Refresh uint32 - Retry uint32 - Expire uint32 - Minimum uint32 -} - -func (a Authority) EncodeResource() []byte { - rdata := bytes.NewBuffer([]byte{}) - rdata.Write(a.MName.Encode()) - rdata.Write(a.RName.Encode()) - rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial)) - rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh)) - rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry)) - rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire)) - rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum)) - rdataBytes := rdata.Bytes() - - rr := bytes.NewBuffer([]byte{}) - rr.Write(a.ResourceRecordHeader.EncodeHeader()) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes)))) - rr.Write(rdataBytes) - return rr.Bytes() -} - -type Unknown struct { - ResourceRecordHeader - Data []byte -} - -func (u Unknown) EncodeResource() []byte { - rr := bytes.NewBuffer([]byte{}) - rr.Write(u.ResourceRecordHeader.EncodeHeader()) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data)))) - rr.Write(u.Data) - return rr.Bytes() -} - -type Message struct { - ID uint16 - Flags Flags - QD []Question - AN []ResourceRecord - NS []ResourceRecord - AR []ResourceRecord -} - -func (m Message) Encode() []byte { - rr := bytes.NewBuffer([]byte{}) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID)) - rr.Write(m.Flags.Encode()) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD)))) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN)))) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS)))) - rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR)))) - for _, q := range m.QD { - rr.Write(q.EncodeQuestion()) - } - for _, a := range m.AN { - rr.Write(a.EncodeResource()) - } - for _, ns := range m.NS { - rr.Write(ns.EncodeResource()) - } - for _, a := range m.AR { - rr.Write(a.EncodeResource()) - } - return rr.Bytes() -} diff --git a/dns-proxy/types_test.go b/dns-proxy/types_test.go deleted file mode 100644 index c285575..0000000 --- a/dns-proxy/types_test.go +++ /dev/null @@ -1,225 +0,0 @@ -package dnsProxy - -import ( - "bytes" - "testing" -) - -func TestDNSResourceRecordHeaderEncode(t *testing.T) { - recordHeader := ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - } - recordHeaderEncoded := recordHeader.EncodeHeader() - recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0} - if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 { - t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood) - } -} - -func TestDNSNameString(t *testing.T) { - dnsName := Name{Parts: []string{"example", "com"}} - dnsNameString := dnsName.String() - dnsNameStringGood := "example.com" - if dnsNameString != dnsNameStringGood { - t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood) - } -} - -func TestDNSNameEncode(t *testing.T) { - dnsName := Name{Parts: []string{"example", "com"}} - dnsNameEncoded := dnsName.Encode() - dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} - if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 { - t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood) - } -} - -func TestDNSFlagsEncode(t *testing.T) { - dnsFlags := Flags{ - QR: 0x1, - Opcode: 0xF, - AA: 0x0, - TC: 0x0, - RD: 0x1, - RA: 0x1, - Z1: 0x0, - Z2: 0x0, - Z3: 0x0, - RCode: 0xF, - } - dnsFlagsEncoded := dnsFlags.Encode() - dnsFlagsEncodedGood := []byte{0xf9, 0x8f} - if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 { - t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood) - } -} - -func TestDNSQuestionEncode(t *testing.T) { - dnsQuestion := Question{ - QName: Name{Parts: []string{"example", "com"}}, - QType: 0x001c, - QClass: 0x0001, - } - dnsQuestionEncoded := dnsQuestion.EncodeQuestion() - dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01} - if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 { - t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood) - } -} - -func TestDNSAddressEncode(t *testing.T) { - dnsAddress := Address{ - ResourceRecordHeader: ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - }, - Address: []byte{192, 168, 1, 1}, - } - dnsAddressEncoded := dnsAddress.EncodeResource() - dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1} - if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 { - t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood) - } -} - -func TestDNSNameServerEncode(t *testing.T) { - dnsNameServer := NameServer{ - ResourceRecordHeader: ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - }, - NSDName: Name{Parts: []string{"example", "com"}}, - } - dnsNameServerEncoded := dnsNameServer.EncodeResource() - dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} - if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 { - t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood) - } -} - -func TestDNSCNameEncode(t *testing.T) { - dnsCName := CName{ - ResourceRecordHeader: ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - }, - CName: Name{Parts: []string{"example", "com"}}, - } - dnsCNameEncoded := dnsCName.EncodeResource() - dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00} - if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 { - t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood) - } -} - -func TestDNSAuthorityEncode(t *testing.T) { - dnsAuthority := Authority{ - ResourceRecordHeader: ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - }, - MName: Name{Parts: []string{"example", "com"}}, - RName: Name{Parts: []string{"example", "com"}}, - Serial: 0x12345678, - Refresh: 0x12345678, - Retry: 0x12345678, - Expire: 0x12345678, - Minimum: 0x12345678, - } - dnsAuthorityEncoded := dnsAuthority.EncodeResource() - dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78} - if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 { - t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood) - } -} - -func TestDNSUnknownEncode(t *testing.T) { - dnsUnknown := Unknown{ - ResourceRecordHeader: ResourceRecordHeader{ - Name: Name{Parts: []string{"example", "com"}}, - Type: 0xF0, - Class: 0xF0, - TTL: 0x77770FF0, - }, - Data: []byte{0x01, 0x02, 0x03}, - } - dnsUnknownEncoded := dnsUnknown.EncodeResource() - dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03} - if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 { - t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood) - } -} - -//func TestDNSMessageEncode(t *testing.T) { -// dnsMessage := Message{ -// ID: 0x00FF, -// Flags: Flags{ -// QR: 0x1, -// Opcode: 0xF, -// AA: 0x0, -// TC: 0x0, -// RD: 0x1, -// RA: 0x1, -// Z1: 0x0, -// Z2: 0x0, -// Z3: 0x0, -// RCode: 0xF, -// }, -// QD: []Question{ -// { -// QName: Name{Parts: []string{"example", "com"}}, -// QType: 0x001c, -// QClass: 0x0001, -// }, -// }, -// AN: []ResourceRecord{ -// Unknown{ -// ResourceRecordHeader: ResourceRecordHeader{ -// Name: Name{Parts: []string{"example", "com"}}, -// Type: 0xF0, -// Class: 0xF0, -// TTL: 0x77770FF0, -// }, -// Data: []byte{0x01, 0x02, 0x03}, -// }, -// }, -// NS: []ResourceRecord{ -// Unknown{ -// ResourceRecordHeader: ResourceRecordHeader{ -// Name: Name{Parts: []string{"example", "com"}}, -// Type: 0xF0, -// Class: 0xF0, -// TTL: 0x77770FF0, -// }, -// Data: []byte{0x01, 0x02, 0x03}, -// }, -// }, -// AR: []ResourceRecord{ -// Unknown{ -// ResourceRecordHeader: ResourceRecordHeader{ -// Name: Name{Parts: []string{"example", "com"}}, -// Type: 0xF0, -// Class: 0xF0, -// TTL: 0x77770FF0, -// }, -// Data: []byte{0x01, 0x02, 0x03}, -// }, -// }, -// } -// dnsMessageEncoded := dnsMessage.Encode() -// dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03} -// if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 { -// t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood) -// } -//} diff --git a/go.mod b/go.mod index 391fc4d..38bbac0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.21 require ( github.com/IGLOU-EU/go-wildcard/v2 v2.0.2 github.com/coreos/go-iptables v0.7.0 + github.com/google/uuid v1.6.0 + github.com/miekg/dns v1.1.63 github.com/rs/zerolog v1.33.0 github.com/vishvananda/netlink v1.3.0 ) @@ -13,5 +15,9 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/vishvananda/netns v0.0.4 // indirect - golang.org/x/sys v0.24.0 // indirect + golang.org/x/mod v0.18.0 // indirect + golang.org/x/net v0.31.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/tools v0.22.0 // indirect ) diff --git a/group.go b/group.go index ff7fd80..7e77ee5 100644 --- a/group.go +++ b/group.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "net" "time" @@ -15,9 +16,9 @@ type Group struct { Enabled bool - iptables *iptables.IPTables - ipset *netfilterHelper.IPSet - ifaceToIPSet *netfilterHelper.IfaceToIPSet + iptables *iptables.IPTables + ipset *netfilterHelper.IPSet + ifaceToIPSetNAT *netfilterHelper.IfaceToIPSet } func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error { @@ -44,10 +45,13 @@ func (g *Group) Enable() error { }() if g.FixProtect { - g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT") + err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT") + if err != nil { + return fmt.Errorf("failed to fix protect: %w", err) + } } - err := g.ifaceToIPSet.Enable() + err := g.ifaceToIPSetNAT.Enable() if err != nil { return err } @@ -64,9 +68,9 @@ func (g *Group) Disable() []error { return nil } - errs2 := g.ifaceToIPSet.Disable() - if errs2 != nil { - errs = append(errs, errs2...) + err := g.ifaceToIPSetNAT.Disable() + if err != nil { + errs = append(errs, err...) } g.Enabled = false diff --git a/kvas2.go b/kvas2.go index 316a029..12faab6 100644 --- a/kvas2.go +++ b/kvas2.go @@ -9,13 +9,16 @@ import ( "strings" "time" - "kvas2-go/dns-proxy" + "kvas2-go/dns-mitm" "kvas2-go/models" "kvas2-go/netfilter-helper" + "github.com/google/uuid" + "github.com/miekg/dns" "github.com/rs/zerolog/log" "github.com/vishvananda/netlink" "github.com/vishvananda/netlink/nl" + "golang.org/x/sys/unix" ) var ( @@ -36,11 +39,11 @@ type Config struct { type App struct { Config Config - DNSProxy *dnsProxy.DNSProxy + DNSMITM *dnsMitm.DNSMITM NetfilterHelper4 *netfilterHelper.NetfilterHelper NetfilterHelper6 *netfilterHelper.NetfilterHelper Records *Records - Groups map[int]*Group + Groups map[uuid.UUID]*Group Link netlink.Link @@ -51,19 +54,23 @@ type App struct { func (a *App) handleLink(event netlink.LinkUpdate) { switch event.Change { - case 0x00000001: + case unix.IFF_UP: log.Debug(). Str("interface", event.Link.Attrs().Name). Str("operstatestr", event.Attrs().OperState.String()). Int("operstate", int(event.Attrs().OperState)). Msg("interface change") - if event.Attrs().OperState != netlink.OperDown { + switch event.Attrs().OperState { + case netlink.OperUp: + ifaceName := event.Link.Attrs().Name for _, group := range a.Groups { - if group.Interface == event.Link.Attrs().Name { - err := group.ifaceToIPSet.IfaceHandle() - if err != nil { - log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up") - } + if group.Interface != ifaceName { + continue + } + + err := group.ifaceToIPSetNAT.IfaceHandle() + if err != nil { + log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up") } } } @@ -83,22 +90,34 @@ func (a *App) handleLink(event netlink.LinkUpdate) { } } -func (a *App) listen(ctx context.Context) (err error) { - errChan := make(chan error) - +func (a *App) start(ctx context.Context) (err error) { newCtx, cancel := context.WithCancel(ctx) defer cancel() + // TODO: Chan err + errChan := make(chan error) + + /* + DNS Proxy + */ + go func() { - err := a.DNSProxy.Listen(newCtx) + err := a.DNSMITM.ListenUDP(newCtx) if err != nil { - errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err) + errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err) + } + }() + + go func() { + err := a.DNSMITM.ListenTCP(newCtx) + if err != nil { + errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err) } }() addrList, err := netlink.AddrList(a.Link, nl.FAMILY_ALL) if err != nil { - return fmt.Errorf("failed to addrList address: %w", err) + return fmt.Errorf("failed to list address of interface: %w", err) } a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList) @@ -121,6 +140,10 @@ func (a *App) listen(ctx context.Context) (err error) { _ = a.dnsOverrider6.Disable() }() + /* + Groups + */ + for _, group := range a.Groups { err = group.Enable() if err != nil { @@ -134,6 +157,9 @@ func (a *App) listen(ctx context.Context) (err error) { } }() + /* + Socket (for netfilter.d events) + */ socketPath := "/opt/var/run/kvas2-go.sock" err = os.Remove(socketPath) if err != nil && !errors.Is(err, os.ErrNotExist) { @@ -181,8 +207,8 @@ func (a *App) listen(ctx context.Context) (err error) { } } for _, group := range a.Groups { - if group.ifaceToIPSet.Enabled { - err := group.ifaceToIPSet.PutIPTable(args[2]) + if group.ifaceToIPSetNAT.Enabled { + err := group.ifaceToIPSetNAT.PutIPTable(args[2]) if err != nil { log.Error().Err(err).Msg("error while fixing iptables after netfilter.d") } @@ -193,21 +219,28 @@ func (a *App) listen(ctx context.Context) (err error) { } }() - link := make(chan netlink.LinkUpdate) - done := make(chan struct{}) - err = netlink.LinkSubscribe(link, done) + /* + Interface updates + */ + linkUpdateChannel := make(chan netlink.LinkUpdate) + linkUpdateDone := make(chan struct{}) + err = netlink.LinkSubscribe(linkUpdateChannel, linkUpdateDone) if err != nil { return fmt.Errorf("failed to subscribe to link updates: %w", err) } defer func() { - close(done) + close(linkUpdateDone) }() + /* + Global loop + */ for { select { - case event := <-link: + case event := <-linkUpdateChannel: a.handleLink(event) case err := <-errChan: + close(errChan) return err case <-ctx.Done(): return nil @@ -215,7 +248,7 @@ func (a *App) listen(ctx context.Context) (err error) { } } -func (a *App) Listen(ctx context.Context) (err error) { +func (a *App) Start(ctx context.Context) (err error) { if a.isRunning { return ErrAlreadyRunning } @@ -226,20 +259,16 @@ func (a *App) Listen(ctx context.Context) (err error) { defer func() { if r := recover(); r != nil { - var recoveredError error var ok bool - if recoveredError, ok = r.(error); !ok { - recoveredError = fmt.Errorf("%v", r) + if err, ok = r.(error); !ok { + err = fmt.Errorf("%v", r) } - err = fmt.Errorf("recovered error: %w", recoveredError) + err = fmt.Errorf("recovered error: %w", err) } }() - appErr := a.listen(ctx) - if appErr != nil { - return appErr - } + err = a.start(ctx) return err } @@ -249,19 +278,19 @@ func (a *App) AddGroup(group *models.Group) error { return ErrGroupIDConflict } - ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPrefix, group.ID) + ipsetName := fmt.Sprintf("%s%8x", a.Config.IpSetPrefix, group.ID.ID()) ipset, err := a.NetfilterHelper4.IPSet(ipsetName) if err != nil { return fmt.Errorf("failed to initialize ipset: %w", err) } grp := &Group{ - Group: group, - iptables: a.NetfilterHelper4.IPTables, - ipset: ipset, - ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false), + Group: group, + iptables: a.NetfilterHelper4.IPTables, + ipset: ipset, + ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false), } - a.Groups[group.ID] = grp + a.Groups[grp.ID] = grp return a.SyncGroup(grp) } @@ -276,7 +305,7 @@ func (a *App) SyncGroup(group *Group) error { } knownDomains := a.Records.ListKnownDomains() - for _, domain := range group.Domains { + for _, domain := range group.Rules { if !domain.IsEnabled() { continue } @@ -359,24 +388,24 @@ func (a *App) ListInterfaces() ([]net.Interface, error) { return interfaceNames, nil } -func (a *App) processARecord(aRecord dnsProxy.Address) { +func (a *App) processARecord(aRecord dns.A) { log.Trace(). - Str("name", aRecord.Name.String()). - Str("address", aRecord.Address.String()). - Int("ttl", int(aRecord.TTL)). + Str("name", aRecord.Hdr.Name). + Str("address", aRecord.A.String()). + Int("ttl", int(aRecord.Hdr.Ttl)). Msg("processing a record") - ttlDuration := time.Duration(aRecord.TTL) * time.Second + ttlDuration := time.Duration(aRecord.Hdr.Ttl) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } - a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration) + a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration) - names := a.Records.GetCNameRecords(aRecord.Name.String(), true) + names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true) for _, group := range a.Groups { - Domain: - for _, domain := range group.Domains { + Rule: + for _, domain := range group.Rules { if !domain.IsEnabled() { continue } @@ -384,47 +413,47 @@ func (a *App) processARecord(aRecord dnsProxy.Address) { if !domain.IsMatch(name) { continue } - err := group.AddIPv4(aRecord.Address, ttlDuration) + err := group.AddIPv4(aRecord.A, ttlDuration) if err != nil { log.Error(). - Str("address", aRecord.Address.String()). + Str("address", aRecord.A.String()). Err(err). Msg("failed to add address") } else { log.Trace(). - Str("address", aRecord.Address.String()). - Str("aRecordDomain", aRecord.Name.String()). + Str("address", aRecord.A.String()). + Str("aRecordDomain", aRecord.Hdr.Name). Str("cNameDomain", name). Err(err). Msg("add address") } - break Domain + break Rule } } } } -func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { +func (a *App) processCNameRecord(cNameRecord dns.CNAME) { log.Trace(). - Str("name", cNameRecord.Name.String()). - Str("cname", cNameRecord.CName.String()). - Int("ttl", int(cNameRecord.TTL)). + Str("name", cNameRecord.Hdr.Name). + Str("cname", cNameRecord.Target). + Int("ttl", int(cNameRecord.Hdr.Ttl)). Msg("processing cname record") - ttlDuration := time.Duration(cNameRecord.TTL) * time.Second + ttlDuration := time.Duration(cNameRecord.Hdr.Ttl) * time.Second if ttlDuration < a.Config.MinimalTTL { ttlDuration = a.Config.MinimalTTL } - a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration) + a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration) // TODO: Optimization now := time.Now() - aRecords := a.Records.GetARecords(cNameRecord.Name.String()) - names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true) + aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name) + names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true) for _, group := range a.Groups { - Domain: - for _, domain := range group.Domains { + Rule: + for _, domain := range group.Rules { if !domain.IsEnabled() { continue } @@ -447,31 +476,30 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) { Msg("add address") } } - continue Domain + continue Rule } } } } -func (a *App) handleRecord(rr dnsProxy.ResourceRecord) { +func (a *App) handleRecord(rr dns.RR) { switch v := rr.(type) { - case dnsProxy.Address: - // TODO: Optimize equals domain A records - a.processARecord(v) - case dnsProxy.CName: - a.processCNameRecord(v) + case *dns.A: + a.processARecord(*v) + case *dns.CNAME: + a.processCNameRecord(*v) default: } } -func (a *App) handleMessage(msg *dnsProxy.Message) { - for _, rr := range msg.AN { +func (a *App) handleMessage(msg dns.Msg) { + for _, rr := range msg.Answer { a.handleRecord(rr) } - for _, rr := range msg.NS { + for _, rr := range msg.Ns { a.handleRecord(rr) } - for _, rr := range msg.AR { + for _, rr := range msg.Extra { a.handleRecord(rr) } } @@ -483,17 +511,68 @@ func New(config Config) (*App, error) { app.Config = config + app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress) + app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) { + log.Debug(). + Str("network", network). + Str("clientAddr", clientAddr.String()). + Str("name", reqMsg.Question[0].Name). + Msg("received DNS request") + + // TODO: Need to understand why it not works in proxy mode + if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR { + respMsg := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: reqMsg.Id, + Response: true, + RecursionAvailable: true, + Rcode: dns.RcodeNameError, + }, + Question: reqMsg.Question, + } + log.Debug(). + Str("network", network). + Str("clientAddr", clientAddr.String()). + Msg("sending DNS response") + return nil, respMsg, nil + } + + return nil, nil, nil + } + app.DNSMITM.ResponseHook = func(clientAddr net.Addr, reqMsg dns.Msg, respMsg dns.Msg, network string) (*dns.Msg, error) { + // TODO: Make it optional + var idx int + for _, a := range respMsg.Answer { + if a.Header().Rrtype == dns.TypeAAAA { + continue + } + respMsg.Answer[idx] = a + idx++ + } + respMsg.Answer = respMsg.Answer[:idx] + + if len(respMsg.Answer) != 0 { + log.Debug(). + Str("network", network). + Str("clientAddr", clientAddr.String()). + Str("respMsg", respMsg.Answer[0].Header().Name). + Msg("sending DNS response") + } + + app.handleMessage(respMsg) + + return &respMsg, nil + } + + app.Records = NewRecords() + app.Groups = make(map[uuid.UUID]*Group, 0) + link, err := netlink.LinkByName(app.Config.LinkName) if err != nil { return nil, fmt.Errorf("failed to find link %s: %w", app.Config.LinkName, err) } app.Link = link - app.DNSProxy = dnsProxy.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress) - app.DNSProxy.MsgHandler = app.handleMessage - - app.Records = NewRecords() - nh4, err := netfilterHelper.New(false) if err != nil { return nil, fmt.Errorf("netfilter helper init fail: %w", err) @@ -514,7 +593,7 @@ func New(config Config) (*App, error) { return nil, fmt.Errorf("failed to clear iptables: %w", err) } - app.Groups = make(map[int]*Group) + app.Groups = make(map[uuid.UUID]*Group) return app, nil } diff --git a/main.go b/main.go index 3c69268..8c3ff35 100644 --- a/main.go +++ b/main.go @@ -19,7 +19,7 @@ func main() { ChainPrefix: "KVAS2_", IpSetPrefix: "kvas2_", LinkName: "br0", - TargetDNSServerAddress: "127.0.0.1:53", + TargetDNSServerAddress: "127.0.0.1", ListenDNSPort: 7553, }) if err != nil { @@ -28,13 +28,16 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) + log.Info().Msg("starting service") + + /* + Starting app with graceful shutdown + */ appResult := make(chan error) go func() { - appResult <- app.Listen(ctx) + appResult <- app.Start(ctx) }() - log.Info().Msg("starting service") - c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) diff --git a/models/domain.go b/models/domain.go deleted file mode 100644 index c1338dc..0000000 --- a/models/domain.go +++ /dev/null @@ -1,33 +0,0 @@ -package models - -import ( - "regexp" - - "github.com/IGLOU-EU/go-wildcard/v2" -) - -type Domain struct { - ID int - Group *Group - Type string - Domain string - Enable bool - Comment string -} - -func (d *Domain) IsEnabled() bool { - return d.Enable -} - -func (d *Domain) IsMatch(domainName string) bool { - switch d.Type { - case "wildcard": - return wildcard.Match(d.Domain, domainName) - case "regex": - ok, _ := regexp.MatchString(d.Domain, domainName) - return ok - case "plaintext": - return domainName == d.Domain - } - return false -} diff --git a/models/domain_test.go b/models/domain_test.go deleted file mode 100644 index 0b682dc..0000000 --- a/models/domain_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package models - -import "testing" - -func TestDomain_IsMatch_Plaintext(t *testing.T) { - domain := &Domain{ - Type: "plaintext", - Domain: "example.com", - } - if !domain.IsMatch("example.com") { - t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"example.com\") returns false") - } - if domain.IsMatch("noexample.com") { - t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"noexample.com\") returns true") - } -} - -func TestDomain_IsMatch_Wildcard(t *testing.T) { - domain := &Domain{ - Type: "wildcard", - Domain: "ex*le.com", - } - if !domain.IsMatch("example.com") { - t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"example.com\") returns false") - } - if domain.IsMatch("noexample.com") { - t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true") - } -} - -func TestDomain_IsMatch_RegEx(t *testing.T) { - domain := &Domain{ - Type: "regex", - Domain: "^ex[apm]{3}le.com$", - } - if !domain.IsMatch("example.com") { - t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false") - } - if domain.IsMatch("noexample.com") { - t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true") - } -} diff --git a/models/group.go b/models/group.go index 81dbd11..d810e52 100644 --- a/models/group.go +++ b/models/group.go @@ -1,9 +1,11 @@ package models +import "github.com/google/uuid" + type Group struct { - ID int + ID uuid.UUID Name string Interface string + Rules []*Rule FixProtect bool - Domains []*Domain } diff --git a/models/rule.go b/models/rule.go new file mode 100644 index 0000000..3c45fb8 --- /dev/null +++ b/models/rule.go @@ -0,0 +1,33 @@ +package models + +import ( + "regexp" + + "github.com/IGLOU-EU/go-wildcard/v2" + "github.com/google/uuid" +) + +type Rule struct { + ID uuid.UUID + Name string + Type string + Rule string + Enable bool +} + +func (d *Rule) IsEnabled() bool { + return d.Enable +} + +func (d *Rule) IsMatch(domainName string) bool { + switch d.Type { + case "wildcard": + return wildcard.Match(d.Rule, domainName) + case "regex": + ok, _ := regexp.MatchString(d.Rule, domainName) + return ok + case "plaintext": + return domainName == d.Rule + } + return false +} diff --git a/models/rule_test.go b/models/rule_test.go new file mode 100644 index 0000000..e83c9f2 --- /dev/null +++ b/models/rule_test.go @@ -0,0 +1,42 @@ +package models + +import "testing" + +func TestDomain_IsMatch_Plaintext(t *testing.T) { + rule := &Rule{ + Type: "plaintext", + Rule: "example.com", + } + if !rule.IsMatch("example.com") { + t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"example.com\") returns false") + } + if rule.IsMatch("noexample.com") { + t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"noexample.com\") returns true") + } +} + +func TestDomain_IsMatch_Wildcard(t *testing.T) { + rule := &Rule{ + Type: "wildcard", + Rule: "ex*le.com", + } + if !rule.IsMatch("example.com") { + t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"example.com\") returns false") + } + if rule.IsMatch("noexample.com") { + t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true") + } +} + +func TestDomain_IsMatch_RegEx(t *testing.T) { + rule := &Rule{ + Type: "regex", + Rule: "^ex[apm]{3}le.com$", + } + if !rule.IsMatch("example.com") { + t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false") + } + if rule.IsMatch("noexample.com") { + t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true") + } +} diff --git a/netfilter-helper/port-remap.go b/netfilter-helper/port-remap.go index 0e3725d..b0a2787 100644 --- a/netfilter-helper/port-remap.go +++ b/netfilter-helper/port-remap.go @@ -40,6 +40,11 @@ func (r *PortRemap) PutIPTable(table string) error { if err != nil { return fmt.Errorf("failed to create rule: %w", err) } + + err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To)) + if err != nil { + return fmt.Errorf("failed to create rule: %w", err) + } } err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName)