2025-02-11 02:50:13 +03:00
|
|
|
package dnsMitmProxy
|
2025-02-08 06:23:36 +03:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
)
|
|
|
|
|
2025-02-11 02:50:13 +03:00
|
|
|
type DNSMITMProxy struct {
|
2025-02-14 03:17:43 +03:00
|
|
|
UpstreamDNSAddress string
|
|
|
|
UpstreamDNSPort uint16
|
2025-02-08 06:23:36 +03:00
|
|
|
|
|
|
|
RequestHook func(net.Addr, dns.Msg, string) (*dns.Msg, *dns.Msg, error)
|
|
|
|
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
|
|
|
|
}
|
|
|
|
|
2025-02-11 02:50:13 +03:00
|
|
|
func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) {
|
2025-02-14 03:17:43 +03:00
|
|
|
upstreamConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.UpstreamDNSAddress, p.UpstreamDNSPort))
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
2025-02-14 03:17:43 +03:00
|
|
|
return nil, fmt.Errorf("failed to dial DNS upstream: %w", err)
|
2025-02-08 06:23:36 +03:00
|
|
|
}
|
2025-02-14 03:17:43 +03:00
|
|
|
defer func() { _ = upstreamConn.Close() }()
|
2025-02-08 06:23:36 +03:00
|
|
|
|
2025-02-14 03:17:43 +03:00
|
|
|
err = upstreamConn.SetDeadline(time.Now().Add(time.Second * 5))
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to set deadline: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if network == "tcp" {
|
2025-02-14 03:17:43 +03:00
|
|
|
err = binary.Write(upstreamConn, binary.BigEndian, uint16(len(req)))
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to write length: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-02-14 03:17:43 +03:00
|
|
|
n, err := upstreamConn.Write(req)
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to write request: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
var resp []byte
|
|
|
|
if network == "tcp" {
|
|
|
|
var respLen uint16
|
2025-02-14 03:17:43 +03:00
|
|
|
err = binary.Read(upstreamConn, binary.BigEndian, &respLen)
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to read length: %w", err)
|
|
|
|
}
|
|
|
|
resp = make([]byte, respLen)
|
|
|
|
} else {
|
|
|
|
resp = make([]byte, 512)
|
|
|
|
}
|
|
|
|
|
2025-02-14 03:17:43 +03:00
|
|
|
n, err = upstreamConn.Read(resp)
|
2025-02-08 06:23:36 +03:00
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp[:n], nil
|
|
|
|
}
|
|
|
|
|
2025-02-11 02:50:13 +03:00
|
|
|
func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
2025-02-08 06:23:36 +03:00
|
|
|
var reqMsg dns.Msg
|
|
|
|
if p.RequestHook != nil || p.ResponseHook != nil {
|
|
|
|
err := reqMsg.Unpack(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to parse request: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if p.RequestHook != nil {
|
|
|
|
modifiedReq, modifiedResp, err := p.RequestHook(clientAddr, reqMsg, network)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("request hook error: %w", err)
|
|
|
|
}
|
|
|
|
if modifiedResp != nil {
|
|
|
|
resp, err := modifiedResp.Pack()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to send modified response: %w", err)
|
|
|
|
}
|
|
|
|
return resp, nil
|
|
|
|
}
|
|
|
|
if modifiedReq != nil {
|
|
|
|
reqMsg = *modifiedReq
|
|
|
|
req, err = reqMsg.Pack()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to pack modified request: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := p.requestDNS(req, network)
|
|
|
|
if err != nil {
|
2025-02-11 02:50:13 +03:00
|
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
2025-02-08 06:23:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
if p.ResponseHook != nil {
|
|
|
|
var respMsg dns.Msg
|
|
|
|
err = respMsg.Unpack(resp)
|
|
|
|
if err != nil {
|
2025-02-11 02:50:13 +03:00
|
|
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
2025-02-08 06:23:36 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("response hook error: %w", err)
|
|
|
|
}
|
|
|
|
if modifiedResp != nil {
|
|
|
|
resp, err = modifiedResp.Pack()
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("failed to send modified response: %w", err)
|
|
|
|
}
|
|
|
|
return resp, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return resp, nil
|
|
|
|
}
|
|
|
|
|
2025-02-11 02:50:13 +03:00
|
|
|
func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error {
|
2025-02-08 06:23:36 +03:00
|
|
|
listener, err := net.ListenTCP("tcp", addr)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to listen tcp port: %v", err)
|
|
|
|
}
|
|
|
|
defer func() { _ = listener.Close() }()
|
|
|
|
|
|
|
|
for {
|
|
|
|
// Exit if context is done
|
|
|
|
if ctx.Err() != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
conn, err := listener.Accept()
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("tcp connection error")
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
go func(clientConn net.Conn) {
|
|
|
|
defer func() { _ = clientConn.Close() }()
|
|
|
|
|
|
|
|
var respLen uint16
|
|
|
|
err = binary.Read(clientConn, binary.BigEndian, &respLen)
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to read length")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
req := make([]byte, int(respLen))
|
|
|
|
_, err = clientConn.Read(req)
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to read tcp request")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := p.processReq(clientConn.RemoteAddr(), req, "tcp")
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to process request")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
err = binary.Write(clientConn, binary.BigEndian, uint16(len(resp)))
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to send length")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
_, err = clientConn.Write(resp)
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to send response")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}(conn)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-02-11 02:50:13 +03:00
|
|
|
func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error {
|
2025-02-08 06:23:36 +03:00
|
|
|
conn, err := net.ListenUDP("udp", addr)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed to listen udp port: %v", err)
|
|
|
|
}
|
|
|
|
defer func() { _ = conn.Close() }()
|
|
|
|
|
|
|
|
for {
|
|
|
|
// Exit if context is done
|
|
|
|
if ctx.Err() != nil {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
req := make([]byte, 512)
|
|
|
|
n, clientAddr, err := conn.ReadFromUDP(req)
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to read udp request")
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
req = req[:n]
|
|
|
|
|
|
|
|
go func(clientConn *net.UDPConn, clientAddr *net.UDPAddr) {
|
|
|
|
resp, err := p.processReq(clientAddr, req, "udp")
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to process request")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
_, err = clientConn.WriteToUDP(resp, clientAddr)
|
|
|
|
if err != nil {
|
|
|
|
log.Error().Err(err).Msg("failed to send response")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}(conn, clientAddr)
|
|
|
|
}
|
|
|
|
}
|