refactoring, dns over tcp support, moving to uuid

This commit is contained in:
Vladimir Avtsenov 2025-02-08 06:23:36 +03:00
parent acce1b8bcc
commit cf078c330c
16 changed files with 509 additions and 906 deletions

View File

@ -4,7 +4,7 @@ Better implementation of [KVAS](https://github.com/qzeleza/kvas)
Realized features:
- [x] DNS Proxy (UDP)
- [ ] DNS Proxy (TCP)
- [x] DNS Proxy (TCP)
- [x] Records memory
- [x] IPTables rules for rebind DNS server port
- [X] IPSet integration

239
dns-mitm/mitm.go Normal file
View File

@ -0,0 +1,239 @@
package dnsMitm
import (
"context"
"encoding/binary"
"fmt"
"net"
"strconv"
"time"
"github.com/miekg/dns"
"github.com/rs/zerolog/log"
)
type DNSMITM struct {
ListenPort uint16
TargetDNSServerAddress string
TargetDNSServerPort uint16
RequestHook func(net.Addr, dns.Msg, string) (*dns.Msg, *dns.Msg, error)
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
}
func (p DNSMITM) requestDNS(req []byte, network string) ([]byte, error) {
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
if err != nil {
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
}
defer func() { _ = serverConn.Close() }()
err = serverConn.SetDeadline(time.Now().Add(time.Second * 5))
if err != nil {
return nil, fmt.Errorf("failed to set deadline: %w", err)
}
if network == "tcp" {
err = binary.Write(serverConn, binary.BigEndian, uint16(len(req)))
if err != nil {
return nil, fmt.Errorf("failed to write length: %w", err)
}
}
n, err := serverConn.Write(req)
if err != nil {
return nil, fmt.Errorf("failed to write request: %w", err)
}
var resp []byte
if network == "tcp" {
var respLen uint16
err = binary.Read(serverConn, binary.BigEndian, &respLen)
if err != nil {
return nil, fmt.Errorf("failed to read length: %w", err)
}
resp = make([]byte, respLen)
} else {
resp = make([]byte, 512)
}
n, err = serverConn.Read(resp)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
return resp[:n], nil
}
func (p DNSMITM) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
var reqMsg dns.Msg
if p.RequestHook != nil || p.ResponseHook != nil {
err := reqMsg.Unpack(req)
if err != nil {
return nil, fmt.Errorf("failed to parse request: %w", err)
}
}
if p.RequestHook != nil {
modifiedReq, modifiedResp, err := p.RequestHook(clientAddr, reqMsg, network)
if err != nil {
return nil, fmt.Errorf("request hook error: %w", err)
}
if modifiedResp != nil {
resp, err := modifiedResp.Pack()
if err != nil {
return nil, fmt.Errorf("failed to send modified response: %w", err)
}
return resp, nil
}
if modifiedReq != nil {
reqMsg = *modifiedReq
req, err = reqMsg.Pack()
if err != nil {
return nil, fmt.Errorf("failed to pack modified request: %w", err)
}
}
}
resp, err := p.requestDNS(req, network)
if err != nil {
return nil, fmt.Errorf("failed to send request")
}
if p.ResponseHook != nil {
var respMsg dns.Msg
err = respMsg.Unpack(resp)
if err != nil {
return nil, fmt.Errorf("failed to parse response")
}
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
if err != nil {
return nil, fmt.Errorf("response hook error: %w", err)
}
if modifiedResp != nil {
resp, err = modifiedResp.Pack()
if err != nil {
return nil, fmt.Errorf("failed to send modified response: %w", err)
}
return resp, nil
}
}
return resp, nil
}
func (p DNSMITM) ListenTCP(ctx context.Context) error {
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
if err != nil {
return fmt.Errorf("failed to resolve tcp address: %v", err)
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen tcp port: %v", err)
}
defer func() { _ = listener.Close() }()
for {
// Exit if context is done
if ctx.Err() != nil {
return nil
}
conn, err := listener.Accept()
if err != nil {
log.Error().Err(err).Msg("tcp connection error")
continue
}
go func(clientConn net.Conn) {
defer func() { _ = clientConn.Close() }()
var respLen uint16
err = binary.Read(clientConn, binary.BigEndian, &respLen)
if err != nil {
log.Error().Err(err).Msg("failed to read length")
return
}
req := make([]byte, int(respLen))
_, err = clientConn.Read(req)
if err != nil {
log.Error().Err(err).Msg("failed to read tcp request")
return
}
resp, err := p.processReq(clientConn.RemoteAddr(), req, "tcp")
if err != nil {
log.Error().Err(err).Msg("failed to process request")
return
}
err = binary.Write(clientConn, binary.BigEndian, uint16(len(resp)))
if err != nil {
log.Error().Err(err).Msg("failed to send length")
return
}
_, err = clientConn.Write(resp)
if err != nil {
log.Error().Err(err).Msg("failed to send response")
return
}
}(conn)
}
}
func (p DNSMITM) ListenUDP(ctx context.Context) error {
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(p.ListenPort)))
if err != nil {
return fmt.Errorf("failed to resolve udp address: %v", err)
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to listen udp port: %v", err)
}
defer func() { _ = conn.Close() }()
for {
// Exit if context is done
if ctx.Err() != nil {
return nil
}
req := make([]byte, 512)
n, clientAddr, err := conn.ReadFromUDP(req)
if err != nil {
log.Error().Err(err).Msg("failed to read udp request")
continue
}
req = req[:n]
go func(clientConn *net.UDPConn, clientAddr *net.UDPAddr) {
resp, err := p.processReq(clientAddr, req, "udp")
if err != nil {
log.Error().Err(err).Msg("failed to process request")
return
}
_, err = clientConn.WriteToUDP(resp, clientAddr)
if err != nil {
log.Error().Err(err).Msg("failed to send response")
return
}
}(conn, clientAddr)
}
}
func New(listenPort uint16, targetDNSServerAddress string, targetDNSServerPort ...uint16) *DNSMITM {
dnsMitm := &DNSMITM{
ListenPort: listenPort,
TargetDNSServerAddress: targetDNSServerAddress,
TargetDNSServerPort: 53,
}
if len(targetDNSServerPort) > 0 {
dnsMitm.TargetDNSServerPort = targetDNSServerPort[0]
}
return dnsMitm
}

View File

@ -1,119 +0,0 @@
package dnsProxy
import (
"context"
"errors"
"fmt"
"net"
"os"
"time"
"github.com/rs/zerolog/log"
)
const (
DNSMaxUDPPackageSize = 4096
)
type DNSProxy struct {
udpConn *net.UDPConn
listenPort uint16
targetDNSServerAddress string
MsgHandler func(*Message)
}
func (p DNSProxy) Listen(ctx context.Context) error {
var err error
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", p.listenPort))
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %v", err)
}
p.udpConn, err = net.ListenUDP("udp", udpAddr)
if err != nil {
return fmt.Errorf("failed to listen UDP address: %v", err)
}
defer func() {
if p.udpConn != nil {
err := p.udpConn.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close UDP connection")
}
}
}()
for {
select {
case <-ctx.Done():
return nil
default:
buffer := make([]byte, DNSMaxUDPPackageSize)
n, clientAddr, err := p.udpConn.ReadFromUDP(buffer)
if err != nil {
log.Error().Err(err).Msg("failed to read UDP packet")
continue
}
go p.handleDNSRequest(clientAddr, buffer[:n])
}
}
}
func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) {
conn, err := net.Dial("udp", p.targetDNSServerAddress)
if err != nil {
log.Error().Err(err).Msg("failed to dial target DNS")
return
}
defer conn.Close()
_, err = conn.Write(buffer)
if err != nil {
log.Error().Err(err).Msg("failed to send request to target DNS")
return
}
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if err != nil {
log.Error().Err(err).Msg("failed to set read deadline")
return
}
response := make([]byte, DNSMaxUDPPackageSize)
n, err := conn.Read(response)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
// Just skip it
return
}
log.Error().Err(err).Msg("failed to read response from target DNS")
return
}
msg, err := ParseResponse(response[:n])
if err == nil {
if p.MsgHandler != nil {
p.MsgHandler(msg)
}
} else {
log.Warn().Err(err).Msg("error while parsing DNS message")
}
_, err = p.udpConn.WriteToUDP(response[:n], clientAddr)
if err != nil {
log.Error().Err(err).Msg("failed to send DNS message")
return
}
}
func New(listenPort uint16, targetDNSServerAddress string) *DNSProxy {
return &DNSProxy{
listenPort: listenPort,
targetDNSServerAddress: targetDNSServerAddress,
}
}

View File

@ -1,195 +0,0 @@
package dnsProxy
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
var (
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
)
func parseName(response []byte, pos int) (*Name, int, error) {
var nameParts []string
var jumped bool
var outPos int
responseLen := len(response)
for {
if responseLen < pos+1 {
return nil, pos, io.EOF
}
length := int(response[pos])
pos++
if length == 0 {
break
}
if length&0xC0 != 0 {
if responseLen < pos+1 {
return nil, pos, io.EOF
}
if !jumped {
outPos = pos + 1
}
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
jumped = true
continue
}
if responseLen < pos+length {
return nil, pos, io.EOF
}
nameParts = append(nameParts, string(response[pos:pos+length]))
pos += length
}
if !jumped {
outPos = pos
}
return &Name{Parts: nameParts}, outPos, nil
}
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
responseLen := len(response)
rhName, pos, err := parseName(response, pos)
if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS name: %w", err)
}
if responseLen < pos+10 {
return nil, pos, io.EOF
}
rh := ResourceRecordHeader{
Name: *rhName,
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
}
rdLen := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
pos += 10
if responseLen < pos+rdLen {
return nil, pos, io.EOF
}
switch rh.Type {
case 1:
if rdLen != 4 {
return nil, pos, ErrInvalidDNSAddressResourceData
}
return Address{
ResourceRecordHeader: rh,
Address: response[pos+0 : pos+4],
}, pos + 4, nil
case 2:
var ns *Name
ns, pos, err = parseName(response, pos)
if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
}
return NameServer{
ResourceRecordHeader: rh,
NSDName: *ns,
}, pos, nil
case 5:
var cname *Name
cname, pos, err = parseName(response, pos)
if err != nil {
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
}
return CName{
ResourceRecordHeader: rh,
CName: *cname,
}, pos, nil
}
return Unknown{
ResourceRecordHeader: rh,
Data: response[pos+0 : pos+rdLen],
}, pos + rdLen, nil
}
func ParseResponse(response []byte) (*Message, error) {
var err error
msg := new(Message)
responseLen := len(response)
if responseLen < 12 {
return msg, io.EOF
}
msg.ID = binary.BigEndian.Uint16(response[0:2])
flagsRAW := binary.BigEndian.Uint16(response[2:4])
msg.Flags = Flags{
QR: uint8(flagsRAW >> 15 & 0x1),
Opcode: uint8(flagsRAW >> 11 & 0xF),
AA: uint8(flagsRAW >> 10 & 0x1),
TC: uint8(flagsRAW >> 9 & 0x1),
RD: uint8(flagsRAW >> 8 & 0x1),
RA: uint8(flagsRAW >> 7 & 0x1),
Z1: uint8(flagsRAW >> 6 & 0x1),
Z2: uint8(flagsRAW >> 5 & 0x1),
Z3: uint8(flagsRAW >> 4 & 0x1),
RCode: uint8(flagsRAW >> 0 & 0xF),
}
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
anCount := int(binary.BigEndian.Uint16(response[6:8]))
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
arCount := int(binary.BigEndian.Uint16(response[10:12]))
pos := 12
msg.QD = make([]Question, qdCount)
for i := 0; i < qdCount; i++ {
var name *Name
name, pos, err = parseName(response, pos)
if err != nil {
return msg, fmt.Errorf("error while parsing DNS name: %w", err)
}
if responseLen < pos+4 {
return msg, io.EOF
}
msg.QD[i] = Question{
QName: *name,
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
}
pos += 4
}
msg.AN = make([]ResourceRecord, anCount)
for i := 0; i < anCount; i++ {
msg.AN[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return msg, fmt.Errorf("error while parsing AN record: %w", err)
}
}
msg.NS = make([]ResourceRecord, nsCount)
for i := 0; i < nsCount; i++ {
msg.NS[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return msg, fmt.Errorf("error while parsing NS record: %w", err)
}
}
msg.AR = make([]ResourceRecord, arCount)
for i := 0; i < arCount; i++ {
msg.AR[i], pos, err = parseResourceRecord(response, pos)
if err != nil {
return msg, fmt.Errorf("error while parsing AR record: %w", err)
}
}
return msg, nil
}

View File

@ -1,196 +0,0 @@
package dnsProxy
import (
"bytes"
"encoding/binary"
"net"
"strings"
)
type ResourceRecord interface {
EncodeResource() []byte
}
type ResourceRecordHeader struct {
Name Name
Type uint16
Class uint16
TTL uint32
}
func (q ResourceRecordHeader) EncodeHeader() []byte {
buf := bytes.NewBuffer([]byte{})
buf.Write(q.Name.Encode())
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type))
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class))
buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL))
return buf.Bytes()
}
type Name struct {
Parts []string
}
func (n Name) String() string {
return strings.Join(n.Parts, ".")
}
func (n Name) Encode() []byte {
buf := bytes.NewBuffer([]byte{})
for _, part := range n.Parts {
partLen := byte(len(part)) & 0x3F
buf.WriteByte(partLen)
buf.Write([]byte(part)[0:partLen])
}
buf.WriteByte(0)
return buf.Bytes()
}
type Flags struct {
QR uint8
Opcode uint8
AA uint8
TC uint8
RD uint8
RA uint8
Z1 uint8
Z2 uint8
Z3 uint8
RCode uint8
}
func (f Flags) Encode() []byte {
return []byte{
f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0,
f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0,
}
}
type Question struct {
QName Name
QType uint16
QClass uint16
}
func (q Question) EncodeQuestion() []byte {
buf := bytes.NewBuffer([]byte{})
buf.Write(q.QName.Encode())
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType))
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass))
return buf.Bytes()
}
type Address struct {
ResourceRecordHeader
Address net.IP
}
func (a Address) EncodeResource() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write([]byte{0x00, 0x04})
rr.Write(a.Address[:])
return rr.Bytes()
}
type NameServer struct {
ResourceRecordHeader
NSDName Name
}
func (a NameServer) EncodeResource() []byte {
rdataBytes := a.NSDName.Encode()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type CName struct {
ResourceRecordHeader
CName Name
}
func (a CName) EncodeResource() []byte {
rdataBytes := a.CName.Encode()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type Authority struct {
ResourceRecordHeader
MName Name
RName Name
Serial uint32
Refresh uint32
Retry uint32
Expire uint32
Minimum uint32
}
func (a Authority) EncodeResource() []byte {
rdata := bytes.NewBuffer([]byte{})
rdata.Write(a.MName.Encode())
rdata.Write(a.RName.Encode())
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire))
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum))
rdataBytes := rdata.Bytes()
rr := bytes.NewBuffer([]byte{})
rr.Write(a.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
rr.Write(rdataBytes)
return rr.Bytes()
}
type Unknown struct {
ResourceRecordHeader
Data []byte
}
func (u Unknown) EncodeResource() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(u.ResourceRecordHeader.EncodeHeader())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data))))
rr.Write(u.Data)
return rr.Bytes()
}
type Message struct {
ID uint16
Flags Flags
QD []Question
AN []ResourceRecord
NS []ResourceRecord
AR []ResourceRecord
}
func (m Message) Encode() []byte {
rr := bytes.NewBuffer([]byte{})
rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID))
rr.Write(m.Flags.Encode())
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS))))
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR))))
for _, q := range m.QD {
rr.Write(q.EncodeQuestion())
}
for _, a := range m.AN {
rr.Write(a.EncodeResource())
}
for _, ns := range m.NS {
rr.Write(ns.EncodeResource())
}
for _, a := range m.AR {
rr.Write(a.EncodeResource())
}
return rr.Bytes()
}

View File

@ -1,225 +0,0 @@
package dnsProxy
import (
"bytes"
"testing"
)
func TestDNSResourceRecordHeaderEncode(t *testing.T) {
recordHeader := ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
}
recordHeaderEncoded := recordHeader.EncodeHeader()
recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0}
if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 {
t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood)
}
}
func TestDNSNameString(t *testing.T) {
dnsName := Name{Parts: []string{"example", "com"}}
dnsNameString := dnsName.String()
dnsNameStringGood := "example.com"
if dnsNameString != dnsNameStringGood {
t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood)
}
}
func TestDNSNameEncode(t *testing.T) {
dnsName := Name{Parts: []string{"example", "com"}}
dnsNameEncoded := dnsName.Encode()
dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 {
t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood)
}
}
func TestDNSFlagsEncode(t *testing.T) {
dnsFlags := Flags{
QR: 0x1,
Opcode: 0xF,
AA: 0x0,
TC: 0x0,
RD: 0x1,
RA: 0x1,
Z1: 0x0,
Z2: 0x0,
Z3: 0x0,
RCode: 0xF,
}
dnsFlagsEncoded := dnsFlags.Encode()
dnsFlagsEncodedGood := []byte{0xf9, 0x8f}
if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 {
t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood)
}
}
func TestDNSQuestionEncode(t *testing.T) {
dnsQuestion := Question{
QName: Name{Parts: []string{"example", "com"}},
QType: 0x001c,
QClass: 0x0001,
}
dnsQuestionEncoded := dnsQuestion.EncodeQuestion()
dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01}
if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 {
t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood)
}
}
func TestDNSAddressEncode(t *testing.T) {
dnsAddress := Address{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Address: []byte{192, 168, 1, 1},
}
dnsAddressEncoded := dnsAddress.EncodeResource()
dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1}
if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 {
t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood)
}
}
func TestDNSNameServerEncode(t *testing.T) {
dnsNameServer := NameServer{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
NSDName: Name{Parts: []string{"example", "com"}},
}
dnsNameServerEncoded := dnsNameServer.EncodeResource()
dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 {
t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood)
}
}
func TestDNSCNameEncode(t *testing.T) {
dnsCName := CName{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
CName: Name{Parts: []string{"example", "com"}},
}
dnsCNameEncoded := dnsCName.EncodeResource()
dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 {
t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood)
}
}
func TestDNSAuthorityEncode(t *testing.T) {
dnsAuthority := Authority{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
MName: Name{Parts: []string{"example", "com"}},
RName: Name{Parts: []string{"example", "com"}},
Serial: 0x12345678,
Refresh: 0x12345678,
Retry: 0x12345678,
Expire: 0x12345678,
Minimum: 0x12345678,
}
dnsAuthorityEncoded := dnsAuthority.EncodeResource()
dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}
if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 {
t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood)
}
}
func TestDNSUnknownEncode(t *testing.T) {
dnsUnknown := Unknown{
ResourceRecordHeader: ResourceRecordHeader{
Name: Name{Parts: []string{"example", "com"}},
Type: 0xF0,
Class: 0xF0,
TTL: 0x77770FF0,
},
Data: []byte{0x01, 0x02, 0x03},
}
dnsUnknownEncoded := dnsUnknown.EncodeResource()
dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 {
t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood)
}
}
//func TestDNSMessageEncode(t *testing.T) {
// dnsMessage := Message{
// ID: 0x00FF,
// Flags: Flags{
// QR: 0x1,
// Opcode: 0xF,
// AA: 0x0,
// TC: 0x0,
// RD: 0x1,
// RA: 0x1,
// Z1: 0x0,
// Z2: 0x0,
// Z3: 0x0,
// RCode: 0xF,
// },
// QD: []Question{
// {
// QName: Name{Parts: []string{"example", "com"}},
// QType: 0x001c,
// QClass: 0x0001,
// },
// },
// AN: []ResourceRecord{
// Unknown{
// ResourceRecordHeader: ResourceRecordHeader{
// Name: Name{Parts: []string{"example", "com"}},
// Type: 0xF0,
// Class: 0xF0,
// TTL: 0x77770FF0,
// },
// Data: []byte{0x01, 0x02, 0x03},
// },
// },
// NS: []ResourceRecord{
// Unknown{
// ResourceRecordHeader: ResourceRecordHeader{
// Name: Name{Parts: []string{"example", "com"}},
// Type: 0xF0,
// Class: 0xF0,
// TTL: 0x77770FF0,
// },
// Data: []byte{0x01, 0x02, 0x03},
// },
// },
// AR: []ResourceRecord{
// Unknown{
// ResourceRecordHeader: ResourceRecordHeader{
// Name: Name{Parts: []string{"example", "com"}},
// Type: 0xF0,
// Class: 0xF0,
// TTL: 0x77770FF0,
// },
// Data: []byte{0x01, 0x02, 0x03},
// },
// },
// }
// dnsMessageEncoded := dnsMessage.Encode()
// dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
// if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 {
// t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood)
// }
//}

8
go.mod
View File

@ -5,6 +5,8 @@ go 1.21
require (
github.com/IGLOU-EU/go-wildcard/v2 v2.0.2
github.com/coreos/go-iptables v0.7.0
github.com/google/uuid v1.6.0
github.com/miekg/dns v1.1.63
github.com/rs/zerolog v1.33.0
github.com/vishvananda/netlink v1.3.0
)
@ -13,5 +15,9 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/sys v0.24.0 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/net v0.31.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.27.0 // indirect
golang.org/x/tools v0.22.0 // indirect
)

View File

@ -1,6 +1,7 @@
package main
import (
"fmt"
"net"
"time"
@ -15,9 +16,9 @@ type Group struct {
Enabled bool
iptables *iptables.IPTables
ipset *netfilterHelper.IPSet
ifaceToIPSet *netfilterHelper.IfaceToIPSet
iptables *iptables.IPTables
ipset *netfilterHelper.IPSet
ifaceToIPSetNAT *netfilterHelper.IfaceToIPSet
}
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
@ -44,10 +45,13 @@ func (g *Group) Enable() error {
}()
if g.FixProtect {
g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
if err != nil {
return fmt.Errorf("failed to fix protect: %w", err)
}
}
err := g.ifaceToIPSet.Enable()
err := g.ifaceToIPSetNAT.Enable()
if err != nil {
return err
}
@ -64,9 +68,9 @@ func (g *Group) Disable() []error {
return nil
}
errs2 := g.ifaceToIPSet.Disable()
if errs2 != nil {
errs = append(errs, errs2...)
err := g.ifaceToIPSetNAT.Disable()
if err != nil {
errs = append(errs, err...)
}
g.Enabled = false

239
kvas2.go
View File

@ -9,13 +9,16 @@ import (
"strings"
"time"
"kvas2-go/dns-proxy"
"kvas2-go/dns-mitm"
"kvas2-go/models"
"kvas2-go/netfilter-helper"
"github.com/google/uuid"
"github.com/miekg/dns"
"github.com/rs/zerolog/log"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netlink/nl"
"golang.org/x/sys/unix"
)
var (
@ -36,11 +39,11 @@ type Config struct {
type App struct {
Config Config
DNSProxy *dnsProxy.DNSProxy
DNSMITM *dnsMitm.DNSMITM
NetfilterHelper4 *netfilterHelper.NetfilterHelper
NetfilterHelper6 *netfilterHelper.NetfilterHelper
Records *Records
Groups map[int]*Group
Groups map[uuid.UUID]*Group
Link netlink.Link
@ -51,19 +54,23 @@ type App struct {
func (a *App) handleLink(event netlink.LinkUpdate) {
switch event.Change {
case 0x00000001:
case unix.IFF_UP:
log.Debug().
Str("interface", event.Link.Attrs().Name).
Str("operstatestr", event.Attrs().OperState.String()).
Int("operstate", int(event.Attrs().OperState)).
Msg("interface change")
if event.Attrs().OperState != netlink.OperDown {
switch event.Attrs().OperState {
case netlink.OperUp:
ifaceName := event.Link.Attrs().Name
for _, group := range a.Groups {
if group.Interface == event.Link.Attrs().Name {
err := group.ifaceToIPSet.IfaceHandle()
if err != nil {
log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
}
if group.Interface != ifaceName {
continue
}
err := group.ifaceToIPSetNAT.IfaceHandle()
if err != nil {
log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up")
}
}
}
@ -83,22 +90,34 @@ func (a *App) handleLink(event netlink.LinkUpdate) {
}
}
func (a *App) listen(ctx context.Context) (err error) {
errChan := make(chan error)
func (a *App) start(ctx context.Context) (err error) {
newCtx, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: Chan err
errChan := make(chan error)
/*
DNS Proxy
*/
go func() {
err := a.DNSProxy.Listen(newCtx)
err := a.DNSMITM.ListenUDP(newCtx)
if err != nil {
errChan <- fmt.Errorf("failed to serve DNS proxy: %v", err)
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
}
}()
go func() {
err := a.DNSMITM.ListenTCP(newCtx)
if err != nil {
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
}
}()
addrList, err := netlink.AddrList(a.Link, nl.FAMILY_ALL)
if err != nil {
return fmt.Errorf("failed to addrList address: %w", err)
return fmt.Errorf("failed to list address of interface: %w", err)
}
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
@ -121,6 +140,10 @@ func (a *App) listen(ctx context.Context) (err error) {
_ = a.dnsOverrider6.Disable()
}()
/*
Groups
*/
for _, group := range a.Groups {
err = group.Enable()
if err != nil {
@ -134,6 +157,9 @@ func (a *App) listen(ctx context.Context) (err error) {
}
}()
/*
Socket (for netfilter.d events)
*/
socketPath := "/opt/var/run/kvas2-go.sock"
err = os.Remove(socketPath)
if err != nil && !errors.Is(err, os.ErrNotExist) {
@ -181,8 +207,8 @@ func (a *App) listen(ctx context.Context) (err error) {
}
}
for _, group := range a.Groups {
if group.ifaceToIPSet.Enabled {
err := group.ifaceToIPSet.PutIPTable(args[2])
if group.ifaceToIPSetNAT.Enabled {
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
if err != nil {
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
}
@ -193,21 +219,28 @@ func (a *App) listen(ctx context.Context) (err error) {
}
}()
link := make(chan netlink.LinkUpdate)
done := make(chan struct{})
err = netlink.LinkSubscribe(link, done)
/*
Interface updates
*/
linkUpdateChannel := make(chan netlink.LinkUpdate)
linkUpdateDone := make(chan struct{})
err = netlink.LinkSubscribe(linkUpdateChannel, linkUpdateDone)
if err != nil {
return fmt.Errorf("failed to subscribe to link updates: %w", err)
}
defer func() {
close(done)
close(linkUpdateDone)
}()
/*
Global loop
*/
for {
select {
case event := <-link:
case event := <-linkUpdateChannel:
a.handleLink(event)
case err := <-errChan:
close(errChan)
return err
case <-ctx.Done():
return nil
@ -215,7 +248,7 @@ func (a *App) listen(ctx context.Context) (err error) {
}
}
func (a *App) Listen(ctx context.Context) (err error) {
func (a *App) Start(ctx context.Context) (err error) {
if a.isRunning {
return ErrAlreadyRunning
}
@ -226,20 +259,16 @@ func (a *App) Listen(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
var recoveredError error
var ok bool
if recoveredError, ok = r.(error); !ok {
recoveredError = fmt.Errorf("%v", r)
if err, ok = r.(error); !ok {
err = fmt.Errorf("%v", r)
}
err = fmt.Errorf("recovered error: %w", recoveredError)
err = fmt.Errorf("recovered error: %w", err)
}
}()
appErr := a.listen(ctx)
if appErr != nil {
return appErr
}
err = a.start(ctx)
return err
}
@ -249,19 +278,19 @@ func (a *App) AddGroup(group *models.Group) error {
return ErrGroupIDConflict
}
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPrefix, group.ID)
ipsetName := fmt.Sprintf("%s%8x", a.Config.IpSetPrefix, group.ID.ID())
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
if err != nil {
return fmt.Errorf("failed to initialize ipset: %w", err)
}
grp := &Group{
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
Group: group,
iptables: a.NetfilterHelper4.IPTables,
ipset: ipset,
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPrefix, group.ID), group.Interface, ipsetName, false),
}
a.Groups[group.ID] = grp
a.Groups[grp.ID] = grp
return a.SyncGroup(grp)
}
@ -276,7 +305,7 @@ func (a *App) SyncGroup(group *Group) error {
}
knownDomains := a.Records.ListKnownDomains()
for _, domain := range group.Domains {
for _, domain := range group.Rules {
if !domain.IsEnabled() {
continue
}
@ -359,24 +388,24 @@ func (a *App) ListInterfaces() ([]net.Interface, error) {
return interfaceNames, nil
}
func (a *App) processARecord(aRecord dnsProxy.Address) {
func (a *App) processARecord(aRecord dns.A) {
log.Trace().
Str("name", aRecord.Name.String()).
Str("address", aRecord.Address.String()).
Int("ttl", int(aRecord.TTL)).
Str("name", aRecord.Hdr.Name).
Str("address", aRecord.A.String()).
Int("ttl", int(aRecord.Hdr.Ttl)).
Msg("processing a record")
ttlDuration := time.Duration(aRecord.TTL) * time.Second
ttlDuration := time.Duration(aRecord.Hdr.Ttl) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
a.Records.AddARecord(aRecord.Hdr.Name, aRecord.A, ttlDuration)
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
names := a.Records.GetCNameRecords(aRecord.Hdr.Name, true)
for _, group := range a.Groups {
Domain:
for _, domain := range group.Domains {
Rule:
for _, domain := range group.Rules {
if !domain.IsEnabled() {
continue
}
@ -384,47 +413,47 @@ func (a *App) processARecord(aRecord dnsProxy.Address) {
if !domain.IsMatch(name) {
continue
}
err := group.AddIPv4(aRecord.Address, ttlDuration)
err := group.AddIPv4(aRecord.A, ttlDuration)
if err != nil {
log.Error().
Str("address", aRecord.Address.String()).
Str("address", aRecord.A.String()).
Err(err).
Msg("failed to add address")
} else {
log.Trace().
Str("address", aRecord.Address.String()).
Str("aRecordDomain", aRecord.Name.String()).
Str("address", aRecord.A.String()).
Str("aRecordDomain", aRecord.Hdr.Name).
Str("cNameDomain", name).
Err(err).
Msg("add address")
}
break Domain
break Rule
}
}
}
}
func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
log.Trace().
Str("name", cNameRecord.Name.String()).
Str("cname", cNameRecord.CName.String()).
Int("ttl", int(cNameRecord.TTL)).
Str("name", cNameRecord.Hdr.Name).
Str("cname", cNameRecord.Target).
Int("ttl", int(cNameRecord.Hdr.Ttl)).
Msg("processing cname record")
ttlDuration := time.Duration(cNameRecord.TTL) * time.Second
ttlDuration := time.Duration(cNameRecord.Hdr.Ttl) * time.Second
if ttlDuration < a.Config.MinimalTTL {
ttlDuration = a.Config.MinimalTTL
}
a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
a.Records.AddCNameRecord(cNameRecord.Hdr.Name, cNameRecord.Target, ttlDuration)
// TODO: Optimization
now := time.Now()
aRecords := a.Records.GetARecords(cNameRecord.Name.String())
names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true)
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name)
names := a.Records.GetCNameRecords(cNameRecord.Hdr.Name, true)
for _, group := range a.Groups {
Domain:
for _, domain := range group.Domains {
Rule:
for _, domain := range group.Rules {
if !domain.IsEnabled() {
continue
}
@ -447,31 +476,30 @@ func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
Msg("add address")
}
}
continue Domain
continue Rule
}
}
}
}
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {
func (a *App) handleRecord(rr dns.RR) {
switch v := rr.(type) {
case dnsProxy.Address:
// TODO: Optimize equals domain A records
a.processARecord(v)
case dnsProxy.CName:
a.processCNameRecord(v)
case *dns.A:
a.processARecord(*v)
case *dns.CNAME:
a.processCNameRecord(*v)
default:
}
}
func (a *App) handleMessage(msg *dnsProxy.Message) {
for _, rr := range msg.AN {
func (a *App) handleMessage(msg dns.Msg) {
for _, rr := range msg.Answer {
a.handleRecord(rr)
}
for _, rr := range msg.NS {
for _, rr := range msg.Ns {
a.handleRecord(rr)
}
for _, rr := range msg.AR {
for _, rr := range msg.Extra {
a.handleRecord(rr)
}
}
@ -483,17 +511,68 @@ func New(config Config) (*App, error) {
app.Config = config
app.DNSMITM = dnsMitm.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress)
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Str("name", reqMsg.Question[0].Name).
Msg("received DNS request")
// TODO: Need to understand why it not works in proxy mode
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
respMsg := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: reqMsg.Id,
Response: true,
RecursionAvailable: true,
Rcode: dns.RcodeNameError,
},
Question: reqMsg.Question,
}
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Msg("sending DNS response")
return nil, respMsg, nil
}
return nil, nil, nil
}
app.DNSMITM.ResponseHook = func(clientAddr net.Addr, reqMsg dns.Msg, respMsg dns.Msg, network string) (*dns.Msg, error) {
// TODO: Make it optional
var idx int
for _, a := range respMsg.Answer {
if a.Header().Rrtype == dns.TypeAAAA {
continue
}
respMsg.Answer[idx] = a
idx++
}
respMsg.Answer = respMsg.Answer[:idx]
if len(respMsg.Answer) != 0 {
log.Debug().
Str("network", network).
Str("clientAddr", clientAddr.String()).
Str("respMsg", respMsg.Answer[0].Header().Name).
Msg("sending DNS response")
}
app.handleMessage(respMsg)
return &respMsg, nil
}
app.Records = NewRecords()
app.Groups = make(map[uuid.UUID]*Group, 0)
link, err := netlink.LinkByName(app.Config.LinkName)
if err != nil {
return nil, fmt.Errorf("failed to find link %s: %w", app.Config.LinkName, err)
}
app.Link = link
app.DNSProxy = dnsProxy.New(app.Config.ListenDNSPort, app.Config.TargetDNSServerAddress)
app.DNSProxy.MsgHandler = app.handleMessage
app.Records = NewRecords()
nh4, err := netfilterHelper.New(false)
if err != nil {
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
@ -514,7 +593,7 @@ func New(config Config) (*App, error) {
return nil, fmt.Errorf("failed to clear iptables: %w", err)
}
app.Groups = make(map[int]*Group)
app.Groups = make(map[uuid.UUID]*Group)
return app, nil
}

11
main.go
View File

@ -19,7 +19,7 @@ func main() {
ChainPrefix: "KVAS2_",
IpSetPrefix: "kvas2_",
LinkName: "br0",
TargetDNSServerAddress: "127.0.0.1:53",
TargetDNSServerAddress: "127.0.0.1",
ListenDNSPort: 7553,
})
if err != nil {
@ -28,13 +28,16 @@ func main() {
ctx, cancel := context.WithCancel(context.Background())
log.Info().Msg("starting service")
/*
Starting app with graceful shutdown
*/
appResult := make(chan error)
go func() {
appResult <- app.Listen(ctx)
appResult <- app.Start(ctx)
}()
log.Info().Msg("starting service")
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)

View File

@ -1,33 +0,0 @@
package models
import (
"regexp"
"github.com/IGLOU-EU/go-wildcard/v2"
)
type Domain struct {
ID int
Group *Group
Type string
Domain string
Enable bool
Comment string
}
func (d *Domain) IsEnabled() bool {
return d.Enable
}
func (d *Domain) IsMatch(domainName string) bool {
switch d.Type {
case "wildcard":
return wildcard.Match(d.Domain, domainName)
case "regex":
ok, _ := regexp.MatchString(d.Domain, domainName)
return ok
case "plaintext":
return domainName == d.Domain
}
return false
}

View File

@ -1,42 +0,0 @@
package models
import "testing"
func TestDomain_IsMatch_Plaintext(t *testing.T) {
domain := &Domain{
Type: "plaintext",
Domain: "example.com",
}
if !domain.IsMatch("example.com") {
t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"example.com\") returns false")
}
if domain.IsMatch("noexample.com") {
t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"noexample.com\") returns true")
}
}
func TestDomain_IsMatch_Wildcard(t *testing.T) {
domain := &Domain{
Type: "wildcard",
Domain: "ex*le.com",
}
if !domain.IsMatch("example.com") {
t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"example.com\") returns false")
}
if domain.IsMatch("noexample.com") {
t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true")
}
}
func TestDomain_IsMatch_RegEx(t *testing.T) {
domain := &Domain{
Type: "regex",
Domain: "^ex[apm]{3}le.com$",
}
if !domain.IsMatch("example.com") {
t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false")
}
if domain.IsMatch("noexample.com") {
t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true")
}
}

View File

@ -1,9 +1,11 @@
package models
import "github.com/google/uuid"
type Group struct {
ID int
ID uuid.UUID
Name string
Interface string
Rules []*Rule
FixProtect bool
Domains []*Domain
}

33
models/rule.go Normal file
View File

@ -0,0 +1,33 @@
package models
import (
"regexp"
"github.com/IGLOU-EU/go-wildcard/v2"
"github.com/google/uuid"
)
type Rule struct {
ID uuid.UUID
Name string
Type string
Rule string
Enable bool
}
func (d *Rule) IsEnabled() bool {
return d.Enable
}
func (d *Rule) IsMatch(domainName string) bool {
switch d.Type {
case "wildcard":
return wildcard.Match(d.Rule, domainName)
case "regex":
ok, _ := regexp.MatchString(d.Rule, domainName)
return ok
case "plaintext":
return domainName == d.Rule
}
return false
}

42
models/rule_test.go Normal file
View File

@ -0,0 +1,42 @@
package models
import "testing"
func TestDomain_IsMatch_Plaintext(t *testing.T) {
rule := &Rule{
Type: "plaintext",
Rule: "example.com",
}
if !rule.IsMatch("example.com") {
t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"example.com\") returns false")
}
if rule.IsMatch("noexample.com") {
t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"noexample.com\") returns true")
}
}
func TestDomain_IsMatch_Wildcard(t *testing.T) {
rule := &Rule{
Type: "wildcard",
Rule: "ex*le.com",
}
if !rule.IsMatch("example.com") {
t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"example.com\") returns false")
}
if rule.IsMatch("noexample.com") {
t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true")
}
}
func TestDomain_IsMatch_RegEx(t *testing.T) {
rule := &Rule{
Type: "regex",
Rule: "^ex[apm]{3}le.com$",
}
if !rule.IsMatch("example.com") {
t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false")
}
if rule.IsMatch("noexample.com") {
t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true")
}
}

View File

@ -40,6 +40,11 @@ func (r *PortRemap) PutIPTable(table string) error {
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
if err != nil {
return fmt.Errorf("failed to create rule: %w", err)
}
}
err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName)