Compare commits
No commits in common. "5bc0c3b2b47a666fc371e8b2c3b661d67b8469b0" and "53a13d5e90a6acaa6e1235d3dc62de4be942118c" have entirely different histories.
5bc0c3b2b4
...
53a13d5e90
@ -1,5 +0,0 @@
|
|||||||
# Contributors
|
|
||||||
|
|
||||||
## Consultants
|
|
||||||
|
|
||||||
- **nesteroff561** - [GitHub](https://github.com/nesteroff561) - "Help with understanding `iptables`"
|
|
@ -4,7 +4,7 @@ Better implementation of [KVAS](https://github.com/qzeleza/kvas)
|
|||||||
|
|
||||||
Realized features:
|
Realized features:
|
||||||
- [x] DNS Proxy (UDP)
|
- [x] DNS Proxy (UDP)
|
||||||
- [x] DNS Proxy (TCP)
|
- [ ] DNS Proxy (TCP)
|
||||||
- [x] Records memory
|
- [x] Records memory
|
||||||
- [x] IPTables rules for rebind DNS server port
|
- [x] IPTables rules for rebind DNS server port
|
||||||
- [X] IPSet integration
|
- [X] IPSet integration
|
||||||
|
@ -1,221 +0,0 @@
|
|||||||
package dnsMitmProxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DNSMITMProxy struct {
|
|
||||||
TargetDNSServerAddress string
|
|
||||||
TargetDNSServerPort uint16
|
|
||||||
|
|
||||||
RequestHook func(net.Addr, dns.Msg, string) (*dns.Msg, *dns.Msg, error)
|
|
||||||
ResponseHook func(net.Addr, dns.Msg, dns.Msg, string) (*dns.Msg, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p DNSMITMProxy) requestDNS(req []byte, network string) ([]byte, error) {
|
|
||||||
serverConn, err := net.Dial(network, fmt.Sprintf("%s:%d", p.TargetDNSServerAddress, p.TargetDNSServerPort))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to dial DNS server: %w", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = serverConn.Close() }()
|
|
||||||
|
|
||||||
err = serverConn.SetDeadline(time.Now().Add(time.Second * 5))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set deadline: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if network == "tcp" {
|
|
||||||
err = binary.Write(serverConn, binary.BigEndian, uint16(len(req)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to write length: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := serverConn.Write(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to write request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp []byte
|
|
||||||
if network == "tcp" {
|
|
||||||
var respLen uint16
|
|
||||||
err = binary.Read(serverConn, binary.BigEndian, &respLen)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read length: %w", err)
|
|
||||||
}
|
|
||||||
resp = make([]byte, respLen)
|
|
||||||
} else {
|
|
||||||
resp = make([]byte, 512)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = serverConn.Read(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp[:n], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p DNSMITMProxy) processReq(clientAddr net.Addr, req []byte, network string) ([]byte, error) {
|
|
||||||
var reqMsg dns.Msg
|
|
||||||
if p.RequestHook != nil || p.ResponseHook != nil {
|
|
||||||
err := reqMsg.Unpack(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse request: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.RequestHook != nil {
|
|
||||||
modifiedReq, modifiedResp, err := p.RequestHook(clientAddr, reqMsg, network)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("request hook error: %w", err)
|
|
||||||
}
|
|
||||||
if modifiedResp != nil {
|
|
||||||
resp, err := modifiedResp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send modified response: %w", err)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
if modifiedReq != nil {
|
|
||||||
reqMsg = *modifiedReq
|
|
||||||
req, err = reqMsg.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to pack modified request: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := p.requestDNS(req, network)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.ResponseHook != nil {
|
|
||||||
var respMsg dns.Msg
|
|
||||||
err = respMsg.Unpack(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
modifiedResp, err := p.ResponseHook(clientAddr, reqMsg, respMsg, network)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("response hook error: %w", err)
|
|
||||||
}
|
|
||||||
if modifiedResp != nil {
|
|
||||||
resp, err = modifiedResp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to send modified response: %w", err)
|
|
||||||
}
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p DNSMITMProxy) ListenTCP(ctx context.Context, addr *net.TCPAddr) error {
|
|
||||||
listener, err := net.ListenTCP("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to listen tcp port: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = listener.Close() }()
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Exit if context is done
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("tcp connection error")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(clientConn net.Conn) {
|
|
||||||
defer func() { _ = clientConn.Close() }()
|
|
||||||
|
|
||||||
var respLen uint16
|
|
||||||
err = binary.Read(clientConn, binary.BigEndian, &respLen)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to read length")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req := make([]byte, int(respLen))
|
|
||||||
_, err = clientConn.Read(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to read tcp request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := p.processReq(clientConn.RemoteAddr(), req, "tcp")
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to process request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = binary.Write(clientConn, binary.BigEndian, uint16(len(resp)))
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to send length")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = clientConn.Write(resp)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to send response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p DNSMITMProxy) ListenUDP(ctx context.Context, addr *net.UDPAddr) error {
|
|
||||||
conn, err := net.ListenUDP("udp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to listen udp port: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = conn.Close() }()
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Exit if context is done
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
req := make([]byte, 512)
|
|
||||||
n, clientAddr, err := conn.ReadFromUDP(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to read udp request")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
req = req[:n]
|
|
||||||
|
|
||||||
go func(clientConn *net.UDPConn, clientAddr *net.UDPAddr) {
|
|
||||||
resp, err := p.processReq(clientAddr, req, "udp")
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to process request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = clientConn.WriteToUDP(resp, clientAddr)
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to send response")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}(conn, clientAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *DNSMITMProxy {
|
|
||||||
return &DNSMITMProxy{
|
|
||||||
TargetDNSServerPort: 53,
|
|
||||||
}
|
|
||||||
}
|
|
119
dns-proxy/dns-proxy.go
Normal file
119
dns-proxy/dns-proxy.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DNSMaxUDPPackageSize = 4096
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNSProxy struct {
|
||||||
|
udpConn *net.UDPConn
|
||||||
|
listenPort uint16
|
||||||
|
|
||||||
|
targetDNSServerAddress string
|
||||||
|
|
||||||
|
MsgHandler func(*Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) Listen(ctx context.Context) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", p.listenPort))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve UDP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.udpConn, err = net.ListenUDP("udp", udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen UDP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if p.udpConn != nil {
|
||||||
|
err := p.udpConn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to close UDP connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
buffer := make([]byte, DNSMaxUDPPackageSize)
|
||||||
|
n, clientAddr, err := p.udpConn.ReadFromUDP(buffer)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to read UDP packet")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go p.handleDNSRequest(clientAddr, buffer[:n])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p DNSProxy) handleDNSRequest(clientAddr *net.UDPAddr, buffer []byte) {
|
||||||
|
conn, err := net.Dial("udp", p.targetDNSServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to dial target DNS")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
_, err = conn.Write(buffer)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to send request to target DNS")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to set read deadline")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, DNSMaxUDPPackageSize)
|
||||||
|
n, err := conn.Read(response)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||||
|
// Just skip it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error().Err(err).Msg("failed to read response from target DNS")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, err := ParseResponse(response[:n])
|
||||||
|
if err == nil {
|
||||||
|
if p.MsgHandler != nil {
|
||||||
|
p.MsgHandler(msg)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Warn().Err(err).Msg("error while parsing DNS message")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = p.udpConn.WriteToUDP(response[:n], clientAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to send DNS message")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(listenPort uint16, targetDNSServerAddress string) *DNSProxy {
|
||||||
|
return &DNSProxy{
|
||||||
|
listenPort: listenPort,
|
||||||
|
targetDNSServerAddress: targetDNSServerAddress,
|
||||||
|
}
|
||||||
|
}
|
195
dns-proxy/parser.go
Normal file
195
dns-proxy/parser.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidDNSAddressResourceData = errors.New("invalid DNS address resource data")
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseName(response []byte, pos int) (*Name, int, error) {
|
||||||
|
var nameParts []string
|
||||||
|
var jumped bool
|
||||||
|
var outPos int
|
||||||
|
responseLen := len(response)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if responseLen < pos+1 {
|
||||||
|
return nil, pos, io.EOF
|
||||||
|
}
|
||||||
|
length := int(response[pos])
|
||||||
|
pos++
|
||||||
|
if length == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if length&0xC0 != 0 {
|
||||||
|
if responseLen < pos+1 {
|
||||||
|
return nil, pos, io.EOF
|
||||||
|
}
|
||||||
|
if !jumped {
|
||||||
|
outPos = pos + 1
|
||||||
|
}
|
||||||
|
pos = int(binary.BigEndian.Uint16(response[pos-1:pos+1]) & 0x3FFF)
|
||||||
|
jumped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseLen < pos+length {
|
||||||
|
return nil, pos, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
nameParts = append(nameParts, string(response[pos:pos+length]))
|
||||||
|
pos += length
|
||||||
|
}
|
||||||
|
|
||||||
|
if !jumped {
|
||||||
|
outPos = pos
|
||||||
|
}
|
||||||
|
return &Name{Parts: nameParts}, outPos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResourceRecord(response []byte, pos int) (ResourceRecord, int, error) {
|
||||||
|
responseLen := len(response)
|
||||||
|
|
||||||
|
rhName, pos, err := parseName(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, pos, fmt.Errorf("error while parsing DNS name: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseLen < pos+10 {
|
||||||
|
return nil, pos, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
rh := ResourceRecordHeader{
|
||||||
|
Name: *rhName,
|
||||||
|
Type: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||||
|
Class: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||||
|
TTL: binary.BigEndian.Uint32(response[pos+4 : pos+8]),
|
||||||
|
}
|
||||||
|
rdLen := int(binary.BigEndian.Uint16(response[pos+8 : pos+10]))
|
||||||
|
|
||||||
|
pos += 10
|
||||||
|
|
||||||
|
if responseLen < pos+rdLen {
|
||||||
|
return nil, pos, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
switch rh.Type {
|
||||||
|
case 1:
|
||||||
|
if rdLen != 4 {
|
||||||
|
return nil, pos, ErrInvalidDNSAddressResourceData
|
||||||
|
}
|
||||||
|
return Address{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
Address: response[pos+0 : pos+4],
|
||||||
|
}, pos + 4, nil
|
||||||
|
case 2:
|
||||||
|
var ns *Name
|
||||||
|
ns, pos, err = parseName(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
|
||||||
|
}
|
||||||
|
return NameServer{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
NSDName: *ns,
|
||||||
|
}, pos, nil
|
||||||
|
case 5:
|
||||||
|
var cname *Name
|
||||||
|
cname, pos, err = parseName(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return nil, pos, fmt.Errorf("error while parsing DNS resource record: %w", err)
|
||||||
|
}
|
||||||
|
return CName{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
CName: *cname,
|
||||||
|
}, pos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return Unknown{
|
||||||
|
ResourceRecordHeader: rh,
|
||||||
|
Data: response[pos+0 : pos+rdLen],
|
||||||
|
}, pos + rdLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseResponse(response []byte) (*Message, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
msg := new(Message)
|
||||||
|
|
||||||
|
responseLen := len(response)
|
||||||
|
if responseLen < 12 {
|
||||||
|
return msg, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.ID = binary.BigEndian.Uint16(response[0:2])
|
||||||
|
|
||||||
|
flagsRAW := binary.BigEndian.Uint16(response[2:4])
|
||||||
|
msg.Flags = Flags{
|
||||||
|
QR: uint8(flagsRAW >> 15 & 0x1),
|
||||||
|
Opcode: uint8(flagsRAW >> 11 & 0xF),
|
||||||
|
AA: uint8(flagsRAW >> 10 & 0x1),
|
||||||
|
TC: uint8(flagsRAW >> 9 & 0x1),
|
||||||
|
RD: uint8(flagsRAW >> 8 & 0x1),
|
||||||
|
RA: uint8(flagsRAW >> 7 & 0x1),
|
||||||
|
Z1: uint8(flagsRAW >> 6 & 0x1),
|
||||||
|
Z2: uint8(flagsRAW >> 5 & 0x1),
|
||||||
|
Z3: uint8(flagsRAW >> 4 & 0x1),
|
||||||
|
RCode: uint8(flagsRAW >> 0 & 0xF),
|
||||||
|
}
|
||||||
|
|
||||||
|
qdCount := int(binary.BigEndian.Uint16(response[4:6]))
|
||||||
|
anCount := int(binary.BigEndian.Uint16(response[6:8]))
|
||||||
|
nsCount := int(binary.BigEndian.Uint16(response[8:10]))
|
||||||
|
arCount := int(binary.BigEndian.Uint16(response[10:12]))
|
||||||
|
|
||||||
|
pos := 12
|
||||||
|
|
||||||
|
msg.QD = make([]Question, qdCount)
|
||||||
|
for i := 0; i < qdCount; i++ {
|
||||||
|
var name *Name
|
||||||
|
name, pos, err = parseName(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return msg, fmt.Errorf("error while parsing DNS name: %w", err)
|
||||||
|
}
|
||||||
|
if responseLen < pos+4 {
|
||||||
|
return msg, io.EOF
|
||||||
|
}
|
||||||
|
msg.QD[i] = Question{
|
||||||
|
QName: *name,
|
||||||
|
QType: binary.BigEndian.Uint16(response[pos+0 : pos+2]),
|
||||||
|
QClass: binary.BigEndian.Uint16(response[pos+2 : pos+4]),
|
||||||
|
}
|
||||||
|
pos += 4
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.AN = make([]ResourceRecord, anCount)
|
||||||
|
for i := 0; i < anCount; i++ {
|
||||||
|
msg.AN[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return msg, fmt.Errorf("error while parsing AN record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.NS = make([]ResourceRecord, nsCount)
|
||||||
|
for i := 0; i < nsCount; i++ {
|
||||||
|
msg.NS[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return msg, fmt.Errorf("error while parsing NS record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg.AR = make([]ResourceRecord, arCount)
|
||||||
|
for i := 0; i < arCount; i++ {
|
||||||
|
msg.AR[i], pos, err = parseResourceRecord(response, pos)
|
||||||
|
if err != nil {
|
||||||
|
return msg, fmt.Errorf("error while parsing AR record: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
196
dns-proxy/types.go
Normal file
196
dns-proxy/types.go
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResourceRecord interface {
|
||||||
|
EncodeResource() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResourceRecordHeader struct {
|
||||||
|
Name Name
|
||||||
|
Type uint16
|
||||||
|
Class uint16
|
||||||
|
TTL uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q ResourceRecordHeader) EncodeHeader() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
buf.Write(q.Name.Encode())
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Type))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.Class))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint32([]byte{}, q.TTL))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Name struct {
|
||||||
|
Parts []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Name) String() string {
|
||||||
|
return strings.Join(n.Parts, ".")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n Name) Encode() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
for _, part := range n.Parts {
|
||||||
|
partLen := byte(len(part)) & 0x3F
|
||||||
|
buf.WriteByte(partLen)
|
||||||
|
buf.Write([]byte(part)[0:partLen])
|
||||||
|
}
|
||||||
|
buf.WriteByte(0)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Flags struct {
|
||||||
|
QR uint8
|
||||||
|
Opcode uint8
|
||||||
|
AA uint8
|
||||||
|
TC uint8
|
||||||
|
RD uint8
|
||||||
|
RA uint8
|
||||||
|
Z1 uint8
|
||||||
|
Z2 uint8
|
||||||
|
Z3 uint8
|
||||||
|
RCode uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f Flags) Encode() []byte {
|
||||||
|
return []byte{
|
||||||
|
f.QR&0x1<<7 + f.Opcode&0xF<<3 + f.AA&0x1<<2 + f.TC&0x1<<1 + f.RD&0x1<<0,
|
||||||
|
f.RA&0x1<<7 + f.Z1&0x1<<6 + f.Z2&0x1<<5 + f.Z3&0x1<<4 + f.RCode&0xF<<0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Question struct {
|
||||||
|
QName Name
|
||||||
|
QType uint16
|
||||||
|
QClass uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q Question) EncodeQuestion() []byte {
|
||||||
|
buf := bytes.NewBuffer([]byte{})
|
||||||
|
buf.Write(q.QName.Encode())
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QType))
|
||||||
|
buf.Write(binary.BigEndian.AppendUint16([]byte{}, q.QClass))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
Address net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Address) EncodeResource() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write([]byte{0x00, 0x04})
|
||||||
|
rr.Write(a.Address[:])
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type NameServer struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
NSDName Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a NameServer) EncodeResource() []byte {
|
||||||
|
rdataBytes := a.NSDName.Encode()
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type CName struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
CName Name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a CName) EncodeResource() []byte {
|
||||||
|
rdataBytes := a.CName.Encode()
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Authority struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
MName Name
|
||||||
|
RName Name
|
||||||
|
Serial uint32
|
||||||
|
Refresh uint32
|
||||||
|
Retry uint32
|
||||||
|
Expire uint32
|
||||||
|
Minimum uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Authority) EncodeResource() []byte {
|
||||||
|
rdata := bytes.NewBuffer([]byte{})
|
||||||
|
rdata.Write(a.MName.Encode())
|
||||||
|
rdata.Write(a.RName.Encode())
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Serial))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Refresh))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Retry))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Expire))
|
||||||
|
rdata.Write(binary.BigEndian.AppendUint32([]byte{}, a.Minimum))
|
||||||
|
rdataBytes := rdata.Bytes()
|
||||||
|
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(a.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(rdataBytes))))
|
||||||
|
rr.Write(rdataBytes)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Unknown struct {
|
||||||
|
ResourceRecordHeader
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u Unknown) EncodeResource() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(u.ResourceRecordHeader.EncodeHeader())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(u.Data))))
|
||||||
|
rr.Write(u.Data)
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
ID uint16
|
||||||
|
Flags Flags
|
||||||
|
QD []Question
|
||||||
|
AN []ResourceRecord
|
||||||
|
NS []ResourceRecord
|
||||||
|
AR []ResourceRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) Encode() []byte {
|
||||||
|
rr := bytes.NewBuffer([]byte{})
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, m.ID))
|
||||||
|
rr.Write(m.Flags.Encode())
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.QD))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AN))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.NS))))
|
||||||
|
rr.Write(binary.BigEndian.AppendUint16([]byte{}, uint16(len(m.AR))))
|
||||||
|
for _, q := range m.QD {
|
||||||
|
rr.Write(q.EncodeQuestion())
|
||||||
|
}
|
||||||
|
for _, a := range m.AN {
|
||||||
|
rr.Write(a.EncodeResource())
|
||||||
|
}
|
||||||
|
for _, ns := range m.NS {
|
||||||
|
rr.Write(ns.EncodeResource())
|
||||||
|
}
|
||||||
|
for _, a := range m.AR {
|
||||||
|
rr.Write(a.EncodeResource())
|
||||||
|
}
|
||||||
|
return rr.Bytes()
|
||||||
|
}
|
225
dns-proxy/types_test.go
Normal file
225
dns-proxy/types_test.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package dnsProxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDNSResourceRecordHeaderEncode(t *testing.T) {
|
||||||
|
recordHeader := ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
}
|
||||||
|
recordHeaderEncoded := recordHeader.EncodeHeader()
|
||||||
|
recordHeaderEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0}
|
||||||
|
if bytes.Compare(recordHeaderEncoded, recordHeaderEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`ResourceRecordHeader.EncodeHeader() = %x, want "%x", error`, recordHeaderEncoded, recordHeaderEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameString(t *testing.T) {
|
||||||
|
dnsName := Name{Parts: []string{"example", "com"}}
|
||||||
|
dnsNameString := dnsName.String()
|
||||||
|
dnsNameStringGood := "example.com"
|
||||||
|
if dnsNameString != dnsNameStringGood {
|
||||||
|
t.Fatalf(`Name.String() = %s, want "%s", error`, dnsNameString, dnsNameStringGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameEncode(t *testing.T) {
|
||||||
|
dnsName := Name{Parts: []string{"example", "com"}}
|
||||||
|
dnsNameEncoded := dnsName.Encode()
|
||||||
|
dnsNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsNameEncoded, dnsNameEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Name.Encode() = %x, want "%x", error`, dnsNameEncoded, dnsNameEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSFlagsEncode(t *testing.T) {
|
||||||
|
dnsFlags := Flags{
|
||||||
|
QR: 0x1,
|
||||||
|
Opcode: 0xF,
|
||||||
|
AA: 0x0,
|
||||||
|
TC: 0x0,
|
||||||
|
RD: 0x1,
|
||||||
|
RA: 0x1,
|
||||||
|
Z1: 0x0,
|
||||||
|
Z2: 0x0,
|
||||||
|
Z3: 0x0,
|
||||||
|
RCode: 0xF,
|
||||||
|
}
|
||||||
|
dnsFlagsEncoded := dnsFlags.Encode()
|
||||||
|
dnsFlagsEncodedGood := []byte{0xf9, 0x8f}
|
||||||
|
if bytes.Compare(dnsFlagsEncoded, dnsFlagsEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Flags.Encode() = %x, want "%x", error`, dnsFlagsEncoded, dnsFlagsEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSQuestionEncode(t *testing.T) {
|
||||||
|
dnsQuestion := Question{
|
||||||
|
QName: Name{Parts: []string{"example", "com"}},
|
||||||
|
QType: 0x001c,
|
||||||
|
QClass: 0x0001,
|
||||||
|
}
|
||||||
|
dnsQuestionEncoded := dnsQuestion.EncodeQuestion()
|
||||||
|
dnsQuestionEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0x1c, 0x00, 0x01}
|
||||||
|
if bytes.Compare(dnsQuestionEncoded, dnsQuestionEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Question.EncodeHeader() = %x, want "%x", error`, dnsQuestionEncoded, dnsQuestionEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSAddressEncode(t *testing.T) {
|
||||||
|
dnsAddress := Address{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Address: []byte{192, 168, 1, 1},
|
||||||
|
}
|
||||||
|
dnsAddressEncoded := dnsAddress.EncodeResource()
|
||||||
|
dnsAddressEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x04, 192, 168, 1, 1}
|
||||||
|
if bytes.Compare(dnsAddressEncoded, dnsAddressEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Address.EncodeResource() = %x, want "%x", error`, dnsAddressEncoded, dnsAddressEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSNameServerEncode(t *testing.T) {
|
||||||
|
dnsNameServer := NameServer{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
NSDName: Name{Parts: []string{"example", "com"}},
|
||||||
|
}
|
||||||
|
dnsNameServerEncoded := dnsNameServer.EncodeResource()
|
||||||
|
dnsNameServerEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsNameServerEncoded, dnsNameServerEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`NameServer.EncodeResource() = %x, want "%x", error`, dnsNameServerEncoded, dnsNameServerEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSCNameEncode(t *testing.T) {
|
||||||
|
dnsCName := CName{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
CName: Name{Parts: []string{"example", "com"}},
|
||||||
|
}
|
||||||
|
dnsCNameEncoded := dnsCName.EncodeResource()
|
||||||
|
dnsCNameEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x0D, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00}
|
||||||
|
if bytes.Compare(dnsCNameEncoded, dnsCNameEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`CName.EncodeResource() = %x, want "%x", error`, dnsCNameEncoded, dnsCNameEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSAuthorityEncode(t *testing.T) {
|
||||||
|
dnsAuthority := Authority{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
MName: Name{Parts: []string{"example", "com"}},
|
||||||
|
RName: Name{Parts: []string{"example", "com"}},
|
||||||
|
Serial: 0x12345678,
|
||||||
|
Refresh: 0x12345678,
|
||||||
|
Retry: 0x12345678,
|
||||||
|
Expire: 0x12345678,
|
||||||
|
Minimum: 0x12345678,
|
||||||
|
}
|
||||||
|
dnsAuthorityEncoded := dnsAuthority.EncodeResource()
|
||||||
|
dnsAuthorityEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x2E, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78, 0x12, 0x34, 0x56, 0x78}
|
||||||
|
if bytes.Compare(dnsAuthorityEncoded, dnsAuthorityEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Authority.EncodeResource() = %x, want "%x", error`, dnsAuthorityEncoded, dnsAuthorityEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDNSUnknownEncode(t *testing.T) {
|
||||||
|
dnsUnknown := Unknown{
|
||||||
|
ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
Type: 0xF0,
|
||||||
|
Class: 0xF0,
|
||||||
|
TTL: 0x77770FF0,
|
||||||
|
},
|
||||||
|
Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
}
|
||||||
|
dnsUnknownEncoded := dnsUnknown.EncodeResource()
|
||||||
|
dnsUnknownEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||||
|
if bytes.Compare(dnsUnknownEncoded, dnsUnknownEncodedGood) != 0 {
|
||||||
|
t.Fatalf(`Unknown.EncodeResource() = %x, want "%x", error`, dnsUnknownEncoded, dnsUnknownEncodedGood)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//func TestDNSMessageEncode(t *testing.T) {
|
||||||
|
// dnsMessage := Message{
|
||||||
|
// ID: 0x00FF,
|
||||||
|
// Flags: Flags{
|
||||||
|
// QR: 0x1,
|
||||||
|
// Opcode: 0xF,
|
||||||
|
// AA: 0x0,
|
||||||
|
// TC: 0x0,
|
||||||
|
// RD: 0x1,
|
||||||
|
// RA: 0x1,
|
||||||
|
// Z1: 0x0,
|
||||||
|
// Z2: 0x0,
|
||||||
|
// Z3: 0x0,
|
||||||
|
// RCode: 0xF,
|
||||||
|
// },
|
||||||
|
// QD: []Question{
|
||||||
|
// {
|
||||||
|
// QName: Name{Parts: []string{"example", "com"}},
|
||||||
|
// QType: 0x001c,
|
||||||
|
// QClass: 0x0001,
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// AN: []ResourceRecord{
|
||||||
|
// Unknown{
|
||||||
|
// ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
// Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
// Type: 0xF0,
|
||||||
|
// Class: 0xF0,
|
||||||
|
// TTL: 0x77770FF0,
|
||||||
|
// },
|
||||||
|
// Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// NS: []ResourceRecord{
|
||||||
|
// Unknown{
|
||||||
|
// ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
// Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
// Type: 0xF0,
|
||||||
|
// Class: 0xF0,
|
||||||
|
// TTL: 0x77770FF0,
|
||||||
|
// },
|
||||||
|
// Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// AR: []ResourceRecord{
|
||||||
|
// Unknown{
|
||||||
|
// ResourceRecordHeader: ResourceRecordHeader{
|
||||||
|
// Name: Name{Parts: []string{"example", "com"}},
|
||||||
|
// Type: 0xF0,
|
||||||
|
// Class: 0xF0,
|
||||||
|
// TTL: 0x77770FF0,
|
||||||
|
// },
|
||||||
|
// Data: []byte{0x01, 0x02, 0x03},
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
// dnsMessageEncoded := dnsMessage.Encode()
|
||||||
|
// dnsMessageEncodedGood := []byte{0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 0x03, 'c', 'o', 'm', 0x00, 0x00, 0xF0, 0x00, 0xF0, 0x77, 0x77, 0x0F, 0xF0, 0x00, 0x03, 0x01, 0x02, 0x03}
|
||||||
|
// if bytes.Compare(dnsMessageEncoded, dnsMessageEncodedGood) != 0 {
|
||||||
|
// t.Fatalf(`Message.Encode() = %x, want "%x", error`, dnsMessageEncoded, dnsMessageEncodedGood)
|
||||||
|
// }
|
||||||
|
//}
|
8
go.mod
8
go.mod
@ -5,8 +5,6 @@ go 1.21
|
|||||||
require (
|
require (
|
||||||
github.com/IGLOU-EU/go-wildcard/v2 v2.0.2
|
github.com/IGLOU-EU/go-wildcard/v2 v2.0.2
|
||||||
github.com/coreos/go-iptables v0.7.0
|
github.com/coreos/go-iptables v0.7.0
|
||||||
github.com/google/uuid v1.6.0
|
|
||||||
github.com/miekg/dns v1.1.63
|
|
||||||
github.com/rs/zerolog v1.33.0
|
github.com/rs/zerolog v1.33.0
|
||||||
github.com/vishvananda/netlink v1.3.0
|
github.com/vishvananda/netlink v1.3.0
|
||||||
)
|
)
|
||||||
@ -15,9 +13,5 @@ require (
|
|||||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
github.com/vishvananda/netns v0.0.4 // indirect
|
github.com/vishvananda/netns v0.0.4 // indirect
|
||||||
golang.org/x/mod v0.18.0 // indirect
|
golang.org/x/sys v0.24.0 // indirect
|
||||||
golang.org/x/net v0.31.0 // indirect
|
|
||||||
golang.org/x/sync v0.7.0 // indirect
|
|
||||||
golang.org/x/sys v0.27.0 // indirect
|
|
||||||
golang.org/x/tools v0.22.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
18
group.go
18
group.go
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -18,12 +17,12 @@ type Group struct {
|
|||||||
|
|
||||||
iptables *iptables.IPTables
|
iptables *iptables.IPTables
|
||||||
ipset *netfilterHelper.IPSet
|
ipset *netfilterHelper.IPSet
|
||||||
ifaceToIPSetNAT *netfilterHelper.IfaceToIPSet
|
ifaceToIPSet *netfilterHelper.IfaceToIPSet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
|
func (g *Group) AddIPv4(address net.IP, ttl time.Duration) error {
|
||||||
ttlSeconds := uint32(ttl.Seconds())
|
ttlSeconds := uint32(ttl.Seconds())
|
||||||
return g.ipset.AddIP(address, &ttlSeconds)
|
return g.ipset.Add(address, &ttlSeconds)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *Group) DelIPv4(address net.IP) error {
|
func (g *Group) DelIPv4(address net.IP) error {
|
||||||
@ -45,13 +44,10 @@ func (g *Group) Enable() error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if g.FixProtect {
|
if g.FixProtect {
|
||||||
err := g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
|
g.iptables.AppendUnique("filter", "_NDM_SL_FORWARD", "-o", g.Interface, "-m", "state", "--state", "NEW", "-j", "_NDM_SL_PROTECT")
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to fix protect: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := g.ifaceToIPSetNAT.Enable()
|
err := g.ifaceToIPSet.Enable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -68,9 +64,9 @@ func (g *Group) Disable() []error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := g.ifaceToIPSetNAT.Disable()
|
errs2 := g.ifaceToIPSet.Disable()
|
||||||
if err != nil {
|
if errs2 != nil {
|
||||||
errs = append(errs, err...)
|
errs = append(errs, errs2...)
|
||||||
}
|
}
|
||||||
|
|
||||||
g.Enabled = false
|
g.Enabled = false
|
||||||
|
543
kvas2.go
543
kvas2.go
@ -5,21 +5,16 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"kvas2-go/dns-mitm-proxy"
|
"kvas2-go/dns-proxy"
|
||||||
"kvas2-go/models"
|
"kvas2-go/models"
|
||||||
"kvas2-go/netfilter-helper"
|
"kvas2-go/netfilter-helper"
|
||||||
"kvas2-go/records"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
"github.com/vishvananda/netlink/nl"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -29,31 +24,140 @@ var (
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
MinimalTTL time.Duration
|
MinimalTTL time.Duration
|
||||||
ChainPrefix string
|
ChainPostfix string
|
||||||
IpSetPrefix string
|
IpSetPostfix string
|
||||||
LinkName string
|
|
||||||
TargetDNSServerAddress string
|
TargetDNSServerAddress string
|
||||||
ListenDNSPort uint16
|
ListenPort uint16
|
||||||
UseSoftwareRouting bool
|
UseSoftwareRouting bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type App struct {
|
type App struct {
|
||||||
Config Config
|
Config Config
|
||||||
|
|
||||||
DNSMITM *dnsMitmProxy.DNSMITMProxy
|
DNSProxy *dnsProxy.DNSProxy
|
||||||
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
NetfilterHelper4 *netfilterHelper.NetfilterHelper
|
||||||
NetfilterHelper6 *netfilterHelper.NetfilterHelper
|
Records *Records
|
||||||
Records *records.Records
|
Groups map[int]*Group
|
||||||
Groups map[uuid.UUID]*Group
|
|
||||||
|
|
||||||
Link netlink.Link
|
|
||||||
|
|
||||||
isRunning bool
|
isRunning bool
|
||||||
dnsOverrider4 *netfilterHelper.PortRemap
|
dnsOverrider4 *netfilterHelper.PortRemap
|
||||||
dnsOverrider6 *netfilterHelper.PortRemap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) handleLink(event netlink.LinkUpdate) {
|
func (a *App) Listen(ctx context.Context) []error {
|
||||||
|
if a.isRunning {
|
||||||
|
return []error{ErrAlreadyRunning}
|
||||||
|
}
|
||||||
|
a.isRunning = true
|
||||||
|
defer func() { a.isRunning = false }()
|
||||||
|
|
||||||
|
errs := make([]error, 0)
|
||||||
|
isError := make(chan struct{})
|
||||||
|
|
||||||
|
var once sync.Once
|
||||||
|
var errsMu sync.Mutex
|
||||||
|
handleError := func(err error) {
|
||||||
|
errsMu.Lock()
|
||||||
|
defer errsMu.Unlock()
|
||||||
|
|
||||||
|
errs = append(errs, err)
|
||||||
|
once.Do(func() { close(isError) })
|
||||||
|
}
|
||||||
|
handleErrors := func(errs2 []error) {
|
||||||
|
errsMu.Lock()
|
||||||
|
defer errsMu.Unlock()
|
||||||
|
|
||||||
|
errs = append(errs, errs2...)
|
||||||
|
once.Do(func() { close(isError) })
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
if err, ok := r.(error); ok {
|
||||||
|
handleError(err)
|
||||||
|
} else {
|
||||||
|
handleError(fmt.Errorf("%v", r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
newCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPostfix), 53, a.Config.ListenPort)
|
||||||
|
err := a.dnsOverrider4.Enable()
|
||||||
|
|
||||||
|
for _, group := range a.Groups {
|
||||||
|
err = group.Enable()
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to enable group: %w", err))
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := a.DNSProxy.Listen(newCtx); err != nil {
|
||||||
|
handleError(fmt.Errorf("failed to initialize DNS proxy: %v", err))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
link := make(chan netlink.LinkUpdate)
|
||||||
|
done := make(chan struct{})
|
||||||
|
netlink.LinkSubscribe(link, done)
|
||||||
|
|
||||||
|
exitListenerLoop := false
|
||||||
|
listener, err := net.Listen("unix", "/opt/var/run/kvas2-go.sock")
|
||||||
|
if err != nil {
|
||||||
|
handleError(fmt.Errorf("error while serve UNIX socket: %v", err))
|
||||||
|
return errs
|
||||||
|
}
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
if exitListenerLoop {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while listening unix socket")
|
||||||
|
}
|
||||||
|
|
||||||
|
go func(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
args := strings.Split(string(buf[:n]), ":")
|
||||||
|
if len(args) == 3 && args[0] == "netfilter.d" {
|
||||||
|
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
|
||||||
|
if a.dnsOverrider4.Enabled {
|
||||||
|
err := a.dnsOverrider4.PutIPTable(args[2])
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, group := range a.Groups {
|
||||||
|
if group.ifaceToIPSet.Enabled {
|
||||||
|
err := group.ifaceToIPSet.PutIPTable(args[2])
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
Loop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event := <-link:
|
||||||
switch event.Change {
|
switch event.Change {
|
||||||
case 0x00000001:
|
case 0x00000001:
|
||||||
log.Debug().
|
log.Debug().
|
||||||
@ -61,17 +165,13 @@ func (a *App) handleLink(event netlink.LinkUpdate) {
|
|||||||
Str("operstatestr", event.Attrs().OperState.String()).
|
Str("operstatestr", event.Attrs().OperState.String()).
|
||||||
Int("operstate", int(event.Attrs().OperState)).
|
Int("operstate", int(event.Attrs().OperState)).
|
||||||
Msg("interface change")
|
Msg("interface change")
|
||||||
switch event.Attrs().OperState {
|
if event.Attrs().OperState != netlink.OperDown {
|
||||||
case netlink.OperUp:
|
|
||||||
ifaceName := event.Link.Attrs().Name
|
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
if group.Interface != ifaceName {
|
if group.Interface == event.Link.Attrs().Name {
|
||||||
continue
|
err = group.ifaceToIPSet.IfaceHandle()
|
||||||
}
|
|
||||||
|
|
||||||
err := group.ifaceToIPSetNAT.IfaceHandle()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Str("group", group.ID.String()).Err(err).Msg("error while handling interface up")
|
log.Error().Int("group", group.ID).Err(err).Msg("error while handling interface up")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -89,198 +189,30 @@ func (a *App) handleLink(event netlink.LinkUpdate) {
|
|||||||
Msg("interface del")
|
Msg("interface del")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) start(ctx context.Context) (err error) {
|
|
||||||
newCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
errChan := make(chan error)
|
|
||||||
|
|
||||||
/*
|
|
||||||
DNS Proxy
|
|
||||||
*/
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
addr, err := net.ResolveUDPAddr("udp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to resolve udp address: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = a.DNSMITM.ListenUDP(newCtx, addr)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to serve DNS UDP proxy: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
addr, err := net.ResolveTCPAddr("tcp", "[::]:"+strconv.Itoa(int(a.Config.ListenDNSPort)))
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to resolve tcp address: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = a.DNSMITM.ListenTCP(newCtx, addr)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to serve DNS TCP proxy: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
addrList, err := netlink.AddrList(a.Link, nl.FAMILY_ALL)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list address of interface: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.dnsOverrider4 = a.NetfilterHelper4.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
|
|
||||||
err = a.dnsOverrider4.Enable()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to override DNS (IPv4): %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = a.dnsOverrider4.Disable() }()
|
|
||||||
|
|
||||||
a.dnsOverrider6 = a.NetfilterHelper6.PortRemap(fmt.Sprintf("%sDNSOR", a.Config.ChainPrefix), 53, a.Config.ListenDNSPort, addrList)
|
|
||||||
err = a.dnsOverrider6.Enable()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to override DNS (IPv6): %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = a.dnsOverrider6.Disable() }()
|
|
||||||
|
|
||||||
/*
|
|
||||||
Groups
|
|
||||||
*/
|
|
||||||
|
|
||||||
for _, group := range a.Groups {
|
|
||||||
err = group.Enable()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to enable group: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
for _, group := range a.Groups {
|
|
||||||
_ = group.Disable()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
/*
|
|
||||||
Socket (for netfilter.d events)
|
|
||||||
*/
|
|
||||||
socketPath := "/opt/var/run/kvas2-go.sock"
|
|
||||||
err = os.Remove(socketPath)
|
|
||||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
||||||
return fmt.Errorf("failed to remove existed UNIX socket: %w", err)
|
|
||||||
}
|
|
||||||
socket, err := net.Listen("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error while serve UNIX socket: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = socket.Close()
|
|
||||||
_ = os.Remove(socketPath)
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if newCtx.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := socket.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if !strings.Contains(err.Error(), "use of closed network connection") {
|
|
||||||
log.Error().Err(err).Msg("error while listening unix socket")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(conn net.Conn) {
|
|
||||||
defer func() { _ = conn.Close() }()
|
|
||||||
|
|
||||||
buf := make([]byte, 1024)
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
args := strings.Split(string(buf[:n]), ":")
|
|
||||||
if len(args) == 3 && args[0] == "netfilter.d" {
|
|
||||||
log.Debug().Str("table", args[2]).Msg("netfilter.d event")
|
|
||||||
if a.dnsOverrider4.Enabled {
|
|
||||||
err := a.dnsOverrider4.PutIPTable(args[2])
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if a.dnsOverrider6.Enabled {
|
|
||||||
err = a.dnsOverrider6.PutIPTable(args[2])
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, group := range a.Groups {
|
|
||||||
if group.ifaceToIPSetNAT.Enabled {
|
|
||||||
err := group.ifaceToIPSetNAT.PutIPTable(args[2])
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("error while fixing iptables after netfilter.d")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
/*
|
|
||||||
Interface updates
|
|
||||||
*/
|
|
||||||
linkUpdateChannel := make(chan netlink.LinkUpdate)
|
|
||||||
linkUpdateDone := make(chan struct{})
|
|
||||||
err = netlink.LinkSubscribe(linkUpdateChannel, linkUpdateDone)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to subscribe to link updates: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
close(linkUpdateDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
/*
|
|
||||||
Global loop
|
|
||||||
*/
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case event := <-linkUpdateChannel:
|
|
||||||
a.handleLink(event)
|
|
||||||
case err := <-errChan:
|
|
||||||
return err
|
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
break Loop
|
||||||
|
case <-isError:
|
||||||
|
break Loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (a *App) Start(ctx context.Context) (err error) {
|
exitListenerLoop = true
|
||||||
if a.isRunning {
|
|
||||||
return ErrAlreadyRunning
|
|
||||||
}
|
|
||||||
a.isRunning = true
|
|
||||||
defer func() {
|
|
||||||
a.isRunning = false
|
|
||||||
}()
|
|
||||||
|
|
||||||
defer func() {
|
close(done)
|
||||||
if r := recover(); r != nil {
|
|
||||||
var ok bool
|
errs2 := a.dnsOverrider4.Disable()
|
||||||
if err, ok = r.(error); !ok {
|
if errs2 != nil {
|
||||||
err = fmt.Errorf("%v", r)
|
handleErrors(errs2)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fmt.Errorf("recovered error: %w", err)
|
for _, group := range a.Groups {
|
||||||
|
errs2 = group.Disable()
|
||||||
|
if errs2 != nil {
|
||||||
|
handleErrors(errs2)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
err = a.start(ctx)
|
return errs
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) AddGroup(group *models.Group) error {
|
func (a *App) AddGroup(group *models.Group) error {
|
||||||
@ -288,7 +220,7 @@ func (a *App) AddGroup(group *models.Group) error {
|
|||||||
return ErrGroupIDConflict
|
return ErrGroupIDConflict
|
||||||
}
|
}
|
||||||
|
|
||||||
ipsetName := fmt.Sprintf("%s%8x", a.Config.IpSetPrefix, group.ID.ID())
|
ipsetName := fmt.Sprintf("%s%d", a.Config.IpSetPostfix, group.ID)
|
||||||
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
|
ipset, err := a.NetfilterHelper4.IPSet(ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize ipset: %w", err)
|
return fmt.Errorf("failed to initialize ipset: %w", err)
|
||||||
@ -298,19 +230,24 @@ func (a *App) AddGroup(group *models.Group) error {
|
|||||||
Group: group,
|
Group: group,
|
||||||
iptables: a.NetfilterHelper4.IPTables,
|
iptables: a.NetfilterHelper4.IPTables,
|
||||||
ipset: ipset,
|
ipset: ipset,
|
||||||
ifaceToIPSetNAT: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%8x", a.Config.ChainPrefix, group.ID.ID()), group.Interface, ipsetName, false),
|
ifaceToIPSet: a.NetfilterHelper4.IfaceToIPSet(fmt.Sprintf("%sR_%d", a.Config.ChainPostfix, group.ID), group.Interface, ipsetName, false),
|
||||||
}
|
}
|
||||||
grp.ifaceToIPSetNAT.SoftwareMode = a.Config.UseSoftwareRouting
|
a.Groups[group.ID] = grp
|
||||||
a.Groups[grp.ID] = grp
|
|
||||||
return a.SyncGroup(grp)
|
return a.SyncGroup(grp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) SyncGroup(group *Group) error {
|
func (a *App) SyncGroup(group *Group) error {
|
||||||
|
processedDomains := make(map[string]struct{})
|
||||||
|
newIpsetAddressesMap := make(map[string]time.Duration)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
addresses := make(map[string]time.Duration)
|
oldIpsetAddresses, err := group.ListIPv4()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get old ipset list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
knownDomains := a.Records.ListKnownDomains()
|
knownDomains := a.Records.ListKnownDomains()
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Domains {
|
||||||
if !domain.IsEnabled() {
|
if !domain.IsEnabled() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -320,24 +257,26 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
domainAddresses := a.Records.GetARecords(domainName)
|
cnames := a.Records.GetCNameRecords(domainName, true)
|
||||||
for _, address := range domainAddresses {
|
if len(cnames) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, cname := range cnames {
|
||||||
|
processedDomains[cname] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
addresses := a.Records.GetARecords(domainName)
|
||||||
|
for _, address := range addresses {
|
||||||
ttl := now.Sub(address.Deadline)
|
ttl := now.Sub(address.Deadline)
|
||||||
if oldTTL, ok := addresses[string(address.Address)]; !ok || ttl > oldTTL {
|
if oldTTL, ok := newIpsetAddressesMap[string(address.Address)]; !ok || ttl > oldTTL {
|
||||||
addresses[string(address.Address)] = ttl
|
newIpsetAddressesMap[string(address.Address)] = ttl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
currentAddresses, err := group.ListIPv4()
|
for addr, ttl := range newIpsetAddressesMap {
|
||||||
if err != nil {
|
if _, exists := oldIpsetAddresses[addr]; exists {
|
||||||
return fmt.Errorf("failed to get old ipset list: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for addr, ttl := range addresses {
|
|
||||||
// TODO: Check TTL
|
|
||||||
if _, exists := currentAddresses[addr]; exists {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip := net.IP(addr)
|
ip := net.IP(addr)
|
||||||
@ -347,16 +286,11 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
Str("address", ip.String()).
|
Str("address", ip.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("failed to add address")
|
Msg("failed to add address")
|
||||||
} else {
|
|
||||||
log.Trace().
|
|
||||||
Str("address", ip.String()).
|
|
||||||
Err(err).
|
|
||||||
Msg("add address")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr := range currentAddresses {
|
for addr, _ := range oldIpsetAddresses {
|
||||||
if _, ok := addresses[addr]; ok {
|
if _, exists := newIpsetAddressesMap[addr]; exists {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ip := net.IP(addr)
|
ip := net.IP(addr)
|
||||||
@ -370,7 +304,7 @@ func (a *App) SyncGroup(group *Group) error {
|
|||||||
log.Trace().
|
log.Trace().
|
||||||
Str("address", ip.String()).
|
Str("address", ip.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("del address")
|
Msg("add address")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -396,24 +330,24 @@ func (a *App) ListInterfaces() ([]net.Interface, error) {
|
|||||||
return interfaceNames, nil
|
return interfaceNames, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) processARecord(aRecord dns.A) {
|
func (a *App) processARecord(aRecord dnsProxy.Address) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("name", aRecord.Hdr.Name).
|
Str("name", aRecord.Name.String()).
|
||||||
Str("address", aRecord.A.String()).
|
Str("address", aRecord.Address.String()).
|
||||||
Int("ttl", int(aRecord.Hdr.Ttl)).
|
Int("ttl", int(aRecord.TTL)).
|
||||||
Msg("processing a record")
|
Msg("processing a record")
|
||||||
|
|
||||||
ttlDuration := time.Duration(aRecord.Hdr.Ttl) * time.Second
|
ttlDuration := time.Duration(aRecord.TTL) * time.Second
|
||||||
if ttlDuration < a.Config.MinimalTTL {
|
if ttlDuration < a.Config.MinimalTTL {
|
||||||
ttlDuration = a.Config.MinimalTTL
|
ttlDuration = a.Config.MinimalTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Records.AddARecord(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1], aRecord.A, ttlDuration)
|
a.Records.AddARecord(aRecord.Name.String(), aRecord.Address, ttlDuration)
|
||||||
|
|
||||||
names := a.Records.GetAliases(aRecord.Hdr.Name[:len(aRecord.Hdr.Name)-1])
|
names := a.Records.GetCNameRecords(aRecord.Name.String(), true)
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
Rule:
|
Domain:
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Domains {
|
||||||
if !domain.IsEnabled() {
|
if !domain.IsEnabled() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -421,48 +355,47 @@ func (a *App) processARecord(aRecord dns.A) {
|
|||||||
if !domain.IsMatch(name) {
|
if !domain.IsMatch(name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// TODO: Check already existed
|
err := group.AddIPv4(aRecord.Address, ttlDuration)
|
||||||
err := group.AddIPv4(aRecord.A, ttlDuration)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().
|
log.Error().
|
||||||
Str("address", aRecord.A.String()).
|
Str("address", aRecord.Address.String()).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("failed to add address")
|
Msg("failed to add address")
|
||||||
} else {
|
} else {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("address", aRecord.A.String()).
|
Str("address", aRecord.Address.String()).
|
||||||
Str("aRecordDomain", aRecord.Hdr.Name).
|
Str("aRecordDomain", aRecord.Name.String()).
|
||||||
Str("cNameDomain", name).
|
Str("cNameDomain", name).
|
||||||
Err(err).
|
Err(err).
|
||||||
Msg("add address")
|
Msg("add address")
|
||||||
}
|
}
|
||||||
break Rule
|
break Domain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
|
func (a *App) processCNameRecord(cNameRecord dnsProxy.CName) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("name", cNameRecord.Hdr.Name).
|
Str("name", cNameRecord.Name.String()).
|
||||||
Str("cname", cNameRecord.Target).
|
Str("cname", cNameRecord.CName.String()).
|
||||||
Int("ttl", int(cNameRecord.Hdr.Ttl)).
|
Int("ttl", int(cNameRecord.TTL)).
|
||||||
Msg("processing cname record")
|
Msg("processing cname record")
|
||||||
|
|
||||||
ttlDuration := time.Duration(cNameRecord.Hdr.Ttl) * time.Second
|
ttlDuration := time.Duration(cNameRecord.TTL) * time.Second
|
||||||
if ttlDuration < a.Config.MinimalTTL {
|
if ttlDuration < a.Config.MinimalTTL {
|
||||||
ttlDuration = a.Config.MinimalTTL
|
ttlDuration = a.Config.MinimalTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
a.Records.AddCNameRecord(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1], cNameRecord.Target, ttlDuration)
|
a.Records.AddCNameRecord(cNameRecord.Name.String(), cNameRecord.CName.String(), ttlDuration)
|
||||||
|
|
||||||
// TODO: Optimization
|
// TODO: Optimization
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
aRecords := a.Records.GetARecords(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
aRecords := a.Records.GetARecords(cNameRecord.Name.String())
|
||||||
names := a.Records.GetAliases(cNameRecord.Hdr.Name[:len(cNameRecord.Hdr.Name)-1])
|
names := a.Records.GetCNameRecords(cNameRecord.Name.String(), true)
|
||||||
for _, group := range a.Groups {
|
for _, group := range a.Groups {
|
||||||
Rule:
|
Domain:
|
||||||
for _, domain := range group.Rules {
|
for _, domain := range group.Domains {
|
||||||
if !domain.IsEnabled() {
|
if !domain.IsEnabled() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -485,24 +418,31 @@ func (a *App) processCNameRecord(cNameRecord dns.CNAME) {
|
|||||||
Msg("add address")
|
Msg("add address")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue Rule
|
continue Domain
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) handleRecord(rr dns.RR) {
|
func (a *App) handleRecord(rr dnsProxy.ResourceRecord) {
|
||||||
switch v := rr.(type) {
|
switch v := rr.(type) {
|
||||||
case *dns.A:
|
case dnsProxy.Address:
|
||||||
a.processARecord(*v)
|
// TODO: Optimize equals domain A records
|
||||||
case *dns.CNAME:
|
a.processARecord(v)
|
||||||
a.processCNameRecord(*v)
|
case dnsProxy.CName:
|
||||||
|
a.processCNameRecord(v)
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *App) handleMessage(msg dns.Msg) {
|
func (a *App) handleMessage(msg *dnsProxy.Message) {
|
||||||
for _, rr := range msg.Answer {
|
for _, rr := range msg.AN {
|
||||||
|
a.handleRecord(rr)
|
||||||
|
}
|
||||||
|
for _, rr := range msg.NS {
|
||||||
|
a.handleRecord(rr)
|
||||||
|
}
|
||||||
|
for _, rr := range msg.AR {
|
||||||
a.handleRecord(rr)
|
a.handleRecord(rr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -514,73 +454,18 @@ func New(config Config) (*App, error) {
|
|||||||
|
|
||||||
app.Config = config
|
app.Config = config
|
||||||
|
|
||||||
app.DNSMITM = dnsMitmProxy.New()
|
app.DNSProxy = dnsProxy.New(app.Config.ListenPort, app.Config.TargetDNSServerAddress)
|
||||||
app.DNSMITM.TargetDNSServerAddress = app.Config.TargetDNSServerAddress
|
app.DNSProxy.MsgHandler = app.handleMessage
|
||||||
app.DNSMITM.TargetDNSServerPort = 53
|
|
||||||
app.DNSMITM.RequestHook = func(clientAddr net.Addr, reqMsg dns.Msg, network string) (*dns.Msg, *dns.Msg, error) {
|
|
||||||
// TODO: Need to understand why it not works in proxy mode
|
|
||||||
if len(reqMsg.Question) == 1 && reqMsg.Question[0].Qtype == dns.TypePTR {
|
|
||||||
respMsg := &dns.Msg{
|
|
||||||
MsgHdr: dns.MsgHdr{
|
|
||||||
Id: reqMsg.Id,
|
|
||||||
Response: true,
|
|
||||||
RecursionAvailable: true,
|
|
||||||
Rcode: dns.RcodeNameError,
|
|
||||||
},
|
|
||||||
Question: reqMsg.Question,
|
|
||||||
}
|
|
||||||
return nil, respMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil, nil
|
app.Records = NewRecords()
|
||||||
}
|
|
||||||
app.DNSMITM.ResponseHook = func(clientAddr net.Addr, reqMsg dns.Msg, respMsg dns.Msg, network string) (*dns.Msg, error) {
|
|
||||||
// TODO: Make it optional
|
|
||||||
var idx int
|
|
||||||
for _, a := range respMsg.Answer {
|
|
||||||
if a.Header().Rrtype == dns.TypeAAAA {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
respMsg.Answer[idx] = a
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
respMsg.Answer = respMsg.Answer[:idx]
|
|
||||||
|
|
||||||
app.handleMessage(respMsg)
|
|
||||||
|
|
||||||
return &respMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
app.Records = records.New()
|
|
||||||
app.Groups = make(map[uuid.UUID]*Group, 0)
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(app.Config.LinkName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to find link %s: %w", app.Config.LinkName, err)
|
|
||||||
}
|
|
||||||
app.Link = link
|
|
||||||
|
|
||||||
nh4, err := netfilterHelper.New(false)
|
nh4, err := netfilterHelper.New(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
|
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
|
||||||
}
|
}
|
||||||
app.NetfilterHelper4 = nh4
|
app.NetfilterHelper4 = nh4
|
||||||
err = app.NetfilterHelper4.ClearIPTables(app.Config.ChainPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to clear iptables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nh6, err := netfilterHelper.New(true)
|
app.Groups = make(map[int]*Group)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("netfilter helper init fail: %w", err)
|
|
||||||
}
|
|
||||||
app.NetfilterHelper6 = nh6
|
|
||||||
err = app.NetfilterHelper6.ClearIPTables(app.Config.ChainPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to clear iptables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
app.Groups = make(map[uuid.UUID]*Group)
|
|
||||||
|
|
||||||
return app, nil
|
return app, nil
|
||||||
}
|
}
|
||||||
|
26
main.go
26
main.go
@ -16,11 +16,10 @@ func main() {
|
|||||||
|
|
||||||
app, err := New(Config{
|
app, err := New(Config{
|
||||||
MinimalTTL: time.Hour,
|
MinimalTTL: time.Hour,
|
||||||
ChainPrefix: "KVAS2_",
|
ChainPostfix: "KVAS2_",
|
||||||
IpSetPrefix: "kvas2_",
|
IpSetPostfix: "kvas2_",
|
||||||
LinkName: "br0",
|
TargetDNSServerAddress: "127.0.0.1:53",
|
||||||
TargetDNSServerAddress: "127.0.0.1",
|
ListenPort: 7548,
|
||||||
ListenDNSPort: 7553,
|
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg("failed to initialize application")
|
log.Fatal().Err(err).Msg("failed to initialize application")
|
||||||
@ -28,23 +27,22 @@ func main() {
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
log.Info().Msg("starting service")
|
appErrsChan := make(chan []error)
|
||||||
|
|
||||||
/*
|
|
||||||
Starting app with graceful shutdown
|
|
||||||
*/
|
|
||||||
appResult := make(chan error)
|
|
||||||
go func() {
|
go func() {
|
||||||
appResult <- app.Start(ctx)
|
errs := app.Listen(ctx)
|
||||||
|
appErrsChan <- errs
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
log.Info().Msg("starting service")
|
||||||
|
|
||||||
c := make(chan os.Signal, 1)
|
c := make(chan os.Signal, 1)
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case err, _ := <-appResult:
|
case appErrs, _ := <-appErrsChan:
|
||||||
if err != nil {
|
for _, err = range appErrs {
|
||||||
log.Error().Err(err).Msg("failed to start application")
|
log.Error().Err(err).Msg("failed to start application")
|
||||||
}
|
}
|
||||||
log.Info().Msg("exiting application")
|
log.Info().Msg("exiting application")
|
||||||
|
33
models/domain.go
Normal file
33
models/domain.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/IGLOU-EU/go-wildcard/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Domain struct {
|
||||||
|
ID int
|
||||||
|
Group *Group
|
||||||
|
Type string
|
||||||
|
Domain string
|
||||||
|
Enable bool
|
||||||
|
Comment string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Domain) IsEnabled() bool {
|
||||||
|
return d.Enable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Domain) IsMatch(domainName string) bool {
|
||||||
|
switch d.Type {
|
||||||
|
case "wildcard":
|
||||||
|
return wildcard.Match(d.Domain, domainName)
|
||||||
|
case "regex":
|
||||||
|
ok, _ := regexp.MatchString(d.Domain, domainName)
|
||||||
|
return ok
|
||||||
|
case "plaintext":
|
||||||
|
return domainName == d.Domain
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
42
models/domain_test.go
Normal file
42
models/domain_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDomain_IsMatch_Plaintext(t *testing.T) {
|
||||||
|
domain := &Domain{
|
||||||
|
Type: "plaintext",
|
||||||
|
Domain: "example.com",
|
||||||
|
}
|
||||||
|
if !domain.IsMatch("example.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"example.com\") returns false")
|
||||||
|
}
|
||||||
|
if domain.IsMatch("noexample.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"plaintext\", Domain: \"example.com\"}.IsMatch(\"noexample.com\") returns true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomain_IsMatch_Wildcard(t *testing.T) {
|
||||||
|
domain := &Domain{
|
||||||
|
Type: "wildcard",
|
||||||
|
Domain: "ex*le.com",
|
||||||
|
}
|
||||||
|
if !domain.IsMatch("example.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"example.com\") returns false")
|
||||||
|
}
|
||||||
|
if domain.IsMatch("noexample.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"wildcard\", Domain: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomain_IsMatch_RegEx(t *testing.T) {
|
||||||
|
domain := &Domain{
|
||||||
|
Type: "regex",
|
||||||
|
Domain: "^ex[apm]{3}le.com$",
|
||||||
|
}
|
||||||
|
if !domain.IsMatch("example.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false")
|
||||||
|
}
|
||||||
|
if domain.IsMatch("noexample.com") {
|
||||||
|
t.Fatal("&Domain{Type: \"regex\", Domain: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true")
|
||||||
|
}
|
||||||
|
}
|
@ -1,11 +1,9 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
|
||||||
|
|
||||||
type Group struct {
|
type Group struct {
|
||||||
ID uuid.UUID
|
ID int
|
||||||
Name string
|
Name string
|
||||||
Interface string
|
Interface string
|
||||||
Rules []*Rule
|
|
||||||
FixProtect bool
|
FixProtect bool
|
||||||
|
Domains []*Domain
|
||||||
}
|
}
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
package models
|
|
||||||
|
|
||||||
import (
|
|
||||||
"regexp"
|
|
||||||
|
|
||||||
"github.com/IGLOU-EU/go-wildcard/v2"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Rule struct {
|
|
||||||
ID uuid.UUID
|
|
||||||
Name string
|
|
||||||
Type string
|
|
||||||
Rule string
|
|
||||||
Enable bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Rule) IsEnabled() bool {
|
|
||||||
return d.Enable
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Rule) IsMatch(domainName string) bool {
|
|
||||||
switch d.Type {
|
|
||||||
case "wildcard":
|
|
||||||
return wildcard.Match(d.Rule, domainName)
|
|
||||||
case "regex":
|
|
||||||
ok, _ := regexp.MatchString(d.Rule, domainName)
|
|
||||||
return ok
|
|
||||||
case "plaintext":
|
|
||||||
return domainName == d.Rule
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
@ -1,42 +0,0 @@
|
|||||||
package models
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestDomain_IsMatch_Plaintext(t *testing.T) {
|
|
||||||
rule := &Rule{
|
|
||||||
Type: "plaintext",
|
|
||||||
Rule: "example.com",
|
|
||||||
}
|
|
||||||
if !rule.IsMatch("example.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"example.com\") returns false")
|
|
||||||
}
|
|
||||||
if rule.IsMatch("noexample.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"plaintext\", Rule: \"example.com\"}.IsMatch(\"noexample.com\") returns true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDomain_IsMatch_Wildcard(t *testing.T) {
|
|
||||||
rule := &Rule{
|
|
||||||
Type: "wildcard",
|
|
||||||
Rule: "ex*le.com",
|
|
||||||
}
|
|
||||||
if !rule.IsMatch("example.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"example.com\") returns false")
|
|
||||||
}
|
|
||||||
if rule.IsMatch("noexample.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"wildcard\", Rule: \"ex*le.com\"}.IsMatch(\"noexample.com\") returns true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDomain_IsMatch_RegEx(t *testing.T) {
|
|
||||||
rule := &Rule{
|
|
||||||
Type: "regex",
|
|
||||||
Rule: "^ex[apm]{3}le.com$",
|
|
||||||
}
|
|
||||||
if !rule.IsMatch("example.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"example.com\") returns false")
|
|
||||||
}
|
|
||||||
if rule.IsMatch("noexample.com") {
|
|
||||||
t.Fatal("&Rule{Type: \"regex\", Rule: \"^ex[apm]{3}le.com$\"}.IsMatch(\"noexample.com\") returns true")
|
|
||||||
}
|
|
||||||
}
|
|
@ -179,7 +179,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
|||||||
// IPTables rules
|
// IPTables rules
|
||||||
err = r.PutIPTable("all")
|
err = r.PutIPTable("all")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mapping mark with table
|
// Mapping mark with table
|
||||||
@ -194,7 +194,7 @@ func (r *IfaceToIPSet) ForceEnable() error {
|
|||||||
|
|
||||||
err = r.IfaceHandle()
|
err = r.IfaceHandle()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Enabled = true
|
r.Enabled = true
|
||||||
|
@ -11,7 +11,7 @@ type IPSet struct {
|
|||||||
SetName string
|
SetName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *IPSet) AddIP(addr net.IP, timeout *uint32) error {
|
func (r *IPSet) Add(addr net.IP, timeout *uint32) error {
|
||||||
err := netlink.IpsetAdd(r.SetName, &netlink.IPSetEntry{
|
err := netlink.IpsetAdd(r.SetName, &netlink.IPSetEntry{
|
||||||
IP: addr,
|
IP: addr,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
@ -63,7 +63,7 @@ func (nh *NetfilterHelper) IPSet(name string) (*IPSet, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
defaultTimeout := uint32(300)
|
defaultTimeout := uint32(300)
|
||||||
err = netlink.IpsetCreate(ipset.SetName, "hash:net", netlink.IpsetCreateOptions{
|
err = netlink.IpsetCreate(ipset.SetName, "hash:ip", netlink.IpsetCreateOptions{
|
||||||
Timeout: &defaultTimeout,
|
Timeout: &defaultTimeout,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,57 +0,0 @@
|
|||||||
package netfilterHelper
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (nh *NetfilterHelper) ClearIPTables(chainPrefix string) error {
|
|
||||||
jumpToChainPrefix := fmt.Sprintf("-j %s", chainPrefix)
|
|
||||||
tableList := []string{"nat", "mangle", "filter"}
|
|
||||||
|
|
||||||
for _, table := range tableList {
|
|
||||||
chainListToDelete := make([]string, 0)
|
|
||||||
|
|
||||||
chains, err := nh.IPTables.ListChains(table)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("listing chains error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range chains {
|
|
||||||
if strings.HasPrefix(chain, chainPrefix) {
|
|
||||||
chainListToDelete = append(chainListToDelete, chain)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
rules, err := nh.IPTables.List(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("listing rules error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
ruleSlice := strings.Split(rule, " ")
|
|
||||||
if len(ruleSlice) < 2 || ruleSlice[0] != "-A" || ruleSlice[1] != chain {
|
|
||||||
// TODO: Warn
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ruleSlice = ruleSlice[2:]
|
|
||||||
|
|
||||||
if strings.Contains(strings.Join(ruleSlice, " "), jumpToChainPrefix) {
|
|
||||||
err := nh.IPTables.Delete(table, chain, ruleSlice...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("rule deletion error: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range chainListToDelete {
|
|
||||||
err := nh.IPTables.ClearAndDeleteChain(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("deleting chain error: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -2,17 +2,13 @@ package netfilterHelper
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/vishvananda/netlink"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PortRemap struct {
|
type PortRemap struct {
|
||||||
IPTables *iptables.IPTables
|
IPTables *iptables.IPTables
|
||||||
ChainName string
|
ChainName string
|
||||||
Addresses []netlink.Addr
|
|
||||||
From uint16
|
From uint16
|
||||||
To uint16
|
To uint16
|
||||||
|
|
||||||
@ -26,27 +22,11 @@ func (r *PortRemap) PutIPTable(table string) error {
|
|||||||
return fmt.Errorf("failed to clear chain: %w", err)
|
return fmt.Errorf("failed to clear chain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, addr := range r.Addresses {
|
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "--dport", strconv.Itoa(int(r.From)), "-j", "REDIRECT", "--to-port", strconv.Itoa(int(r.To)))
|
||||||
var addrIP net.IP
|
|
||||||
iptablesProtocol := r.IPTables.Proto()
|
|
||||||
if (iptablesProtocol == iptables.ProtocolIPv4 && len(addr.IP) == net.IPv4len) || (iptablesProtocol == iptables.ProtocolIPv6 && len(addr.IP) == net.IPv6len) {
|
|
||||||
addrIP = addr.IP
|
|
||||||
}
|
|
||||||
if addrIP == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "udp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create rule: %w", err)
|
return fmt.Errorf("failed to create rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.IPTables.AppendUnique("nat", r.ChainName, "-p", "tcp", "-d", addrIP.String(), "--dport", strconv.Itoa(int(r.From)), "-j", "DNAT", "--to-destination", fmt.Sprintf(":%d", r.To))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName)
|
err = r.IPTables.InsertUnique("nat", "PREROUTING", 1, "-j", r.ChainName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to linking chain: %w", err)
|
return fmt.Errorf("failed to linking chain: %w", err)
|
||||||
@ -97,11 +77,10 @@ func (r *PortRemap) Enable() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16, addr []netlink.Addr) *PortRemap {
|
func (nh *NetfilterHelper) PortRemap(name string, from, to uint16) *PortRemap {
|
||||||
return &PortRemap{
|
return &PortRemap{
|
||||||
IPTables: nh.IPTables,
|
IPTables: nh.IPTables,
|
||||||
ChainName: name,
|
ChainName: name,
|
||||||
Addresses: addr,
|
|
||||||
From: from,
|
From: from,
|
||||||
To: to,
|
To: to,
|
||||||
}
|
}
|
||||||
|
228
records.go
Normal file
228
records.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ARecord struct {
|
||||||
|
Address net.IP
|
||||||
|
Deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewARecord(addr net.IP, deadline time.Time) *ARecord {
|
||||||
|
return &ARecord{
|
||||||
|
Address: addr,
|
||||||
|
Deadline: deadline,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CNameRecord struct {
|
||||||
|
Alias string
|
||||||
|
Deadline time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCNameRecord(domainName string, deadline time.Time) *CNameRecord {
|
||||||
|
return &CNameRecord{
|
||||||
|
Alias: domainName,
|
||||||
|
Deadline: deadline,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Records struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
ARecords map[string][]*ARecord
|
||||||
|
CNameRecords map[string]*CNameRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) cleanupARecords(now time.Time) {
|
||||||
|
for name, aRecords := range r.ARecords {
|
||||||
|
i := 0
|
||||||
|
for _, aRecord := range aRecords {
|
||||||
|
if now.After(aRecord.Deadline) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aRecords[i] = aRecord
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
aRecords = aRecords[:i]
|
||||||
|
if i == 0 {
|
||||||
|
delete(r.ARecords, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) cleanupCNameRecords(now time.Time) {
|
||||||
|
for name, record := range r.CNameRecords {
|
||||||
|
if now.After(record.Deadline) {
|
||||||
|
delete(r.CNameRecords, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) getAliasedDomain(now time.Time, domainName string) string {
|
||||||
|
processedDomains := make(map[string]struct{})
|
||||||
|
for {
|
||||||
|
if _, processed := processedDomains[domainName]; processed {
|
||||||
|
// Loop detected!
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
processedDomains[domainName] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
cname, ok := r.CNameRecords[domainName]
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if now.After(cname.Deadline) {
|
||||||
|
delete(r.CNameRecords, domainName)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
domainName = cname.Alias
|
||||||
|
}
|
||||||
|
return domainName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) getActualARecords(now time.Time, domainName string) []*ARecord {
|
||||||
|
aRecords, ok := r.ARecords[domainName]
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for _, aRecord := range aRecords {
|
||||||
|
if now.After(aRecord.Deadline) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
aRecords[i] = aRecord
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
aRecords = aRecords[:i]
|
||||||
|
if i == 0 {
|
||||||
|
delete(r.ARecords, domainName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return aRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) getActualCNames(now time.Time, domainName string, fromEnd bool) []string {
|
||||||
|
processedDomains := make(map[string]struct{})
|
||||||
|
cNameList := make([]string, 0)
|
||||||
|
if fromEnd {
|
||||||
|
domainName = r.getAliasedDomain(now, domainName)
|
||||||
|
cNameList = append(cNameList, domainName)
|
||||||
|
}
|
||||||
|
r.cleanupCNameRecords(now)
|
||||||
|
for {
|
||||||
|
if _, processed := processedDomains[domainName]; processed {
|
||||||
|
// Loop detected!
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
processedDomains[domainName] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
found := false
|
||||||
|
for aliasFrom, aliasTo := range r.CNameRecords {
|
||||||
|
if aliasTo.Alias == domainName {
|
||||||
|
cNameList = append(cNameList, aliasFrom)
|
||||||
|
domainName = aliasFrom
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cNameList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) Cleanup() {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.cleanupARecords(now)
|
||||||
|
r.cleanupCNameRecords(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) GetCNameRecords(domainName string, fromEnd bool) []string {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
return r.getActualCNames(now, domainName, fromEnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) GetARecords(domainName string) []*ARecord {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
return r.getActualARecords(now, r.getAliasedDomain(now, domainName))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) AddCNameRecord(domainName string, cName string, ttl time.Duration) {
|
||||||
|
if domainName == cName {
|
||||||
|
// Can't assing to yourself
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
delete(r.ARecords, domainName)
|
||||||
|
r.CNameRecords[domainName] = NewCNameRecord(cName, now.Add(ttl))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
delete(r.CNameRecords, domainName)
|
||||||
|
if _, ok := r.ARecords[domainName]; !ok {
|
||||||
|
r.ARecords[domainName] = make([]*ARecord, 0)
|
||||||
|
}
|
||||||
|
for _, aRecord := range r.ARecords[domainName] {
|
||||||
|
if bytes.Compare(aRecord.Address, addr) == 0 {
|
||||||
|
aRecord.Deadline = now.Add(ttl)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.ARecords[domainName] = append(r.ARecords[domainName], NewARecord(addr, now.Add(ttl)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Records) ListKnownDomains() []string {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
now := time.Now()
|
||||||
|
r.cleanupARecords(now)
|
||||||
|
r.cleanupCNameRecords(now)
|
||||||
|
|
||||||
|
domains := map[string]struct{}{}
|
||||||
|
for name, _ := range r.ARecords {
|
||||||
|
domains[name] = struct{}{}
|
||||||
|
}
|
||||||
|
for name, _ := range r.CNameRecords {
|
||||||
|
domains[name] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
domainsList := make([]string, len(domains))
|
||||||
|
i := 0
|
||||||
|
for name, _ := range domains {
|
||||||
|
domainsList[i] = name
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
return domainsList
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRecords() *Records {
|
||||||
|
return &Records{
|
||||||
|
ARecords: make(map[string][]*ARecord),
|
||||||
|
CNameRecords: make(map[string]*CNameRecord),
|
||||||
|
}
|
||||||
|
}
|
@ -1,167 +0,0 @@
|
|||||||
package records
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ARecord struct {
|
|
||||||
Address net.IP
|
|
||||||
Deadline time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type CNameRecord struct {
|
|
||||||
Alias string
|
|
||||||
Deadline time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
type Records struct {
|
|
||||||
mux sync.RWMutex
|
|
||||||
records map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) AddCNameRecord(domainName, alias string, ttl time.Duration) {
|
|
||||||
if domainName == alias {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mux.Lock()
|
|
||||||
r.records[domainName] = &CNameRecord{
|
|
||||||
Alias: alias,
|
|
||||||
Deadline: time.Now().Add(ttl),
|
|
||||||
}
|
|
||||||
r.mux.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) AddARecord(domainName string, addr net.IP, ttl time.Duration) {
|
|
||||||
r.mux.Lock()
|
|
||||||
defer r.mux.Unlock()
|
|
||||||
|
|
||||||
deadline := time.Now().Add(ttl)
|
|
||||||
|
|
||||||
aRecords, _ := r.records[domainName].([]*ARecord)
|
|
||||||
for _, aRecord := range aRecords {
|
|
||||||
if bytes.Compare(aRecord.Address, addr) != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
aRecord.Deadline = deadline
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r.records[domainName] = append(aRecords, &ARecord{
|
|
||||||
Address: addr,
|
|
||||||
Deadline: deadline,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) GetAliases(domainName string) []string {
|
|
||||||
r.mux.Lock()
|
|
||||||
defer r.mux.Unlock()
|
|
||||||
r.cleanupRecords()
|
|
||||||
|
|
||||||
domains := make(map[string]struct{})
|
|
||||||
domains[domainName] = struct{}{}
|
|
||||||
|
|
||||||
for {
|
|
||||||
var addedNew bool
|
|
||||||
for name, aRecord := range r.records {
|
|
||||||
if _, ok := domains[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
cname, ok := aRecord.(*CNameRecord)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, ok = domains[cname.Alias]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
domains[name] = struct{}{}
|
|
||||||
addedNew = true
|
|
||||||
}
|
|
||||||
if !addedNew {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
domainList := make([]string, len(domains))
|
|
||||||
idx := 0
|
|
||||||
for name, _ := range domains {
|
|
||||||
domainList[idx] = name
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
|
|
||||||
return domainList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) GetARecords(domainName string) []*ARecord {
|
|
||||||
r.mux.Lock()
|
|
||||||
defer r.mux.Unlock()
|
|
||||||
r.cleanupRecords()
|
|
||||||
|
|
||||||
loopDetect := make(map[string]struct{})
|
|
||||||
loopDetect[domainName] = struct{}{}
|
|
||||||
for {
|
|
||||||
switch v := r.records[domainName].(type) {
|
|
||||||
case *CNameRecord:
|
|
||||||
if _, ok := loopDetect[v.Alias]; ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
domainName = v.Alias
|
|
||||||
loopDetect[v.Alias] = struct{}{}
|
|
||||||
case []*ARecord:
|
|
||||||
return v
|
|
||||||
default:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) ListKnownDomains() []string {
|
|
||||||
r.mux.Lock()
|
|
||||||
defer r.mux.Unlock()
|
|
||||||
r.cleanupRecords()
|
|
||||||
|
|
||||||
domainsList := make([]string, len(r.records))
|
|
||||||
i := 0
|
|
||||||
for name, _ := range r.records {
|
|
||||||
domainsList[i] = name
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return domainsList
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *Records) cleanupRecords() {
|
|
||||||
now := time.Now()
|
|
||||||
for name, records := range r.records {
|
|
||||||
switch v := records.(type) {
|
|
||||||
case []*ARecord:
|
|
||||||
idx := 0
|
|
||||||
for _, aRecord := range v {
|
|
||||||
if now.After(aRecord.Deadline) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
v[idx] = aRecord
|
|
||||||
idx++
|
|
||||||
}
|
|
||||||
if idx == 0 {
|
|
||||||
delete(r.records, name)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
r.records[name] = v[:idx]
|
|
||||||
case *CNameRecord:
|
|
||||||
if !now.After(v.Deadline) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(r.records, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *Records {
|
|
||||||
return &Records{
|
|
||||||
records: make(map[string]interface{}),
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,109 +0,0 @@
|
|||||||
package records
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"slices"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoop(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddCNameRecord("1", "2", time.Minute)
|
|
||||||
r.AddCNameRecord("2", "1", time.Minute)
|
|
||||||
if r.GetARecords("1") != nil {
|
|
||||||
t.Fatal("loop detected")
|
|
||||||
}
|
|
||||||
if r.GetARecords("2") != nil {
|
|
||||||
t.Fatal("loop detected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCName(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
|
||||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
|
||||||
records := r.GetARecords("gateway.example.com")
|
|
||||||
if records == nil {
|
|
||||||
t.Fatal("no records")
|
|
||||||
}
|
|
||||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
|
||||||
t.Fatal("cname mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestA(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, time.Minute)
|
|
||||||
records := r.GetARecords("example.com")
|
|
||||||
if records == nil {
|
|
||||||
t.Fatal("no records")
|
|
||||||
}
|
|
||||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
|
||||||
t.Fatal("cname mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeprecated(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddARecord("example.com", []byte{1, 2, 3, 4}, -time.Minute)
|
|
||||||
records := r.GetARecords("example.com")
|
|
||||||
if records != nil {
|
|
||||||
t.Fatal("deprecated records")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotExistedA(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
records := r.GetARecords("example.com")
|
|
||||||
if records != nil {
|
|
||||||
t.Fatal("not existed records")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotExistedCNameAlias(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
|
||||||
records := r.GetARecords("gateway.example.com")
|
|
||||||
if records != nil {
|
|
||||||
t.Fatal("not existed records")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReplacing(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddCNameRecord("gateway.example.com", "example.com", time.Minute)
|
|
||||||
r.AddARecord("gateway.example.com", []byte{1, 2, 3, 4}, time.Minute)
|
|
||||||
records := r.GetARecords("gateway.example.com")
|
|
||||||
if bytes.Compare(records[0].Address, []byte{1, 2, 3, 4}) != 0 {
|
|
||||||
t.Fatal("mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAliases(t *testing.T) {
|
|
||||||
r := New()
|
|
||||||
r.AddARecord("1", []byte{1, 2, 3, 4}, time.Minute)
|
|
||||||
r.AddCNameRecord("2", "1", time.Minute)
|
|
||||||
r.AddCNameRecord("3", "2", time.Minute)
|
|
||||||
r.AddCNameRecord("4", "2", time.Minute)
|
|
||||||
r.AddCNameRecord("5", "1", time.Minute)
|
|
||||||
aliases := r.GetAliases("1")
|
|
||||||
if aliases == nil {
|
|
||||||
t.Fatal("no aliases")
|
|
||||||
}
|
|
||||||
if !slices.Contains(aliases, "1") {
|
|
||||||
t.Fatal("no 1")
|
|
||||||
}
|
|
||||||
if !slices.Contains(aliases, "2") {
|
|
||||||
t.Fatal("no 2")
|
|
||||||
}
|
|
||||||
if !slices.Contains(aliases, "3") {
|
|
||||||
t.Fatal("no 3")
|
|
||||||
}
|
|
||||||
if !slices.Contains(aliases, "4") {
|
|
||||||
t.Fatal("no 4")
|
|
||||||
}
|
|
||||||
if !slices.Contains(aliases, "5") {
|
|
||||||
t.Fatal("no 5")
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user